From cdc4abfb36cc309581181087267daf44f29783b1 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 24 Jul 2025 13:56:33 +0300 Subject: [PATCH] Add comprehensive hooks system for MCPHost lifecycle events (#111) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add comprehensive hooks system for MCPHost lifecycle events Implements a flexible hooks system based on Anthropic Claude Code specification: - **Hook Events**: PreToolUse, PostToolUse, UserPromptSubmit, Stop - **Hook Types**: Command execution with JSON input/output - **Configuration**: XDG-compliant with layered config support - **Security**: Command validation, timeout controls, safe execution - **Common Fields**: Consistent session ID, timestamps, model info across all hooks Key features: - Hooks receive JSON via stdin and can control flow via stdout - Pattern matching for tool-specific hooks (regex support) - Enhanced Stop hook with agent response and metadata - Centralized session management with consistent IDs - Built-in examples for logging, validation, and monitoring This enables users to: - Log and audit all tool usage and prompts - Implement custom security policies - Track usage metrics and model performance - Integrate with external systems - Build custom workflows around MCPHost 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode * Enable hooks in script mode Previously, hooks were only initialized and executed in normal mode but not in script mode. This was because script mode had its own execution path that bypassed the hook initialization code. This fix: - Adds hook initialization to runScriptMode function - Creates hook executor with proper session ID and model info - Passes the hook executor to runAgenticLoop Now hooks work consistently across all execution modes (normal, script, and interactive), ensuring uniform behavior for logging, validation, and monitoring. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode * Remove unnecessary hooks.local.yml pattern The .local.yml pattern adds unnecessary complexity. Users who want project-specific hooks that aren't committed to git can simply add .mcphost/ to their .gitignore. This simplifies the hooks configuration loading and makes it clearer that: - Global user hooks go in ~/.config/mcphost/hooks.yml - Project-specific hooks go in .mcphost/hooks.yml - Git ignore management is left to the user 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode * Fix hooks test isolation and add --no-hooks flag - Fix TestLoadHooksConfig by setting temporary XDG_CONFIG_HOME to prevent loading global hooks - Add --no-hooks flag to disable all hooks execution across all modes - Update README with documentation for the new flag - Add test to verify hooks loading behavior This allows users to temporarily disable hooks for security or debugging purposes. 🤖 Generated with [opencode](https://opencode.ai) Co-Authored-By: opencode --------- Co-authored-by: opencode --- README.md | 66 ++++++ cmd/hooks.go | 169 ++++++++++++++ cmd/root.go | 210 ++++++++++++++++-- cmd/script.go | 19 +- examples/hooks/bash-validator.py | 57 +++++ examples/hooks/mcp-monitor.py | 72 ++++++ examples/hooks/prompt-logger.sh | 23 ++ internal/hooks/config.go | 132 +++++++++++ internal/hooks/config_test.go | 224 +++++++++++++++++++ internal/hooks/events.go | 32 +++ internal/hooks/executor.go | 254 ++++++++++++++++++++++ internal/hooks/executor_test.go | 193 ++++++++++++++++ internal/hooks/schemas.go | 55 +++++ internal/hooks/testdata/invalid-hooks.yml | 10 + internal/hooks/testdata/valid-hooks.yml | 21 ++ internal/hooks/validator.go | 130 +++++++++++ internal/hooks/validator_test.go | 251 +++++++++++++++++++++ 17 files changed, 1897 insertions(+), 21 deletions(-) create mode 100644 cmd/hooks.go create mode 100755 examples/hooks/bash-validator.py create mode 100755 examples/hooks/mcp-monitor.py create mode 100755 examples/hooks/prompt-logger.sh create mode 100644 internal/hooks/config.go create mode 100644 internal/hooks/config_test.go create mode 100644 internal/hooks/events.go create mode 100644 internal/hooks/executor.go create mode 100644 internal/hooks/executor_test.go create mode 100644 internal/hooks/schemas.go create mode 100644 internal/hooks/testdata/invalid-hooks.yml create mode 100644 internal/hooks/testdata/valid-hooks.yml create mode 100644 internal/hooks/validator.go create mode 100644 internal/hooks/validator_test.go diff --git a/README.md b/README.md index f5580426..97e2f342 100644 --- a/README.md +++ b/README.md @@ -558,6 +558,72 @@ See `examples/scripts/` for sample scripts: - `example-script.sh` - Script with custom MCP servers - `simple-script.sh` - Script using default config fallback +### Hooks System + +MCPHost supports a powerful hooks system that allows you to execute custom commands at specific points during execution. This enables security policies, logging, custom integrations, and automated workflows. + +#### Quick Start + +1. Initialize a hooks configuration: + ```bash + mcphost hooks init + ``` + +2. View active hooks: + ```bash + mcphost hooks list + ``` + +3. Validate your configuration: + ```bash + mcphost hooks validate + ``` + +#### Configuration + +Hooks are configured in YAML files with the following precedence (highest to lowest): +- `.mcphost/hooks.yml` (project-specific hooks) +- `$XDG_CONFIG_HOME/mcphost/hooks.yml` (user global hooks, defaults to `~/.config/mcphost/hooks.yml`) + +Example configuration: +```yaml +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "/usr/local/bin/validate-bash.py" + timeout: 5 + + UserPromptSubmit: + - hooks: + - type: command + command: "~/.mcphost/hooks/log-prompt.sh" +``` + +#### Available Hook Events + +- **PreToolUse**: Before any tool execution (bash, fetch, todo, MCP tools) +- **PostToolUse**: After tool execution completes +- **UserPromptSubmit**: When user submits a prompt +- **Stop**: When the agent finishes responding +- **SubagentStop**: When a subagent (Task tool) finishes +- **Notification**: When MCPHost sends notifications + +#### Security + +⚠️ **WARNING**: Hooks execute arbitrary commands on your system. Only use hooks from trusted sources and always review hook commands before enabling them. + +To temporarily disable all hooks, use the `--no-hooks` flag: +```bash +mcphost --no-hooks +``` + +See the example hook scripts in `examples/hooks/`: +- `bash-validator.py` - Validates and blocks dangerous bash commands +- `prompt-logger.sh` - Logs all user prompts with timestamps +- `mcp-monitor.py` - Monitors and enforces policies on MCP tool usage + ### Non-Interactive Mode Run a single prompt and exit - perfect for scripting and automation: diff --git a/cmd/hooks.go b/cmd/hooks.go new file mode 100644 index 00000000..3e5dad19 --- /dev/null +++ b/cmd/hooks.go @@ -0,0 +1,169 @@ +package cmd + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/mark3labs/mcphost/internal/hooks" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +var hooksCmd = &cobra.Command{ + Use: "hooks", + Short: "Manage MCPHost hooks", + Long: "Commands for managing and testing MCPHost hooks configuration", +} + +var hooksListCmd = &cobra.Command{ + Use: "list", + Short: "List all configured hooks", + RunE: func(cmd *cobra.Command, args []string) error { + config, err := hooks.LoadHooksConfig() + if err != nil { + return fmt.Errorf("loading hooks config: %w", err) + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "EVENT\tMATCHER\tCOMMAND\tTIMEOUT") + + for event, matchers := range config.Hooks { + for _, matcher := range matchers { + for _, hook := range matcher.Hooks { + timeout := "60s" + if hook.Timeout > 0 { + timeout = fmt.Sprintf("%ds", hook.Timeout) + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", + event, matcher.Matcher, hook.Command, timeout) + } + } + } + + return w.Flush() + }, +} + +var hooksValidateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate hooks configuration", + RunE: func(cmd *cobra.Command, args []string) error { + config, err := hooks.LoadHooksConfig() + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Additional validation + if err := hooks.ValidateHookConfig(config); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + fmt.Println("✓ Hooks configuration is valid") + return nil + }, +} + +var hooksInitCmd = &cobra.Command{ + Use: "init", + Short: "Generate example hooks configuration", + RunE: func(cmd *cobra.Command, args []string) error { + example := &hooks.HookConfig{ + Hooks: map[hooks.HookEvent][]hooks.HookMatcher{ + // PreToolUse - runs before any tool execution + hooks.PreToolUse: { + { + Matcher: "bash.*", + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] $ " + .tool_input.command' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-commands.log"`, + Timeout: 5, + }, + }, + }, + { + Matcher: ".*", // Log all tool usage + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), event: "pre", tool: .tool_name, input: .tool_input}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/all-tools.jsonl"`, + Timeout: 5, + }, + }, + }, + }, + // PostToolUse - runs after tool execution completes + hooks.PostToolUse: { + { + Matcher: "bash.*", + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), cmd: .tool_input.command, exit: .tool_response._meta.exit, stdout: (.tool_response._meta.stdout | rtrimstr("\n") | .[0:100]), stderr: (.tool_response._meta.stderr | rtrimstr("\n"))}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-audit.jsonl"`, + Timeout: 5, + }, + }, + }, + { + Matcher: "mcp__.*", // Log MCP tool responses + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), tool: .tool_name, response_preview: (.tool_response | tostring | .[0:200])}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/mcp-tools.jsonl"`, + Timeout: 5, + }, + }, + }, + }, + // UserPromptSubmit - runs when user submits a prompt + hooks.UserPromptSubmit: { + { + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] " + .prompt' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/prompts.log"`, + }, + }, + }, + }, + // Stop - runs when the main agent finishes responding + hooks.Stop: { + { + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] Session " + .session_id + " stopped"' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/sessions.log"`, + }, + }, + }, + }, + }, + } + + // Create .mcphost directory if it doesn't exist + if err := os.MkdirAll(".mcphost", 0755); err != nil { + return fmt.Errorf("creating .mcphost directory: %w", err) + } + + // Write example configuration + data, err := yaml.Marshal(example) + if err != nil { + return fmt.Errorf("marshaling example: %w", err) + } + + if err := os.WriteFile(".mcphost/hooks.yml", data, 0644); err != nil { + return fmt.Errorf("writing example: %w", err) + } + + fmt.Println("Created .mcphost/hooks.yml with example configuration") + return nil + }, +} + +func init() { + rootCmd.AddCommand(hooksCmd) + hooksCmd.AddCommand(hooksListCmd) + hooksCmd.AddCommand(hooksValidateCmd) + hooksCmd.AddCommand(hooksInitCmd) +} diff --git a/cmd/root.go b/cmd/root.go index bc55131c..8bf8426d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,10 +9,12 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/cloudwego/eino/schema" "github.com/mark3labs/mcphost/internal/agent" "github.com/mark3labs/mcphost/internal/config" + "github.com/mark3labs/mcphost/internal/hooks" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/session" "github.com/mark3labs/mcphost/internal/tokens" @@ -50,6 +52,9 @@ var ( // Ollama-specific parameters numGPU int32 mainGPU int32 + + // Hooks control + noHooks bool ) // agentUIAdapter adapts agent.Agent to ui.AgentInterface @@ -177,6 +182,19 @@ func initConfig() { // Set environment variable prefix viper.SetEnvPrefix("MCPHOST") viper.AutomaticEnv() + + // Load hooks configuration unless disabled + if !viper.GetBool("no-hooks") { + hooksConfig, err := hooks.LoadHooksConfig() + if err != nil { + // Hooks are optional, so just log a warning + if debugMode { + fmt.Fprintf(os.Stderr, "Warning: Failed to load hooks configuration: %v\n", err) + } + } else { + viper.Set("hooks", hooksConfig) + } + } } // loadConfigWithEnvSubstitution loads a config file with environment variable substitution @@ -230,6 +248,8 @@ func init() { BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display") rootCmd.PersistentFlags(). BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling") + rootCmd.PersistentFlags(). + BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution") // Session management flags rootCmd.PersistentFlags(). @@ -261,6 +281,7 @@ func init() { viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps")) viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream")) viper.BindPFlag("compact", rootCmd.PersistentFlags().Lookup("compact")) + viper.BindPFlag("no-hooks", rootCmd.PersistentFlags().Lookup("no-hooks")) viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url")) viper.BindPFlag("provider-api-key", rootCmd.PersistentFlags().Lookup("provider-api-key")) viper.BindPFlag("max-tokens", rootCmd.PersistentFlags().Lookup("max-tokens")) @@ -374,6 +395,7 @@ func runNormalMode(ctx context.Context) error { } defer mcpAgent.Close() + // Initialize hook executor if hooks are configured // Get model name for display modelString := viper.GetString("model") parts := strings.SplitN(modelString, ":", 2) @@ -382,6 +404,20 @@ func runNormalMode(ctx context.Context) error { modelName = parts[1] } + var hookExecutor *hooks.Executor + if hooksConfig := viper.Get("hooks"); hooksConfig != nil { + if hc, ok := hooksConfig.(*hooks.HookConfig); ok { + // Generate a session ID for this run + sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix()) + transcriptPath := "" // We could add transcript logging later + hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath) + + // Set model and interactive mode + hookExecutor.SetModel(modelString) + hookExecutor.SetInteractive(promptFlag == "") // Interactive if no prompt flag + } + } + // Create an adapter for the agent to match the UI interface agentAdapter := &agentUIAdapter{agent: mcpAgent} @@ -608,7 +644,7 @@ func runNormalMode(ctx context.Context) error { // Check if running in non-interactive mode if promptFlag != "" { - return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager) + return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager, hookExecutor) } // Quiet mode is not allowed in interactive mode @@ -616,7 +652,7 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet flag can only be used with --prompt/-p") } - return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager) + return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor) } // AgenticLoopConfig configures the behavior of the unified agentic loop @@ -672,9 +708,30 @@ func replaceMessagesHistory(messages *[]*schema.Message, sessionManager *session } // runAgenticLoop handles all execution modes with a single unified loop -func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error { +func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error { // Handle initial prompt for non-interactive modes if !config.IsInteractive && config.InitialPrompt != "" { + // Execute UserPromptSubmit hooks for non-interactive mode + if hookExecutor != nil { + input := &hooks.UserPromptSubmitInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit), + Prompt: config.InitialPrompt, + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err) + } + } + + // Check if hook blocked the prompt + if hookOutput != nil && hookOutput.Decision == "block" { + return fmt.Errorf("prompt blocked by hook: %s", hookOutput.Reason) + } + } + // Display user message (skip if quiet) if !config.Quiet && cli != nil { cli.DisplayUserMessage(config.InitialPrompt) @@ -684,7 +741,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes tempMessages := append(messages, schema.UserMessage(config.InitialPrompt)) // Process the initial prompt with tool calls - _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config) + _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor) if err != nil { // Check if this was a user cancellation if err.Error() == "generation cancelled by user" && cli != nil { @@ -712,14 +769,14 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes // Interactive loop (or continuation after non-interactive) if config.IsInteractive { - return runInteractiveLoop(ctx, mcpAgent, cli, messages, config) + return runInteractiveLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } return nil } // runAgenticStep processes a single step of the agentic loop (handles tool calls) -func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, []*schema.Message, error) { +func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) (*schema.Message, []*schema.Message, error) { var currentSpinner *ui.Spinner // Start initial spinner (skip if quiet) @@ -762,9 +819,17 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes streamingStarted = false streamingContent.Reset() + // Variables to store tool information for hooks + var currentToolName string + var currentToolArgs string + result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages, // Tool call handler - called when a tool is about to be executed func(toolName, toolArgs string) { + // Store tool info for use in execution handler + currentToolName = toolName + currentToolArgs = toolArgs + if !config.Quiet && cli != nil { // Stop spinner before displaying tool call if currentSpinner != nil { @@ -776,22 +841,66 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes }, // Tool execution handler - called when tool execution starts/ends func(toolName string, isStarting bool) { - if !config.Quiet && cli != nil { - if isStarting { + if isStarting { + // Execute PreToolUse hooks + if hookExecutor != nil { + input := &hooks.PreToolUseInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.PreToolUse), + ToolName: currentToolName, + ToolInput: json.RawMessage(currentToolArgs), + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.PreToolUse, input) + if err != nil { + // Log error but don't fail the tool execution + if debugMode { + fmt.Fprintf(os.Stderr, "Hook execution error: %v\n", err) + } + } + + // Check if hook blocked the execution + if hookOutput != nil && hookOutput.Decision == "block" { + // We need a way to cancel the tool execution + // For now, just log it + if !config.Quiet && cli != nil { + cli.DisplayInfo(fmt.Sprintf("Tool execution blocked by hook: %s", hookOutput.Reason)) + } + } + } + + if !config.Quiet && cli != nil { // Start spinner for tool execution currentSpinner = ui.NewSpinner(fmt.Sprintf("Executing %s...", toolName)) currentSpinner.Start() - } else { - // Stop spinner when tool execution completes - if currentSpinner != nil { - currentSpinner.Stop() - currentSpinner = nil - } + } + } else { + // Stop spinner when tool execution completes + if !config.Quiet && cli != nil && currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil } } }, // Tool result handler - called when a tool execution completes func(toolName, toolArgs, result string, isError bool) { + // Execute PostToolUse hooks + if hookExecutor != nil && result != "" { + input := &hooks.PostToolUseInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.PostToolUse), + ToolName: currentToolName, + ToolInput: json.RawMessage(currentToolArgs), + ToolResponse: json.RawMessage(result), + } + + _, err := hookExecutor.ExecuteHooks(ctx, hooks.PostToolUse, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "PostToolUse hook execution error: %v\n", err) + } + } + } + if !config.Quiet && cli != nil { // Parse tool result content - it might be JSON-encoded MCP content resultContent := result @@ -921,12 +1030,49 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes cli.DisplayUsageAfterResponse() } + // Execute Stop hook after agent has finished responding + executeStopHook(hookExecutor, response, "completed", config.ModelName) + // Return the final response and all conversation messages return response, conversationMessages, nil } +// executeStopHook executes the Stop hook if a hook executor is available +func executeStopHook(hookExecutor *hooks.Executor, response *schema.Message, stopReason string, modelName string) { + if hookExecutor != nil { + // Prepare metadata + var meta json.RawMessage + if response != nil { + metaData := map[string]interface{}{ + "model": modelName, + "role": string(response.Role), + "has_tool_calls": len(response.ToolCalls) > 0, + } + if metaBytes, err := json.Marshal(metaData); err == nil { + meta = json.RawMessage(metaBytes) + } + } + + responseContent := "" + if response != nil { + responseContent = response.Content + } + + input := &hooks.StopInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.Stop), + StopHookActive: true, + Response: responseContent, + StopReason: stopReason, + Meta: meta, + } + + // Execute Stop hook (ignore errors as we're exiting anyway) + hookExecutor.ExecuteHooks(context.Background(), hooks.Stop, input) + } +} + // runInteractiveLoop handles the interactive portion of the agentic loop -func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error { +func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error { for { // Get user input prompt, err := cli.GetPrompt() @@ -942,6 +1088,30 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, continue } + // Execute UserPromptSubmit hooks + if hookExecutor != nil { + input := &hooks.UserPromptSubmitInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit), + Prompt: prompt, + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err) + } + } + + // Check if hook blocked the prompt + if hookOutput != nil && hookOutput.Decision == "block" { + if cli != nil { + cli.DisplayInfo(fmt.Sprintf("Prompt blocked: %s", hookOutput.Reason)) + } + continue // Skip this prompt + } + } + // Handle slash commands if cli.IsSlashCommand(prompt) { result := cli.HandleSlashCommand(prompt, config.ServerNames, config.ToolNames) @@ -965,7 +1135,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, tempMessages := append(messages, schema.UserMessage(prompt)) // Process the user input with tool calls - _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config) + _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor) if err != nil { // Check if this was a user cancellation if err.Error() == "generation cancelled by user" { @@ -983,7 +1153,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, } // runNonInteractiveMode handles the non-interactive mode execution -func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager) error { +func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { // Prepare data for slash commands (needed if continuing to interactive mode) var serverNames []string for name := range mcpConfig.MCPServers { @@ -1011,11 +1181,11 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C SessionManager: sessionManager, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } // runInteractiveMode handles the interactive mode execution -func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager) error { +func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { // Configure and run unified agentic loop config := AgenticLoopConfig{ IsInteractive: true, @@ -1029,5 +1199,5 @@ func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, SessionManager: sessionManager, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } diff --git a/cmd/script.go b/cmd/script.go index 8ee02262..a0e357b4 100644 --- a/cmd/script.go +++ b/cmd/script.go @@ -8,10 +8,12 @@ import ( "os" "regexp" "strings" + "time" "github.com/cloudwego/eino/schema" "github.com/mark3labs/mcphost/internal/agent" "github.com/mark3labs/mcphost/internal/config" + "github.com/mark3labs/mcphost/internal/hooks" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/ui" "github.com/spf13/cobra" @@ -652,6 +654,21 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, cli.DisplayDebugConfig(debugConfig) } + // Initialize hooks + var hookExecutor *hooks.Executor + if hooksConfig := viper.Get("hooks"); hooksConfig != nil { + if hc, ok := hooksConfig.(*hooks.HookConfig); ok { + // Generate a session ID for this run + sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix()) + transcriptPath := "" // We could add transcript logging later + hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath) + + // Set model and interactive mode + hookExecutor.SetModel(finalModel) + hookExecutor.SetInteractive(prompt == "") + } + } + // Prepare data for slash commands var serverNames []string for name := range mcpConfig.MCPServers { @@ -679,5 +696,5 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, MCPConfig: mcpConfig, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } diff --git a/examples/hooks/bash-validator.py b/examples/hooks/bash-validator.py new file mode 100755 index 00000000..e3120a24 --- /dev/null +++ b/examples/hooks/bash-validator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +Validates bash commands before execution. +Blocks dangerous commands and suggests alternatives. +""" +import json +import sys +import re + +# Define validation rules +DANGEROUS_PATTERNS = [ + (r'\brm\s+-rf\s+/', "Dangerous command: rm -rf /"), + (r'\bdd\s+.*\bof=/dev/[sh]d[a-z]', "Direct disk write detected"), + (r'>\s*/dev/null\s+2>&1', "Consider using proper error handling instead of discarding stderr"), +] + +SUGGEST_ALTERNATIVES = { + r'\bgrep\b': "Use 'rg' (ripgrep) for better performance", + r'\bfind\s+.*-name': "Use 'fd' for faster file finding", +} + +def main(): + try: + # Read input + input_data = json.load(sys.stdin) + + # Only process bash commands + if input_data.get('tool_name') != 'bash': + sys.exit(0) + + command = json.loads(input_data.get('tool_input', '{}')).get('command', '') + + # Check dangerous patterns + for pattern, message in DANGEROUS_PATTERNS: + if re.search(pattern, command, re.IGNORECASE): + print(message, file=sys.stderr) + sys.exit(2) # Block execution + + # Suggest alternatives + suggestions = [] + for pattern, suggestion in SUGGEST_ALTERNATIVES.items(): + if re.search(pattern, command): + suggestions.append(suggestion) + + if suggestions: + output = { + "decision": "approve", + "reason": "Command approved. Suggestions: " + "; ".join(suggestions) + } + print(json.dumps(output)) + + except Exception as e: + print(f"Hook error: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/hooks/mcp-monitor.py b/examples/hooks/mcp-monitor.py new file mode 100755 index 00000000..94f86590 --- /dev/null +++ b/examples/hooks/mcp-monitor.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +""" +Monitors MCP tool usage and enforces policies. +""" +import json +import sys +import re +import os +from datetime import datetime + +# Define MCP tool policies +BLOCKED_MCP_TOOLS = [ + "mcp__github__delete_.*", # Block all GitHub delete operations + "mcp__aws__.*_production", # Block production AWS operations +] + +RATE_LIMITS = { + "mcp__openai__.*": (10, 60), # 10 calls per 60 seconds +} + +def check_rate_limit(tool_name, limits): + # This is a simplified example - real implementation would need persistent storage + # For now, just log the attempt + return True + +def main(): + try: + input_data = json.load(sys.stdin) + tool_name = input_data.get('tool_name', '') + + # Check if tool is blocked + for pattern in BLOCKED_MCP_TOOLS: + if re.match(pattern, tool_name): + output = { + "decision": "block", + "reason": f"Tool {tool_name} is blocked by security policy" + } + print(json.dumps(output)) + sys.exit(0) + + # Check rate limits + for pattern, (limit, window) in RATE_LIMITS.items(): + if re.match(pattern, tool_name): + if not check_rate_limit(tool_name, (limit, window)): + output = { + "decision": "block", + "reason": f"Rate limit exceeded: {limit} calls per {window}s" + } + print(json.dumps(output)) + sys.exit(0) + + # Log MCP tool usage + log_entry = { + "timestamp": datetime.now().isoformat(), + "tool": tool_name, + "input": input_data.get('tool_input', {}) + } + + # Use XDG_CONFIG_HOME if set, otherwise default to ~/.config + config_home = os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config')) + log_dir = os.path.join(config_home, 'mcphost', 'logs') + os.makedirs(log_dir, exist_ok=True) + + with open(os.path.join(log_dir, "mcp-usage.jsonl"), "a") as f: + f.write(json.dumps(log_entry) + "\n") + + except Exception as e: + print(f"Hook error: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/hooks/prompt-logger.sh b/examples/hooks/prompt-logger.sh new file mode 100755 index 00000000..2c51fdf2 --- /dev/null +++ b/examples/hooks/prompt-logger.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Logs all user prompts with timestamp + +# Read JSON input +input=$(cat) + +# Extract prompt using jq (ensure jq is installed) +prompt=$(echo "$input" | jq -r '.prompt // empty') + +if [ -n "$prompt" ]; then + # Use XDG_CONFIG_HOME if set, otherwise default to ~/.config + CONFIG_DIR="${XDG_CONFIG_HOME:-$HOME/.config}" + LOG_DIR="$CONFIG_DIR/mcphost/logs" + + # Create log directory if it doesn't exist + mkdir -p "$LOG_DIR" + + # Log with timestamp + echo "[$(date '+%Y-%m-%d %H:%M:%S')] $prompt" >> "$LOG_DIR/prompts.log" +fi + +# Always allow prompt to continue +exit 0 \ No newline at end of file diff --git a/internal/hooks/config.go b/internal/hooks/config.go new file mode 100644 index 00000000..dc59c15e --- /dev/null +++ b/internal/hooks/config.go @@ -0,0 +1,132 @@ +package hooks + +import ( + "encoding/json" + "fmt" + "github.com/mark3labs/mcphost/internal/config" + "gopkg.in/yaml.v3" + "os" + "path/filepath" +) + +// HookConfig represents the complete hooks configuration +type HookConfig struct { + Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"` +} + +// HookMatcher matches specific tools and defines hooks to execute +type HookMatcher struct { + Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"` + Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"` + Hooks []HookEntry `yaml:"hooks" json:"hooks"` +} + +// HookEntry defines a single hook command +type HookEntry struct { + Type string `yaml:"type" json:"type"` + Command string `yaml:"command" json:"command"` + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` +} + +// LoadHooksConfig loads and merges hook configurations from multiple sources +func LoadHooksConfig(customPaths ...string) (*HookConfig, error) { + // Get config directory following XDG Base Directory specification + configDir := getConfigDir() + + // Define search paths in order of precedence (lowest to highest) + searchPaths := []string{ + filepath.Join(configDir, "mcphost", "hooks.json"), + filepath.Join(configDir, "mcphost", "hooks.yml"), + ".mcphost/hooks.json", + ".mcphost/hooks.yml", + } + + // Add custom paths with highest precedence + searchPaths = append(searchPaths, customPaths...) + + merged := &HookConfig{ + Hooks: make(map[HookEvent][]HookMatcher), + } + + for _, path := range searchPaths { + if _, err := os.Stat(path); os.IsNotExist(err) { + continue + } + + // Read file content + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading %s: %w", path, err) + } + + // Apply environment substitution + envSubstituter := &config.EnvSubstituter{} + substituted, err := envSubstituter.SubstituteEnvVars(string(content)) + if err != nil { + return nil, fmt.Errorf("substituting env vars in %s: %w", path, err) + } + + // Parse configuration + var cfg HookConfig + if filepath.Ext(path) == ".json" { + err = json.Unmarshal([]byte(substituted), &cfg) + } else { + err = yaml.Unmarshal([]byte(substituted), &cfg) + } + if err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + + // Merge configurations + mergeHookConfigs(merged, &cfg) + } + + return merged, nil +} + +// getConfigDir returns the configuration directory following XDG Base Directory specification +func getConfigDir() string { + // Try XDG_CONFIG_HOME first + if xdgConfig := os.Getenv("XDG_CONFIG_HOME"); xdgConfig != "" { + return xdgConfig + } + + // Fall back to ~/.config + if home := os.Getenv("HOME"); home != "" { + return filepath.Join(home, ".config") + } + + // Last resort: current directory + return "." +} + +// mergeHookConfigs merges source hooks into destination +func mergeHookConfigs(dst, src *HookConfig) { + for event, matchers := range src.Hooks { + if dst.Hooks[event] == nil { + dst.Hooks[event] = matchers + continue + } + + // Handle merge strategies + for _, srcMatcher := range matchers { + if srcMatcher.Merge == "replace" { + // Replace all matchers for this event + dst.Hooks[event] = []HookMatcher{srcMatcher} + } else { + // Append or update existing matcher + found := false + for i, dstMatcher := range dst.Hooks[event] { + if dstMatcher.Matcher == srcMatcher.Matcher { + dst.Hooks[event][i] = srcMatcher + found = true + break + } + } + if !found { + dst.Hooks[event] = append(dst.Hooks[event], srcMatcher) + } + } + } + } +} diff --git a/internal/hooks/config_test.go b/internal/hooks/config_test.go new file mode 100644 index 00000000..2f9a2682 --- /dev/null +++ b/internal/hooks/config_test.go @@ -0,0 +1,224 @@ +package hooks + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestLoadHooksConfig(t *testing.T) { + // Save original XDG_CONFIG_HOME + originalXDG := os.Getenv("XDG_CONFIG_HOME") + defer func() { + if originalXDG != "" { + os.Setenv("XDG_CONFIG_HOME", originalXDG) + } else { + os.Unsetenv("XDG_CONFIG_HOME") + } + }() + + tests := []struct { + name string + files map[string]string + expected *HookConfig + wantErr bool + }{ + { + name: "single yaml file", + files: map[string]string{ + "hooks.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo test" + timeout: 5 +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: 5}, + }, + }, + }, + }, + }, + }, + { + name: "environment substitution", + files: map[string]string{ + "hooks.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "${env://TEST_HOOK_CMD:-echo default}" +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo default"}, + }, + }, + }, + }, + }, + }, + { + name: "merge multiple files", + files: map[string]string{ + "global.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "global-hook" +`, + "local.yml": ` +hooks: + PreToolUse: + - matcher: "fetch" + hooks: + - type: command + command: "local-hook" +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{{Type: "command", Command: "global-hook"}}, + }, + { + Matcher: "fetch", + Hooks: []HookEntry{{Type: "command", Command: "local-hook"}}, + }, + }, + }, + }, + }, + { + name: "invalid yaml", + files: map[string]string{ + "hooks.yml": `invalid yaml content`, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary directory for test files + tmpDir := t.TempDir() + + // Set XDG_CONFIG_HOME to a temp directory to avoid loading global hooks + testConfigDir := filepath.Join(tmpDir, "config") + os.Setenv("XDG_CONFIG_HOME", testConfigDir) + + // Write test files + var paths []string + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + paths = append(paths, path) + } + + // Load configuration + got, err := LoadHooksConfig(paths...) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("LoadHooksConfig() = %+v, want %+v", got, tt.expected) + } + }) + } +} + +func TestMatchesPattern(t *testing.T) { + tests := []struct { + pattern string + toolName string + want bool + }{ + {"", "bash", true}, // Empty pattern matches all + {"bash", "bash", true}, // Exact match + {"bash", "Bash", false}, // Case sensitive + {"bash|fetch", "bash", true}, // Regex OR + {"bash|fetch", "fetch", true}, // Regex OR + {"bash|fetch", "todo", false}, // Regex OR no match + {"mcp__.*", "mcp__filesystem__read", true}, // MCP pattern + {".*write.*", "mcp__fs__write_file", true}, // Contains pattern + {"^bash$", "bash", true}, // Anchored regex + {"^bash$", "bash2", false}, // Anchored regex no match + } + + for _, tt := range tests { + t.Run(tt.pattern+"_"+tt.toolName, func(t *testing.T) { + got := matchesPattern(tt.pattern, tt.toolName) + if got != tt.want { + t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.toolName, got, tt.want) + } + }) + } +} + +func TestNoHooksFlag(t *testing.T) { + // This test verifies that when hooks are disabled via configuration, + // the LoadHooksConfig function is not called. The actual implementation + // of this is in cmd/root.go where viper.GetBool("no-hooks") is checked. + // This test documents the expected behavior. + + // Create a test hooks file + tmpDir := t.TempDir() + hooksFile := filepath.Join(tmpDir, "hooks.yml") + content := ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo 'This should not run'" +` + if err := os.WriteFile(hooksFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + // Load the hooks config normally + config, err := LoadHooksConfig(hooksFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify hooks are loaded + if len(config.Hooks) == 0 { + t.Error("expected hooks to be loaded") + } + + // The actual --no-hooks flag implementation is in cmd/root.go + // where it checks viper.GetBool("no-hooks") before calling LoadHooksConfig +} diff --git a/internal/hooks/events.go b/internal/hooks/events.go new file mode 100644 index 00000000..7b61397a --- /dev/null +++ b/internal/hooks/events.go @@ -0,0 +1,32 @@ +package hooks + +// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed +type HookEvent string + +const ( + // PreToolUse fires before any tool execution + PreToolUse HookEvent = "PreToolUse" + + // PostToolUse fires after tool execution completes + PostToolUse HookEvent = "PostToolUse" + + // UserPromptSubmit fires when user submits a prompt + UserPromptSubmit HookEvent = "UserPromptSubmit" + + // Stop fires when the main agent finishes responding + Stop HookEvent = "Stop" +) + +// IsValid returns true if the event is a valid hook event +func (e HookEvent) IsValid() bool { + switch e { + case PreToolUse, PostToolUse, UserPromptSubmit, Stop: + return true + } + return false +} + +// RequiresMatcher returns true if the event uses tool matchers +func (e HookEvent) RequiresMatcher() bool { + return e == PreToolUse || e == PostToolUse +} diff --git a/internal/hooks/executor.go b/internal/hooks/executor.go new file mode 100644 index 00000000..0c050f10 --- /dev/null +++ b/internal/hooks/executor.go @@ -0,0 +1,254 @@ +package hooks + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "regexp" + "sync" + "time" +) + +// Executor handles hook execution +type Executor struct { + config *HookConfig + sessionID string + transcript string + model string + interactive bool + mu sync.RWMutex +} + +// NewExecutor creates a new hook executor +func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor { + return &Executor{ + config: config, + sessionID: sessionID, + transcript: transcriptPath, + } +} + +// SetModel sets the model name for hook context +func (e *Executor) SetModel(model string) { + e.mu.Lock() + defer e.mu.Unlock() + e.model = model +} + +// SetInteractive sets whether we're in interactive mode +func (e *Executor) SetInteractive(interactive bool) { + e.mu.Lock() + defer e.mu.Unlock() + e.interactive = interactive +} + +// PopulateCommonFields fills in the common fields for any hook input +func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput { + e.mu.RLock() + defer e.mu.RUnlock() + + cwd, _ := os.Getwd() + return CommonInput{ + SessionID: e.sessionID, + TranscriptPath: e.transcript, + CWD: cwd, + HookEventName: event, + Timestamp: time.Now().Unix(), + Model: e.model, + Interactive: e.interactive, + } +} + +// ExecuteHooks runs all matching hooks for an event +func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) { + matchers, ok := e.config.Hooks[event] + if !ok || len(matchers) == 0 { + return nil, nil + } + + // Get tool name if applicable + toolName := "" + if event.RequiresMatcher() { + toolName = extractToolName(input) + } + + // Find matching hooks + var hooksToRun []HookEntry + for _, matcher := range matchers { + if matchesPattern(matcher.Matcher, toolName) { + hooksToRun = append(hooksToRun, matcher.Hooks...) + } + } + + if len(hooksToRun) == 0 { + return nil, nil + } + + // Execute hooks in parallel + results := make(chan *hookResult, len(hooksToRun)) + var wg sync.WaitGroup + + for _, hook := range hooksToRun { + wg.Add(1) + go func(h HookEntry) { + defer wg.Done() + result := e.executeHook(ctx, h, input) + results <- result + }(hook) + } + + wg.Wait() + close(results) + + // Process results + return e.processResults(results) +} + +// executeHook runs a single hook command +func (e *Executor) executeHook(ctx context.Context, hook HookEntry, input interface{}) *hookResult { + // Prepare input JSON + inputJSON, err := json.Marshal(input) + if err != nil { + return &hookResult{err: fmt.Errorf("marshaling input: %w", err)} + } + + // Set timeout + timeout := time.Duration(hook.Timeout) * time.Second + if timeout == 0 { + timeout = 60 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Create command + cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command) + cmd.Stdin = bytes.NewReader(inputJSON) + cmd.Dir = getCurrentWorkingDir() + + // Capture output + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Execute + err = cmd.Run() + + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + } + + return &hookResult{ + exitCode: exitCode, + stdout: stdout.String(), + stderr: stderr.String(), + err: err, + } +} + +// matchesPattern checks if a tool name matches a pattern +func matchesPattern(pattern, toolName string) bool { + if pattern == "" { + return true // Empty pattern matches all + } + + // Try exact match first + if pattern == toolName { + return true + } + + // Try regex match + matched, err := regexp.MatchString(pattern, toolName) + if err != nil { + // Invalid regex pattern, return false + return false + } + + return matched +} + +// extractToolName gets the tool name from various input types +func extractToolName(input interface{}) string { + switch v := input.(type) { + case *PreToolUseInput: + return v.ToolName + case *PostToolUseInput: + return v.ToolName + default: + return "" + } +} + +type hookResult struct { + exitCode int + stdout string + stderr string + err error +} + +// processResults combines results from multiple hooks +func (e *Executor) processResults(results <-chan *hookResult) (*HookOutput, error) { + var finalOutput HookOutput + + for result := range results { + if result.err != nil && result.exitCode != 2 { + // Hook execution failed, skip this result + continue + } + + // Handle exit code 2 (blocking error) + if result.exitCode == 2 { + finalOutput.Decision = "block" + finalOutput.Reason = result.stderr + continueVal := false + finalOutput.Continue = &continueVal + return &finalOutput, nil + } + + // Try to parse JSON output + if result.stdout != "" { + var output HookOutput + if err := json.Unmarshal([]byte(result.stdout), &output); err == nil { + // Merge outputs (later hooks can override) + mergeHookOutputs(&finalOutput, &output) + } + } + } + + return &finalOutput, nil +} + +// mergeHookOutputs combines two hook outputs +func mergeHookOutputs(dst, src *HookOutput) { + if src.Continue != nil { + dst.Continue = src.Continue + } + if src.StopReason != "" { + dst.StopReason = src.StopReason + } + if src.Decision != "" { + dst.Decision = src.Decision + } + if src.Reason != "" { + dst.Reason = src.Reason + } + if src.SuppressOutput { + dst.SuppressOutput = true + } +} + +func getCurrentWorkingDir() string { + cwd, err := os.Getwd() + if err != nil { + return "/" + } + return cwd +} diff --git a/internal/hooks/executor_test.go b/internal/hooks/executor_test.go new file mode 100644 index 00000000..a5748268 --- /dev/null +++ b/internal/hooks/executor_test.go @@ -0,0 +1,193 @@ +package hooks + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestExecuteHooks(t *testing.T) { + // Create test scripts + tmpDir := t.TempDir() + + // Simple echo script + echoScript := filepath.Join(tmpDir, "echo.sh") + if err := os.WriteFile(echoScript, []byte(`#!/bin/bash +cat +`), 0755); err != nil { + t.Fatalf("failed to create echo script: %v", err) + } + + // Blocking script (exit code 2) + blockScript := filepath.Join(tmpDir, "block.sh") + if err := os.WriteFile(blockScript, []byte(`#!/bin/bash +echo "Blocked by policy" >&2 +exit 2 +`), 0755); err != nil { + t.Fatalf("failed to create block script: %v", err) + } + + // JSON output script + jsonScript := filepath.Join(tmpDir, "json.sh") + if err := os.WriteFile(jsonScript, []byte(`#!/bin/bash +echo '{"decision": "approve", "reason": "Approved by test"}' +`), 0755); err != nil { + t.Fatalf("failed to create json script: %v", err) + } + + tests := []struct { + name string + config *HookConfig + event HookEvent + input interface{} + expected *HookOutput + wantErr bool + }{ + { + name: "simple command execution", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: echoScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{}, + }, + { + name: "blocking hook", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: blockScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{ + Decision: "block", + Reason: "Blocked by policy\n", + Continue: boolPtr(false), + }, + }, + { + name: "JSON output parsing", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: jsonScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{ + Decision: "approve", + Reason: "Approved by test", + }, + }, + { + name: "timeout handling", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: "sleep 10", + Timeout: 1, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + executor := NewExecutor(tt.config, "test-session", "/tmp/test.jsonl") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + got, err := executor.ExecuteHooks(ctx, tt.event, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Compare outputs + if !compareHookOutputs(got, tt.expected) { + gotJSON, _ := json.MarshalIndent(got, "", " ") + expectedJSON, _ := json.MarshalIndent(tt.expected, "", " ") + t.Errorf("ExecuteHooks() output mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, expectedJSON) + } + }) + } +} + +func boolPtr(b bool) *bool { + return &b +} + +func compareHookOutputs(a, b *HookOutput) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // Compare Continue pointers + if (a.Continue == nil) != (b.Continue == nil) { + return false + } + if a.Continue != nil && *a.Continue != *b.Continue { + return false + } + + return a.StopReason == b.StopReason && + a.SuppressOutput == b.SuppressOutput && + a.Decision == b.Decision && + a.Reason == b.Reason +} diff --git a/internal/hooks/schemas.go b/internal/hooks/schemas.go new file mode 100644 index 00000000..9d79ca92 --- /dev/null +++ b/internal/hooks/schemas.go @@ -0,0 +1,55 @@ +package hooks + +import ( + "encoding/json" +) + +// CommonInput contains fields common to all hook inputs +type CommonInput struct { + SessionID string `json:"session_id"` // Unique session identifier + TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled) + CWD string `json:"cwd"` // Current working directory + HookEventName HookEvent `json:"hook_event_name"` // The hook event type + Timestamp int64 `json:"timestamp"` // Unix timestamp when hook fired + Model string `json:"model"` // AI model being used + Interactive bool `json:"interactive"` // Whether in interactive mode +} + +// PreToolUseInput is passed to PreToolUse hooks +type PreToolUseInput struct { + CommonInput + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` +} + +// PostToolUseInput is passed to PostToolUse hooks +type PostToolUseInput struct { + CommonInput + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` + ToolResponse json.RawMessage `json:"tool_response"` +} + +// UserPromptSubmitInput is passed to UserPromptSubmit hooks +type UserPromptSubmitInput struct { + CommonInput + Prompt string `json:"prompt"` +} + +// StopInput is passed to Stop hooks +type StopInput struct { + CommonInput + StopHookActive bool `json:"stop_hook_active"` + Response string `json:"response"` // The agent's final response + StopReason string `json:"stop_reason"` // "completed", "cancelled", "error" + Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info) +} + +// HookOutput represents the JSON output from a hook +type HookOutput struct { + Continue *bool `json:"continue,omitempty"` + StopReason string `json:"stopReason,omitempty"` + SuppressOutput bool `json:"suppressOutput,omitempty"` + Decision string `json:"decision,omitempty"` // "approve", "block", or "" + Reason string `json:"reason,omitempty"` +} diff --git a/internal/hooks/testdata/invalid-hooks.yml b/internal/hooks/testdata/invalid-hooks.yml new file mode 100644 index 00000000..293188bc --- /dev/null +++ b/internal/hooks/testdata/invalid-hooks.yml @@ -0,0 +1,10 @@ +hooks: + InvalidEvent: + - hooks: + - type: command + command: "echo test" + PreToolUse: + - matcher: "[invalid regex" + hooks: + - type: command + command: "echo test" \ No newline at end of file diff --git a/internal/hooks/testdata/valid-hooks.yml b/internal/hooks/testdata/valid-hooks.yml new file mode 100644 index 00000000..2bcdff74 --- /dev/null +++ b/internal/hooks/testdata/valid-hooks.yml @@ -0,0 +1,21 @@ +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo 'Executing bash command'" + timeout: 5 + - matcher: "fetch" + hooks: + - type: command + command: "echo 'Fetching URL'" + timeout: 10 + UserPromptSubmit: + - hooks: + - type: command + command: "date >> /tmp/mcphost-prompts.log" + PostToolUse: + - matcher: ".*" + hooks: + - type: command + command: "echo 'Tool execution completed'" \ No newline at end of file diff --git a/internal/hooks/validator.go b/internal/hooks/validator.go new file mode 100644 index 00000000..504b5e68 --- /dev/null +++ b/internal/hooks/validator.go @@ -0,0 +1,130 @@ +package hooks + +import ( + "fmt" + "regexp" + "strings" +) + +// Security patterns to detect potentially dangerous commands +var ( + commandInjectionPattern = regexp.MustCompile(`[;&|]|\$\(|` + "`") + pathTraversalPattern = regexp.MustCompile(`\.\.\/`) + commandSubstitutionPattern = regexp.MustCompile(`\$\([^)]+\)|` + "`" + `[^` + "`" + `]+` + "`") +) + +// validateHookCommand validates a hook command for security issues +func validateHookCommand(command string) error { + if command == "" { + return fmt.Errorf("empty command") + } + + // Check for command injection attempts + if commandInjectionPattern.MatchString(command) { + // Allow simple pipes and redirects, but check for dangerous patterns + if containsDangerousPattern(command) { + return fmt.Errorf("potential command injection detected") + } + } + + // Check for path traversal + if pathTraversalPattern.MatchString(command) { + return fmt.Errorf("path traversal detected") + } + + // Check for command substitution + if commandSubstitutionPattern.MatchString(command) { + return fmt.Errorf("command substitution detected") + } + + return nil +} + +// containsDangerousPattern checks for specific dangerous command patterns +func containsDangerousPattern(command string) bool { + dangerousPatterns := []string{ + "; rm ", + "&& rm ", + "| rm ", + "; dd ", + "&& dd ", + "| dd ", + "/dev/null 2>&1", + } + + for _, pattern := range dangerousPatterns { + if strings.Contains(command, pattern) { + return true + } + } + + // Check for multiple command separators which might indicate injection + separatorCount := 0 + for _, sep := range []string{";", "&&", "||", "|"} { + separatorCount += strings.Count(command, sep) + } + + // Allow up to 2 separators for reasonable command chaining + return separatorCount > 2 +} + +// ValidateHookConfig validates the entire hook configuration +func ValidateHookConfig(config *HookConfig) error { + if config == nil { + return fmt.Errorf("nil configuration") + } + + for event, matchers := range config.Hooks { + if !event.IsValid() { + return fmt.Errorf("invalid event: %s", event) + } + + for i, matcher := range matchers { + // Validate regex pattern if provided + if matcher.Matcher != "" { + if _, err := regexp.Compile(matcher.Matcher); err != nil { + return fmt.Errorf("invalid regex pattern in matcher %d for event %s: %w", i, event, err) + } + } + + // Validate hooks + if len(matcher.Hooks) == 0 { + return fmt.Errorf("no hooks defined for matcher %d in event %s", i, event) + } + + for j, hook := range matcher.Hooks { + if err := validateHookEntry(hook); err != nil { + return fmt.Errorf("invalid hook %d in matcher %d for event %s: %w", j, i, event, err) + } + } + } + } + + return nil +} + +// validateHookEntry validates a single hook entry +func validateHookEntry(hook HookEntry) error { + if hook.Type != "command" { + return fmt.Errorf("invalid hook type: %s (only 'command' is supported)", hook.Type) + } + + if hook.Command == "" { + return fmt.Errorf("empty command") + } + + // Basic security validation + if err := validateHookCommand(hook.Command); err != nil { + return fmt.Errorf("command validation failed: %w", err) + } + + if hook.Timeout < 0 { + return fmt.Errorf("negative timeout: %d", hook.Timeout) + } + + if hook.Timeout > 600 { // 10 minutes max + return fmt.Errorf("timeout too large: %d (max 600 seconds)", hook.Timeout) + } + + return nil +} diff --git a/internal/hooks/validator_test.go b/internal/hooks/validator_test.go new file mode 100644 index 00000000..5b62dbfd --- /dev/null +++ b/internal/hooks/validator_test.go @@ -0,0 +1,251 @@ +package hooks + +import ( + "strings" + "testing" +) + +func TestValidateHookCommand(t *testing.T) { + tests := []struct { + name string + command string + wantErr bool + errMsg string + }{ + { + name: "simple command", + command: "echo hello", + wantErr: false, + }, + { + name: "absolute path", + command: "/usr/local/bin/validator.py", + wantErr: false, + }, + { + name: "command injection attempt", + command: "echo test; rm -rf /", + wantErr: true, + errMsg: "potential command injection", + }, + { + name: "path traversal", + command: "cat ../../../etc/passwd", + wantErr: true, + errMsg: "path traversal detected", + }, + { + name: "command substitution", + command: "echo $(/bin/sh -c 'malicious')", + wantErr: true, + errMsg: "command substitution detected", + }, + { + name: "backtick substitution", + command: "echo `whoami`", + wantErr: true, + errMsg: "command substitution detected", + }, + { + name: "empty command", + command: "", + wantErr: true, + errMsg: "empty command", + }, + { + name: "simple pipe allowed", + command: "ps aux | grep process", + wantErr: false, + }, + { + name: "too many command separators", + command: "cmd1 | cmd2 && cmd3 ; cmd4", + wantErr: true, + errMsg: "potential command injection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateHookCommand(tt.command) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestValidateHookConfig(t *testing.T) { + tests := []struct { + name string + config *HookConfig + wantErr bool + errMsg string + }{ + { + name: "valid config", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "nil config", + config: nil, + wantErr: true, + errMsg: "nil configuration", + }, + { + name: "invalid event", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + "InvalidEvent": { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid event", + }, + { + name: "invalid regex pattern", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "[invalid", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid regex pattern", + }, + { + name: "no hooks defined", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{}, + }, + }, + }, + }, + wantErr: true, + errMsg: "no hooks defined", + }, + { + name: "invalid hook type", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "invalid", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid hook type", + }, + { + name: "empty command", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: ""}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "empty command", + }, + { + name: "negative timeout", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: -1}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "negative timeout", + }, + { + name: "timeout too large", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: 700}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "timeout too large", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateHookConfig(tt.config) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +}