diff --git a/README.md b/README.md index f5580426..97e2f342 100644 --- a/README.md +++ b/README.md @@ -558,6 +558,72 @@ See `examples/scripts/` for sample scripts: - `example-script.sh` - Script with custom MCP servers - `simple-script.sh` - Script using default config fallback +### Hooks System + +MCPHost supports a powerful hooks system that allows you to execute custom commands at specific points during execution. This enables security policies, logging, custom integrations, and automated workflows. + +#### Quick Start + +1. Initialize a hooks configuration: + ```bash + mcphost hooks init + ``` + +2. View active hooks: + ```bash + mcphost hooks list + ``` + +3. Validate your configuration: + ```bash + mcphost hooks validate + ``` + +#### Configuration + +Hooks are configured in YAML files with the following precedence (highest to lowest): +- `.mcphost/hooks.yml` (project-specific hooks) +- `$XDG_CONFIG_HOME/mcphost/hooks.yml` (user global hooks, defaults to `~/.config/mcphost/hooks.yml`) + +Example configuration: +```yaml +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "/usr/local/bin/validate-bash.py" + timeout: 5 + + UserPromptSubmit: + - hooks: + - type: command + command: "~/.mcphost/hooks/log-prompt.sh" +``` + +#### Available Hook Events + +- **PreToolUse**: Before any tool execution (bash, fetch, todo, MCP tools) +- **PostToolUse**: After tool execution completes +- **UserPromptSubmit**: When user submits a prompt +- **Stop**: When the agent finishes responding +- **SubagentStop**: When a subagent (Task tool) finishes +- **Notification**: When MCPHost sends notifications + +#### Security + +⚠️ **WARNING**: Hooks execute arbitrary commands on your system. Only use hooks from trusted sources and always review hook commands before enabling them. + +To temporarily disable all hooks, use the `--no-hooks` flag: +```bash +mcphost --no-hooks +``` + +See the example hook scripts in `examples/hooks/`: +- `bash-validator.py` - Validates and blocks dangerous bash commands +- `prompt-logger.sh` - Logs all user prompts with timestamps +- `mcp-monitor.py` - Monitors and enforces policies on MCP tool usage + ### Non-Interactive Mode Run a single prompt and exit - perfect for scripting and automation: diff --git a/cmd/hooks.go b/cmd/hooks.go new file mode 100644 index 00000000..3e5dad19 --- /dev/null +++ b/cmd/hooks.go @@ -0,0 +1,169 @@ +package cmd + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/mark3labs/mcphost/internal/hooks" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +var hooksCmd = &cobra.Command{ + Use: "hooks", + Short: "Manage MCPHost hooks", + Long: "Commands for managing and testing MCPHost hooks configuration", +} + +var hooksListCmd = &cobra.Command{ + Use: "list", + Short: "List all configured hooks", + RunE: func(cmd *cobra.Command, args []string) error { + config, err := hooks.LoadHooksConfig() + if err != nil { + return fmt.Errorf("loading hooks config: %w", err) + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "EVENT\tMATCHER\tCOMMAND\tTIMEOUT") + + for event, matchers := range config.Hooks { + for _, matcher := range matchers { + for _, hook := range matcher.Hooks { + timeout := "60s" + if hook.Timeout > 0 { + timeout = fmt.Sprintf("%ds", hook.Timeout) + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", + event, matcher.Matcher, hook.Command, timeout) + } + } + } + + return w.Flush() + }, +} + +var hooksValidateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate hooks configuration", + RunE: func(cmd *cobra.Command, args []string) error { + config, err := hooks.LoadHooksConfig() + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + // Additional validation + if err := hooks.ValidateHookConfig(config); err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + fmt.Println("✓ Hooks configuration is valid") + return nil + }, +} + +var hooksInitCmd = &cobra.Command{ + Use: "init", + Short: "Generate example hooks configuration", + RunE: func(cmd *cobra.Command, args []string) error { + example := &hooks.HookConfig{ + Hooks: map[hooks.HookEvent][]hooks.HookMatcher{ + // PreToolUse - runs before any tool execution + hooks.PreToolUse: { + { + Matcher: "bash.*", + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] $ " + .tool_input.command' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-commands.log"`, + Timeout: 5, + }, + }, + }, + { + Matcher: ".*", // Log all tool usage + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), event: "pre", tool: .tool_name, input: .tool_input}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/all-tools.jsonl"`, + Timeout: 5, + }, + }, + }, + }, + // PostToolUse - runs after tool execution completes + hooks.PostToolUse: { + { + Matcher: "bash.*", + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), cmd: .tool_input.command, exit: .tool_response._meta.exit, stdout: (.tool_response._meta.stdout | rtrimstr("\n") | .[0:100]), stderr: (.tool_response._meta.stderr | rtrimstr("\n"))}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-audit.jsonl"`, + Timeout: 5, + }, + }, + }, + { + Matcher: "mcp__.*", // Log MCP tool responses + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), tool: .tool_name, response_preview: (.tool_response | tostring | .[0:200])}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/mcp-tools.jsonl"`, + Timeout: 5, + }, + }, + }, + }, + // UserPromptSubmit - runs when user submits a prompt + hooks.UserPromptSubmit: { + { + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] " + .prompt' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/prompts.log"`, + }, + }, + }, + }, + // Stop - runs when the main agent finishes responding + hooks.Stop: { + { + Hooks: []hooks.HookEntry{ + { + Type: "command", + Command: `jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] Session " + .session_id + " stopped"' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/sessions.log"`, + }, + }, + }, + }, + }, + } + + // Create .mcphost directory if it doesn't exist + if err := os.MkdirAll(".mcphost", 0755); err != nil { + return fmt.Errorf("creating .mcphost directory: %w", err) + } + + // Write example configuration + data, err := yaml.Marshal(example) + if err != nil { + return fmt.Errorf("marshaling example: %w", err) + } + + if err := os.WriteFile(".mcphost/hooks.yml", data, 0644); err != nil { + return fmt.Errorf("writing example: %w", err) + } + + fmt.Println("Created .mcphost/hooks.yml with example configuration") + return nil + }, +} + +func init() { + rootCmd.AddCommand(hooksCmd) + hooksCmd.AddCommand(hooksListCmd) + hooksCmd.AddCommand(hooksValidateCmd) + hooksCmd.AddCommand(hooksInitCmd) +} diff --git a/cmd/root.go b/cmd/root.go index bc55131c..8bf8426d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -9,10 +9,12 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/cloudwego/eino/schema" "github.com/mark3labs/mcphost/internal/agent" "github.com/mark3labs/mcphost/internal/config" + "github.com/mark3labs/mcphost/internal/hooks" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/session" "github.com/mark3labs/mcphost/internal/tokens" @@ -50,6 +52,9 @@ var ( // Ollama-specific parameters numGPU int32 mainGPU int32 + + // Hooks control + noHooks bool ) // agentUIAdapter adapts agent.Agent to ui.AgentInterface @@ -177,6 +182,19 @@ func initConfig() { // Set environment variable prefix viper.SetEnvPrefix("MCPHOST") viper.AutomaticEnv() + + // Load hooks configuration unless disabled + if !viper.GetBool("no-hooks") { + hooksConfig, err := hooks.LoadHooksConfig() + if err != nil { + // Hooks are optional, so just log a warning + if debugMode { + fmt.Fprintf(os.Stderr, "Warning: Failed to load hooks configuration: %v\n", err) + } + } else { + viper.Set("hooks", hooksConfig) + } + } } // loadConfigWithEnvSubstitution loads a config file with environment variable substitution @@ -230,6 +248,8 @@ func init() { BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display") rootCmd.PersistentFlags(). BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling") + rootCmd.PersistentFlags(). + BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution") // Session management flags rootCmd.PersistentFlags(). @@ -261,6 +281,7 @@ func init() { viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps")) viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream")) viper.BindPFlag("compact", rootCmd.PersistentFlags().Lookup("compact")) + viper.BindPFlag("no-hooks", rootCmd.PersistentFlags().Lookup("no-hooks")) viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url")) viper.BindPFlag("provider-api-key", rootCmd.PersistentFlags().Lookup("provider-api-key")) viper.BindPFlag("max-tokens", rootCmd.PersistentFlags().Lookup("max-tokens")) @@ -374,6 +395,7 @@ func runNormalMode(ctx context.Context) error { } defer mcpAgent.Close() + // Initialize hook executor if hooks are configured // Get model name for display modelString := viper.GetString("model") parts := strings.SplitN(modelString, ":", 2) @@ -382,6 +404,20 @@ func runNormalMode(ctx context.Context) error { modelName = parts[1] } + var hookExecutor *hooks.Executor + if hooksConfig := viper.Get("hooks"); hooksConfig != nil { + if hc, ok := hooksConfig.(*hooks.HookConfig); ok { + // Generate a session ID for this run + sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix()) + transcriptPath := "" // We could add transcript logging later + hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath) + + // Set model and interactive mode + hookExecutor.SetModel(modelString) + hookExecutor.SetInteractive(promptFlag == "") // Interactive if no prompt flag + } + } + // Create an adapter for the agent to match the UI interface agentAdapter := &agentUIAdapter{agent: mcpAgent} @@ -608,7 +644,7 @@ func runNormalMode(ctx context.Context) error { // Check if running in non-interactive mode if promptFlag != "" { - return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager) + return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager, hookExecutor) } // Quiet mode is not allowed in interactive mode @@ -616,7 +652,7 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet flag can only be used with --prompt/-p") } - return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager) + return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor) } // AgenticLoopConfig configures the behavior of the unified agentic loop @@ -672,9 +708,30 @@ func replaceMessagesHistory(messages *[]*schema.Message, sessionManager *session } // runAgenticLoop handles all execution modes with a single unified loop -func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error { +func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error { // Handle initial prompt for non-interactive modes if !config.IsInteractive && config.InitialPrompt != "" { + // Execute UserPromptSubmit hooks for non-interactive mode + if hookExecutor != nil { + input := &hooks.UserPromptSubmitInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit), + Prompt: config.InitialPrompt, + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err) + } + } + + // Check if hook blocked the prompt + if hookOutput != nil && hookOutput.Decision == "block" { + return fmt.Errorf("prompt blocked by hook: %s", hookOutput.Reason) + } + } + // Display user message (skip if quiet) if !config.Quiet && cli != nil { cli.DisplayUserMessage(config.InitialPrompt) @@ -684,7 +741,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes tempMessages := append(messages, schema.UserMessage(config.InitialPrompt)) // Process the initial prompt with tool calls - _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config) + _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor) if err != nil { // Check if this was a user cancellation if err.Error() == "generation cancelled by user" && cli != nil { @@ -712,14 +769,14 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes // Interactive loop (or continuation after non-interactive) if config.IsInteractive { - return runInteractiveLoop(ctx, mcpAgent, cli, messages, config) + return runInteractiveLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } return nil } // runAgenticStep processes a single step of the agentic loop (handles tool calls) -func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, []*schema.Message, error) { +func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) (*schema.Message, []*schema.Message, error) { var currentSpinner *ui.Spinner // Start initial spinner (skip if quiet) @@ -762,9 +819,17 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes streamingStarted = false streamingContent.Reset() + // Variables to store tool information for hooks + var currentToolName string + var currentToolArgs string + result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages, // Tool call handler - called when a tool is about to be executed func(toolName, toolArgs string) { + // Store tool info for use in execution handler + currentToolName = toolName + currentToolArgs = toolArgs + if !config.Quiet && cli != nil { // Stop spinner before displaying tool call if currentSpinner != nil { @@ -776,22 +841,66 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes }, // Tool execution handler - called when tool execution starts/ends func(toolName string, isStarting bool) { - if !config.Quiet && cli != nil { - if isStarting { + if isStarting { + // Execute PreToolUse hooks + if hookExecutor != nil { + input := &hooks.PreToolUseInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.PreToolUse), + ToolName: currentToolName, + ToolInput: json.RawMessage(currentToolArgs), + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.PreToolUse, input) + if err != nil { + // Log error but don't fail the tool execution + if debugMode { + fmt.Fprintf(os.Stderr, "Hook execution error: %v\n", err) + } + } + + // Check if hook blocked the execution + if hookOutput != nil && hookOutput.Decision == "block" { + // We need a way to cancel the tool execution + // For now, just log it + if !config.Quiet && cli != nil { + cli.DisplayInfo(fmt.Sprintf("Tool execution blocked by hook: %s", hookOutput.Reason)) + } + } + } + + if !config.Quiet && cli != nil { // Start spinner for tool execution currentSpinner = ui.NewSpinner(fmt.Sprintf("Executing %s...", toolName)) currentSpinner.Start() - } else { - // Stop spinner when tool execution completes - if currentSpinner != nil { - currentSpinner.Stop() - currentSpinner = nil - } + } + } else { + // Stop spinner when tool execution completes + if !config.Quiet && cli != nil && currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil } } }, // Tool result handler - called when a tool execution completes func(toolName, toolArgs, result string, isError bool) { + // Execute PostToolUse hooks + if hookExecutor != nil && result != "" { + input := &hooks.PostToolUseInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.PostToolUse), + ToolName: currentToolName, + ToolInput: json.RawMessage(currentToolArgs), + ToolResponse: json.RawMessage(result), + } + + _, err := hookExecutor.ExecuteHooks(ctx, hooks.PostToolUse, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "PostToolUse hook execution error: %v\n", err) + } + } + } + if !config.Quiet && cli != nil { // Parse tool result content - it might be JSON-encoded MCP content resultContent := result @@ -921,12 +1030,49 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes cli.DisplayUsageAfterResponse() } + // Execute Stop hook after agent has finished responding + executeStopHook(hookExecutor, response, "completed", config.ModelName) + // Return the final response and all conversation messages return response, conversationMessages, nil } +// executeStopHook executes the Stop hook if a hook executor is available +func executeStopHook(hookExecutor *hooks.Executor, response *schema.Message, stopReason string, modelName string) { + if hookExecutor != nil { + // Prepare metadata + var meta json.RawMessage + if response != nil { + metaData := map[string]interface{}{ + "model": modelName, + "role": string(response.Role), + "has_tool_calls": len(response.ToolCalls) > 0, + } + if metaBytes, err := json.Marshal(metaData); err == nil { + meta = json.RawMessage(metaBytes) + } + } + + responseContent := "" + if response != nil { + responseContent = response.Content + } + + input := &hooks.StopInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.Stop), + StopHookActive: true, + Response: responseContent, + StopReason: stopReason, + Meta: meta, + } + + // Execute Stop hook (ignore errors as we're exiting anyway) + hookExecutor.ExecuteHooks(context.Background(), hooks.Stop, input) + } +} + // runInteractiveLoop handles the interactive portion of the agentic loop -func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error { +func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error { for { // Get user input prompt, err := cli.GetPrompt() @@ -942,6 +1088,30 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, continue } + // Execute UserPromptSubmit hooks + if hookExecutor != nil { + input := &hooks.UserPromptSubmitInput{ + CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit), + Prompt: prompt, + } + + hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input) + if err != nil { + // Log error but don't fail + if debugMode { + fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err) + } + } + + // Check if hook blocked the prompt + if hookOutput != nil && hookOutput.Decision == "block" { + if cli != nil { + cli.DisplayInfo(fmt.Sprintf("Prompt blocked: %s", hookOutput.Reason)) + } + continue // Skip this prompt + } + } + // Handle slash commands if cli.IsSlashCommand(prompt) { result := cli.HandleSlashCommand(prompt, config.ServerNames, config.ToolNames) @@ -965,7 +1135,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, tempMessages := append(messages, schema.UserMessage(prompt)) // Process the user input with tool calls - _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config) + _, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor) if err != nil { // Check if this was a user cancellation if err.Error() == "generation cancelled by user" { @@ -983,7 +1153,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, } // runNonInteractiveMode handles the non-interactive mode execution -func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager) error { +func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { // Prepare data for slash commands (needed if continuing to interactive mode) var serverNames []string for name := range mcpConfig.MCPServers { @@ -1011,11 +1181,11 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C SessionManager: sessionManager, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } // runInteractiveMode handles the interactive mode execution -func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager) error { +func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { // Configure and run unified agentic loop config := AgenticLoopConfig{ IsInteractive: true, @@ -1029,5 +1199,5 @@ func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, SessionManager: sessionManager, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } diff --git a/cmd/script.go b/cmd/script.go index 8ee02262..a0e357b4 100644 --- a/cmd/script.go +++ b/cmd/script.go @@ -8,10 +8,12 @@ import ( "os" "regexp" "strings" + "time" "github.com/cloudwego/eino/schema" "github.com/mark3labs/mcphost/internal/agent" "github.com/mark3labs/mcphost/internal/config" + "github.com/mark3labs/mcphost/internal/hooks" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/ui" "github.com/spf13/cobra" @@ -652,6 +654,21 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, cli.DisplayDebugConfig(debugConfig) } + // Initialize hooks + var hookExecutor *hooks.Executor + if hooksConfig := viper.Get("hooks"); hooksConfig != nil { + if hc, ok := hooksConfig.(*hooks.HookConfig); ok { + // Generate a session ID for this run + sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix()) + transcriptPath := "" // We could add transcript logging later + hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath) + + // Set model and interactive mode + hookExecutor.SetModel(finalModel) + hookExecutor.SetInteractive(prompt == "") + } + } + // Prepare data for slash commands var serverNames []string for name := range mcpConfig.MCPServers { @@ -679,5 +696,5 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, MCPConfig: mcpConfig, } - return runAgenticLoop(ctx, mcpAgent, cli, messages, config) + return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor) } diff --git a/examples/hooks/bash-validator.py b/examples/hooks/bash-validator.py new file mode 100755 index 00000000..e3120a24 --- /dev/null +++ b/examples/hooks/bash-validator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +Validates bash commands before execution. +Blocks dangerous commands and suggests alternatives. +""" +import json +import sys +import re + +# Define validation rules +DANGEROUS_PATTERNS = [ + (r'\brm\s+-rf\s+/', "Dangerous command: rm -rf /"), + (r'\bdd\s+.*\bof=/dev/[sh]d[a-z]', "Direct disk write detected"), + (r'>\s*/dev/null\s+2>&1', "Consider using proper error handling instead of discarding stderr"), +] + +SUGGEST_ALTERNATIVES = { + r'\bgrep\b': "Use 'rg' (ripgrep) for better performance", + r'\bfind\s+.*-name': "Use 'fd' for faster file finding", +} + +def main(): + try: + # Read input + input_data = json.load(sys.stdin) + + # Only process bash commands + if input_data.get('tool_name') != 'bash': + sys.exit(0) + + command = json.loads(input_data.get('tool_input', '{}')).get('command', '') + + # Check dangerous patterns + for pattern, message in DANGEROUS_PATTERNS: + if re.search(pattern, command, re.IGNORECASE): + print(message, file=sys.stderr) + sys.exit(2) # Block execution + + # Suggest alternatives + suggestions = [] + for pattern, suggestion in SUGGEST_ALTERNATIVES.items(): + if re.search(pattern, command): + suggestions.append(suggestion) + + if suggestions: + output = { + "decision": "approve", + "reason": "Command approved. Suggestions: " + "; ".join(suggestions) + } + print(json.dumps(output)) + + except Exception as e: + print(f"Hook error: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/hooks/mcp-monitor.py b/examples/hooks/mcp-monitor.py new file mode 100755 index 00000000..94f86590 --- /dev/null +++ b/examples/hooks/mcp-monitor.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +""" +Monitors MCP tool usage and enforces policies. +""" +import json +import sys +import re +import os +from datetime import datetime + +# Define MCP tool policies +BLOCKED_MCP_TOOLS = [ + "mcp__github__delete_.*", # Block all GitHub delete operations + "mcp__aws__.*_production", # Block production AWS operations +] + +RATE_LIMITS = { + "mcp__openai__.*": (10, 60), # 10 calls per 60 seconds +} + +def check_rate_limit(tool_name, limits): + # This is a simplified example - real implementation would need persistent storage + # For now, just log the attempt + return True + +def main(): + try: + input_data = json.load(sys.stdin) + tool_name = input_data.get('tool_name', '') + + # Check if tool is blocked + for pattern in BLOCKED_MCP_TOOLS: + if re.match(pattern, tool_name): + output = { + "decision": "block", + "reason": f"Tool {tool_name} is blocked by security policy" + } + print(json.dumps(output)) + sys.exit(0) + + # Check rate limits + for pattern, (limit, window) in RATE_LIMITS.items(): + if re.match(pattern, tool_name): + if not check_rate_limit(tool_name, (limit, window)): + output = { + "decision": "block", + "reason": f"Rate limit exceeded: {limit} calls per {window}s" + } + print(json.dumps(output)) + sys.exit(0) + + # Log MCP tool usage + log_entry = { + "timestamp": datetime.now().isoformat(), + "tool": tool_name, + "input": input_data.get('tool_input', {}) + } + + # Use XDG_CONFIG_HOME if set, otherwise default to ~/.config + config_home = os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config')) + log_dir = os.path.join(config_home, 'mcphost', 'logs') + os.makedirs(log_dir, exist_ok=True) + + with open(os.path.join(log_dir, "mcp-usage.jsonl"), "a") as f: + f.write(json.dumps(log_entry) + "\n") + + except Exception as e: + print(f"Hook error: {e}", file=sys.stderr) + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/hooks/prompt-logger.sh b/examples/hooks/prompt-logger.sh new file mode 100755 index 00000000..2c51fdf2 --- /dev/null +++ b/examples/hooks/prompt-logger.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Logs all user prompts with timestamp + +# Read JSON input +input=$(cat) + +# Extract prompt using jq (ensure jq is installed) +prompt=$(echo "$input" | jq -r '.prompt // empty') + +if [ -n "$prompt" ]; then + # Use XDG_CONFIG_HOME if set, otherwise default to ~/.config + CONFIG_DIR="${XDG_CONFIG_HOME:-$HOME/.config}" + LOG_DIR="$CONFIG_DIR/mcphost/logs" + + # Create log directory if it doesn't exist + mkdir -p "$LOG_DIR" + + # Log with timestamp + echo "[$(date '+%Y-%m-%d %H:%M:%S')] $prompt" >> "$LOG_DIR/prompts.log" +fi + +# Always allow prompt to continue +exit 0 \ No newline at end of file diff --git a/internal/hooks/config.go b/internal/hooks/config.go new file mode 100644 index 00000000..dc59c15e --- /dev/null +++ b/internal/hooks/config.go @@ -0,0 +1,132 @@ +package hooks + +import ( + "encoding/json" + "fmt" + "github.com/mark3labs/mcphost/internal/config" + "gopkg.in/yaml.v3" + "os" + "path/filepath" +) + +// HookConfig represents the complete hooks configuration +type HookConfig struct { + Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"` +} + +// HookMatcher matches specific tools and defines hooks to execute +type HookMatcher struct { + Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"` + Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"` + Hooks []HookEntry `yaml:"hooks" json:"hooks"` +} + +// HookEntry defines a single hook command +type HookEntry struct { + Type string `yaml:"type" json:"type"` + Command string `yaml:"command" json:"command"` + Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"` +} + +// LoadHooksConfig loads and merges hook configurations from multiple sources +func LoadHooksConfig(customPaths ...string) (*HookConfig, error) { + // Get config directory following XDG Base Directory specification + configDir := getConfigDir() + + // Define search paths in order of precedence (lowest to highest) + searchPaths := []string{ + filepath.Join(configDir, "mcphost", "hooks.json"), + filepath.Join(configDir, "mcphost", "hooks.yml"), + ".mcphost/hooks.json", + ".mcphost/hooks.yml", + } + + // Add custom paths with highest precedence + searchPaths = append(searchPaths, customPaths...) + + merged := &HookConfig{ + Hooks: make(map[HookEvent][]HookMatcher), + } + + for _, path := range searchPaths { + if _, err := os.Stat(path); os.IsNotExist(err) { + continue + } + + // Read file content + content, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading %s: %w", path, err) + } + + // Apply environment substitution + envSubstituter := &config.EnvSubstituter{} + substituted, err := envSubstituter.SubstituteEnvVars(string(content)) + if err != nil { + return nil, fmt.Errorf("substituting env vars in %s: %w", path, err) + } + + // Parse configuration + var cfg HookConfig + if filepath.Ext(path) == ".json" { + err = json.Unmarshal([]byte(substituted), &cfg) + } else { + err = yaml.Unmarshal([]byte(substituted), &cfg) + } + if err != nil { + return nil, fmt.Errorf("parsing %s: %w", path, err) + } + + // Merge configurations + mergeHookConfigs(merged, &cfg) + } + + return merged, nil +} + +// getConfigDir returns the configuration directory following XDG Base Directory specification +func getConfigDir() string { + // Try XDG_CONFIG_HOME first + if xdgConfig := os.Getenv("XDG_CONFIG_HOME"); xdgConfig != "" { + return xdgConfig + } + + // Fall back to ~/.config + if home := os.Getenv("HOME"); home != "" { + return filepath.Join(home, ".config") + } + + // Last resort: current directory + return "." +} + +// mergeHookConfigs merges source hooks into destination +func mergeHookConfigs(dst, src *HookConfig) { + for event, matchers := range src.Hooks { + if dst.Hooks[event] == nil { + dst.Hooks[event] = matchers + continue + } + + // Handle merge strategies + for _, srcMatcher := range matchers { + if srcMatcher.Merge == "replace" { + // Replace all matchers for this event + dst.Hooks[event] = []HookMatcher{srcMatcher} + } else { + // Append or update existing matcher + found := false + for i, dstMatcher := range dst.Hooks[event] { + if dstMatcher.Matcher == srcMatcher.Matcher { + dst.Hooks[event][i] = srcMatcher + found = true + break + } + } + if !found { + dst.Hooks[event] = append(dst.Hooks[event], srcMatcher) + } + } + } + } +} diff --git a/internal/hooks/config_test.go b/internal/hooks/config_test.go new file mode 100644 index 00000000..2f9a2682 --- /dev/null +++ b/internal/hooks/config_test.go @@ -0,0 +1,224 @@ +package hooks + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestLoadHooksConfig(t *testing.T) { + // Save original XDG_CONFIG_HOME + originalXDG := os.Getenv("XDG_CONFIG_HOME") + defer func() { + if originalXDG != "" { + os.Setenv("XDG_CONFIG_HOME", originalXDG) + } else { + os.Unsetenv("XDG_CONFIG_HOME") + } + }() + + tests := []struct { + name string + files map[string]string + expected *HookConfig + wantErr bool + }{ + { + name: "single yaml file", + files: map[string]string{ + "hooks.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo test" + timeout: 5 +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: 5}, + }, + }, + }, + }, + }, + }, + { + name: "environment substitution", + files: map[string]string{ + "hooks.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "${env://TEST_HOOK_CMD:-echo default}" +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo default"}, + }, + }, + }, + }, + }, + }, + { + name: "merge multiple files", + files: map[string]string{ + "global.yml": ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "global-hook" +`, + "local.yml": ` +hooks: + PreToolUse: + - matcher: "fetch" + hooks: + - type: command + command: "local-hook" +`, + }, + expected: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{{Type: "command", Command: "global-hook"}}, + }, + { + Matcher: "fetch", + Hooks: []HookEntry{{Type: "command", Command: "local-hook"}}, + }, + }, + }, + }, + }, + { + name: "invalid yaml", + files: map[string]string{ + "hooks.yml": `invalid yaml content`, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary directory for test files + tmpDir := t.TempDir() + + // Set XDG_CONFIG_HOME to a temp directory to avoid loading global hooks + testConfigDir := filepath.Join(tmpDir, "config") + os.Setenv("XDG_CONFIG_HOME", testConfigDir) + + // Write test files + var paths []string + for filename, content := range tt.files { + path := filepath.Join(tmpDir, filename) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + paths = append(paths, path) + } + + // Load configuration + got, err := LoadHooksConfig(paths...) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("LoadHooksConfig() = %+v, want %+v", got, tt.expected) + } + }) + } +} + +func TestMatchesPattern(t *testing.T) { + tests := []struct { + pattern string + toolName string + want bool + }{ + {"", "bash", true}, // Empty pattern matches all + {"bash", "bash", true}, // Exact match + {"bash", "Bash", false}, // Case sensitive + {"bash|fetch", "bash", true}, // Regex OR + {"bash|fetch", "fetch", true}, // Regex OR + {"bash|fetch", "todo", false}, // Regex OR no match + {"mcp__.*", "mcp__filesystem__read", true}, // MCP pattern + {".*write.*", "mcp__fs__write_file", true}, // Contains pattern + {"^bash$", "bash", true}, // Anchored regex + {"^bash$", "bash2", false}, // Anchored regex no match + } + + for _, tt := range tests { + t.Run(tt.pattern+"_"+tt.toolName, func(t *testing.T) { + got := matchesPattern(tt.pattern, tt.toolName) + if got != tt.want { + t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.toolName, got, tt.want) + } + }) + } +} + +func TestNoHooksFlag(t *testing.T) { + // This test verifies that when hooks are disabled via configuration, + // the LoadHooksConfig function is not called. The actual implementation + // of this is in cmd/root.go where viper.GetBool("no-hooks") is checked. + // This test documents the expected behavior. + + // Create a test hooks file + tmpDir := t.TempDir() + hooksFile := filepath.Join(tmpDir, "hooks.yml") + content := ` +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo 'This should not run'" +` + if err := os.WriteFile(hooksFile, []byte(content), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + // Load the hooks config normally + config, err := LoadHooksConfig(hooksFile) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify hooks are loaded + if len(config.Hooks) == 0 { + t.Error("expected hooks to be loaded") + } + + // The actual --no-hooks flag implementation is in cmd/root.go + // where it checks viper.GetBool("no-hooks") before calling LoadHooksConfig +} diff --git a/internal/hooks/events.go b/internal/hooks/events.go new file mode 100644 index 00000000..7b61397a --- /dev/null +++ b/internal/hooks/events.go @@ -0,0 +1,32 @@ +package hooks + +// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed +type HookEvent string + +const ( + // PreToolUse fires before any tool execution + PreToolUse HookEvent = "PreToolUse" + + // PostToolUse fires after tool execution completes + PostToolUse HookEvent = "PostToolUse" + + // UserPromptSubmit fires when user submits a prompt + UserPromptSubmit HookEvent = "UserPromptSubmit" + + // Stop fires when the main agent finishes responding + Stop HookEvent = "Stop" +) + +// IsValid returns true if the event is a valid hook event +func (e HookEvent) IsValid() bool { + switch e { + case PreToolUse, PostToolUse, UserPromptSubmit, Stop: + return true + } + return false +} + +// RequiresMatcher returns true if the event uses tool matchers +func (e HookEvent) RequiresMatcher() bool { + return e == PreToolUse || e == PostToolUse +} diff --git a/internal/hooks/executor.go b/internal/hooks/executor.go new file mode 100644 index 00000000..0c050f10 --- /dev/null +++ b/internal/hooks/executor.go @@ -0,0 +1,254 @@ +package hooks + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "regexp" + "sync" + "time" +) + +// Executor handles hook execution +type Executor struct { + config *HookConfig + sessionID string + transcript string + model string + interactive bool + mu sync.RWMutex +} + +// NewExecutor creates a new hook executor +func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor { + return &Executor{ + config: config, + sessionID: sessionID, + transcript: transcriptPath, + } +} + +// SetModel sets the model name for hook context +func (e *Executor) SetModel(model string) { + e.mu.Lock() + defer e.mu.Unlock() + e.model = model +} + +// SetInteractive sets whether we're in interactive mode +func (e *Executor) SetInteractive(interactive bool) { + e.mu.Lock() + defer e.mu.Unlock() + e.interactive = interactive +} + +// PopulateCommonFields fills in the common fields for any hook input +func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput { + e.mu.RLock() + defer e.mu.RUnlock() + + cwd, _ := os.Getwd() + return CommonInput{ + SessionID: e.sessionID, + TranscriptPath: e.transcript, + CWD: cwd, + HookEventName: event, + Timestamp: time.Now().Unix(), + Model: e.model, + Interactive: e.interactive, + } +} + +// ExecuteHooks runs all matching hooks for an event +func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) { + matchers, ok := e.config.Hooks[event] + if !ok || len(matchers) == 0 { + return nil, nil + } + + // Get tool name if applicable + toolName := "" + if event.RequiresMatcher() { + toolName = extractToolName(input) + } + + // Find matching hooks + var hooksToRun []HookEntry + for _, matcher := range matchers { + if matchesPattern(matcher.Matcher, toolName) { + hooksToRun = append(hooksToRun, matcher.Hooks...) + } + } + + if len(hooksToRun) == 0 { + return nil, nil + } + + // Execute hooks in parallel + results := make(chan *hookResult, len(hooksToRun)) + var wg sync.WaitGroup + + for _, hook := range hooksToRun { + wg.Add(1) + go func(h HookEntry) { + defer wg.Done() + result := e.executeHook(ctx, h, input) + results <- result + }(hook) + } + + wg.Wait() + close(results) + + // Process results + return e.processResults(results) +} + +// executeHook runs a single hook command +func (e *Executor) executeHook(ctx context.Context, hook HookEntry, input interface{}) *hookResult { + // Prepare input JSON + inputJSON, err := json.Marshal(input) + if err != nil { + return &hookResult{err: fmt.Errorf("marshaling input: %w", err)} + } + + // Set timeout + timeout := time.Duration(hook.Timeout) * time.Second + if timeout == 0 { + timeout = 60 * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Create command + cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command) + cmd.Stdin = bytes.NewReader(inputJSON) + cmd.Dir = getCurrentWorkingDir() + + // Capture output + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Execute + err = cmd.Run() + + exitCode := 0 + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + exitCode = -1 + } + } + + return &hookResult{ + exitCode: exitCode, + stdout: stdout.String(), + stderr: stderr.String(), + err: err, + } +} + +// matchesPattern checks if a tool name matches a pattern +func matchesPattern(pattern, toolName string) bool { + if pattern == "" { + return true // Empty pattern matches all + } + + // Try exact match first + if pattern == toolName { + return true + } + + // Try regex match + matched, err := regexp.MatchString(pattern, toolName) + if err != nil { + // Invalid regex pattern, return false + return false + } + + return matched +} + +// extractToolName gets the tool name from various input types +func extractToolName(input interface{}) string { + switch v := input.(type) { + case *PreToolUseInput: + return v.ToolName + case *PostToolUseInput: + return v.ToolName + default: + return "" + } +} + +type hookResult struct { + exitCode int + stdout string + stderr string + err error +} + +// processResults combines results from multiple hooks +func (e *Executor) processResults(results <-chan *hookResult) (*HookOutput, error) { + var finalOutput HookOutput + + for result := range results { + if result.err != nil && result.exitCode != 2 { + // Hook execution failed, skip this result + continue + } + + // Handle exit code 2 (blocking error) + if result.exitCode == 2 { + finalOutput.Decision = "block" + finalOutput.Reason = result.stderr + continueVal := false + finalOutput.Continue = &continueVal + return &finalOutput, nil + } + + // Try to parse JSON output + if result.stdout != "" { + var output HookOutput + if err := json.Unmarshal([]byte(result.stdout), &output); err == nil { + // Merge outputs (later hooks can override) + mergeHookOutputs(&finalOutput, &output) + } + } + } + + return &finalOutput, nil +} + +// mergeHookOutputs combines two hook outputs +func mergeHookOutputs(dst, src *HookOutput) { + if src.Continue != nil { + dst.Continue = src.Continue + } + if src.StopReason != "" { + dst.StopReason = src.StopReason + } + if src.Decision != "" { + dst.Decision = src.Decision + } + if src.Reason != "" { + dst.Reason = src.Reason + } + if src.SuppressOutput { + dst.SuppressOutput = true + } +} + +func getCurrentWorkingDir() string { + cwd, err := os.Getwd() + if err != nil { + return "/" + } + return cwd +} diff --git a/internal/hooks/executor_test.go b/internal/hooks/executor_test.go new file mode 100644 index 00000000..a5748268 --- /dev/null +++ b/internal/hooks/executor_test.go @@ -0,0 +1,193 @@ +package hooks + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" +) + +func TestExecuteHooks(t *testing.T) { + // Create test scripts + tmpDir := t.TempDir() + + // Simple echo script + echoScript := filepath.Join(tmpDir, "echo.sh") + if err := os.WriteFile(echoScript, []byte(`#!/bin/bash +cat +`), 0755); err != nil { + t.Fatalf("failed to create echo script: %v", err) + } + + // Blocking script (exit code 2) + blockScript := filepath.Join(tmpDir, "block.sh") + if err := os.WriteFile(blockScript, []byte(`#!/bin/bash +echo "Blocked by policy" >&2 +exit 2 +`), 0755); err != nil { + t.Fatalf("failed to create block script: %v", err) + } + + // JSON output script + jsonScript := filepath.Join(tmpDir, "json.sh") + if err := os.WriteFile(jsonScript, []byte(`#!/bin/bash +echo '{"decision": "approve", "reason": "Approved by test"}' +`), 0755); err != nil { + t.Fatalf("failed to create json script: %v", err) + } + + tests := []struct { + name string + config *HookConfig + event HookEvent + input interface{} + expected *HookOutput + wantErr bool + }{ + { + name: "simple command execution", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: echoScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{}, + }, + { + name: "blocking hook", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: blockScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{ + Decision: "block", + Reason: "Blocked by policy\n", + Continue: boolPtr(false), + }, + }, + { + name: "JSON output parsing", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: jsonScript, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{ + Decision: "approve", + Reason: "Approved by test", + }, + }, + { + name: "timeout handling", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: {{ + Matcher: "bash", + Hooks: []HookEntry{{ + Type: "command", + Command: "sleep 10", + Timeout: 1, + }}, + }}, + }, + }, + event: PreToolUse, + input: &PreToolUseInput{ + CommonInput: CommonInput{HookEventName: PreToolUse}, + ToolName: "bash", + }, + expected: &HookOutput{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + executor := NewExecutor(tt.config, "test-session", "/tmp/test.jsonl") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + got, err := executor.ExecuteHooks(ctx, tt.event, tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Compare outputs + if !compareHookOutputs(got, tt.expected) { + gotJSON, _ := json.MarshalIndent(got, "", " ") + expectedJSON, _ := json.MarshalIndent(tt.expected, "", " ") + t.Errorf("ExecuteHooks() output mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, expectedJSON) + } + }) + } +} + +func boolPtr(b bool) *bool { + return &b +} + +func compareHookOutputs(a, b *HookOutput) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + + // Compare Continue pointers + if (a.Continue == nil) != (b.Continue == nil) { + return false + } + if a.Continue != nil && *a.Continue != *b.Continue { + return false + } + + return a.StopReason == b.StopReason && + a.SuppressOutput == b.SuppressOutput && + a.Decision == b.Decision && + a.Reason == b.Reason +} diff --git a/internal/hooks/schemas.go b/internal/hooks/schemas.go new file mode 100644 index 00000000..9d79ca92 --- /dev/null +++ b/internal/hooks/schemas.go @@ -0,0 +1,55 @@ +package hooks + +import ( + "encoding/json" +) + +// CommonInput contains fields common to all hook inputs +type CommonInput struct { + SessionID string `json:"session_id"` // Unique session identifier + TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled) + CWD string `json:"cwd"` // Current working directory + HookEventName HookEvent `json:"hook_event_name"` // The hook event type + Timestamp int64 `json:"timestamp"` // Unix timestamp when hook fired + Model string `json:"model"` // AI model being used + Interactive bool `json:"interactive"` // Whether in interactive mode +} + +// PreToolUseInput is passed to PreToolUse hooks +type PreToolUseInput struct { + CommonInput + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` +} + +// PostToolUseInput is passed to PostToolUse hooks +type PostToolUseInput struct { + CommonInput + ToolName string `json:"tool_name"` + ToolInput json.RawMessage `json:"tool_input"` + ToolResponse json.RawMessage `json:"tool_response"` +} + +// UserPromptSubmitInput is passed to UserPromptSubmit hooks +type UserPromptSubmitInput struct { + CommonInput + Prompt string `json:"prompt"` +} + +// StopInput is passed to Stop hooks +type StopInput struct { + CommonInput + StopHookActive bool `json:"stop_hook_active"` + Response string `json:"response"` // The agent's final response + StopReason string `json:"stop_reason"` // "completed", "cancelled", "error" + Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info) +} + +// HookOutput represents the JSON output from a hook +type HookOutput struct { + Continue *bool `json:"continue,omitempty"` + StopReason string `json:"stopReason,omitempty"` + SuppressOutput bool `json:"suppressOutput,omitempty"` + Decision string `json:"decision,omitempty"` // "approve", "block", or "" + Reason string `json:"reason,omitempty"` +} diff --git a/internal/hooks/testdata/invalid-hooks.yml b/internal/hooks/testdata/invalid-hooks.yml new file mode 100644 index 00000000..293188bc --- /dev/null +++ b/internal/hooks/testdata/invalid-hooks.yml @@ -0,0 +1,10 @@ +hooks: + InvalidEvent: + - hooks: + - type: command + command: "echo test" + PreToolUse: + - matcher: "[invalid regex" + hooks: + - type: command + command: "echo test" \ No newline at end of file diff --git a/internal/hooks/testdata/valid-hooks.yml b/internal/hooks/testdata/valid-hooks.yml new file mode 100644 index 00000000..2bcdff74 --- /dev/null +++ b/internal/hooks/testdata/valid-hooks.yml @@ -0,0 +1,21 @@ +hooks: + PreToolUse: + - matcher: "bash" + hooks: + - type: command + command: "echo 'Executing bash command'" + timeout: 5 + - matcher: "fetch" + hooks: + - type: command + command: "echo 'Fetching URL'" + timeout: 10 + UserPromptSubmit: + - hooks: + - type: command + command: "date >> /tmp/mcphost-prompts.log" + PostToolUse: + - matcher: ".*" + hooks: + - type: command + command: "echo 'Tool execution completed'" \ No newline at end of file diff --git a/internal/hooks/validator.go b/internal/hooks/validator.go new file mode 100644 index 00000000..504b5e68 --- /dev/null +++ b/internal/hooks/validator.go @@ -0,0 +1,130 @@ +package hooks + +import ( + "fmt" + "regexp" + "strings" +) + +// Security patterns to detect potentially dangerous commands +var ( + commandInjectionPattern = regexp.MustCompile(`[;&|]|\$\(|` + "`") + pathTraversalPattern = regexp.MustCompile(`\.\.\/`) + commandSubstitutionPattern = regexp.MustCompile(`\$\([^)]+\)|` + "`" + `[^` + "`" + `]+` + "`") +) + +// validateHookCommand validates a hook command for security issues +func validateHookCommand(command string) error { + if command == "" { + return fmt.Errorf("empty command") + } + + // Check for command injection attempts + if commandInjectionPattern.MatchString(command) { + // Allow simple pipes and redirects, but check for dangerous patterns + if containsDangerousPattern(command) { + return fmt.Errorf("potential command injection detected") + } + } + + // Check for path traversal + if pathTraversalPattern.MatchString(command) { + return fmt.Errorf("path traversal detected") + } + + // Check for command substitution + if commandSubstitutionPattern.MatchString(command) { + return fmt.Errorf("command substitution detected") + } + + return nil +} + +// containsDangerousPattern checks for specific dangerous command patterns +func containsDangerousPattern(command string) bool { + dangerousPatterns := []string{ + "; rm ", + "&& rm ", + "| rm ", + "; dd ", + "&& dd ", + "| dd ", + "/dev/null 2>&1", + } + + for _, pattern := range dangerousPatterns { + if strings.Contains(command, pattern) { + return true + } + } + + // Check for multiple command separators which might indicate injection + separatorCount := 0 + for _, sep := range []string{";", "&&", "||", "|"} { + separatorCount += strings.Count(command, sep) + } + + // Allow up to 2 separators for reasonable command chaining + return separatorCount > 2 +} + +// ValidateHookConfig validates the entire hook configuration +func ValidateHookConfig(config *HookConfig) error { + if config == nil { + return fmt.Errorf("nil configuration") + } + + for event, matchers := range config.Hooks { + if !event.IsValid() { + return fmt.Errorf("invalid event: %s", event) + } + + for i, matcher := range matchers { + // Validate regex pattern if provided + if matcher.Matcher != "" { + if _, err := regexp.Compile(matcher.Matcher); err != nil { + return fmt.Errorf("invalid regex pattern in matcher %d for event %s: %w", i, event, err) + } + } + + // Validate hooks + if len(matcher.Hooks) == 0 { + return fmt.Errorf("no hooks defined for matcher %d in event %s", i, event) + } + + for j, hook := range matcher.Hooks { + if err := validateHookEntry(hook); err != nil { + return fmt.Errorf("invalid hook %d in matcher %d for event %s: %w", j, i, event, err) + } + } + } + } + + return nil +} + +// validateHookEntry validates a single hook entry +func validateHookEntry(hook HookEntry) error { + if hook.Type != "command" { + return fmt.Errorf("invalid hook type: %s (only 'command' is supported)", hook.Type) + } + + if hook.Command == "" { + return fmt.Errorf("empty command") + } + + // Basic security validation + if err := validateHookCommand(hook.Command); err != nil { + return fmt.Errorf("command validation failed: %w", err) + } + + if hook.Timeout < 0 { + return fmt.Errorf("negative timeout: %d", hook.Timeout) + } + + if hook.Timeout > 600 { // 10 minutes max + return fmt.Errorf("timeout too large: %d (max 600 seconds)", hook.Timeout) + } + + return nil +} diff --git a/internal/hooks/validator_test.go b/internal/hooks/validator_test.go new file mode 100644 index 00000000..5b62dbfd --- /dev/null +++ b/internal/hooks/validator_test.go @@ -0,0 +1,251 @@ +package hooks + +import ( + "strings" + "testing" +) + +func TestValidateHookCommand(t *testing.T) { + tests := []struct { + name string + command string + wantErr bool + errMsg string + }{ + { + name: "simple command", + command: "echo hello", + wantErr: false, + }, + { + name: "absolute path", + command: "/usr/local/bin/validator.py", + wantErr: false, + }, + { + name: "command injection attempt", + command: "echo test; rm -rf /", + wantErr: true, + errMsg: "potential command injection", + }, + { + name: "path traversal", + command: "cat ../../../etc/passwd", + wantErr: true, + errMsg: "path traversal detected", + }, + { + name: "command substitution", + command: "echo $(/bin/sh -c 'malicious')", + wantErr: true, + errMsg: "command substitution detected", + }, + { + name: "backtick substitution", + command: "echo `whoami`", + wantErr: true, + errMsg: "command substitution detected", + }, + { + name: "empty command", + command: "", + wantErr: true, + errMsg: "empty command", + }, + { + name: "simple pipe allowed", + command: "ps aux | grep process", + wantErr: false, + }, + { + name: "too many command separators", + command: "cmd1 | cmd2 && cmd3 ; cmd4", + wantErr: true, + errMsg: "potential command injection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateHookCommand(tt.command) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +} + +func TestValidateHookConfig(t *testing.T) { + tests := []struct { + name string + config *HookConfig + wantErr bool + errMsg string + }{ + { + name: "valid config", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: false, + }, + { + name: "nil config", + config: nil, + wantErr: true, + errMsg: "nil configuration", + }, + { + name: "invalid event", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + "InvalidEvent": { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid event", + }, + { + name: "invalid regex pattern", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "[invalid", + Hooks: []HookEntry{ + {Type: "command", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid regex pattern", + }, + { + name: "no hooks defined", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Matcher: "bash", + Hooks: []HookEntry{}, + }, + }, + }, + }, + wantErr: true, + errMsg: "no hooks defined", + }, + { + name: "invalid hook type", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "invalid", Command: "echo test"}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "invalid hook type", + }, + { + name: "empty command", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: ""}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "empty command", + }, + { + name: "negative timeout", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: -1}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "negative timeout", + }, + { + name: "timeout too large", + config: &HookConfig{ + Hooks: map[HookEvent][]HookMatcher{ + PreToolUse: { + { + Hooks: []HookEntry{ + {Type: "command", Command: "echo test", Timeout: 700}, + }, + }, + }, + }, + }, + wantErr: true, + errMsg: "timeout too large", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateHookConfig(tt.config) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + } + }) + } +}