Add comprehensive hooks system for MCPHost lifecycle events (#111)

* Add comprehensive hooks system for MCPHost lifecycle events

Implements a flexible hooks system based on Anthropic Claude Code specification:

- **Hook Events**: PreToolUse, PostToolUse, UserPromptSubmit, Stop
- **Hook Types**: Command execution with JSON input/output
- **Configuration**: XDG-compliant with layered config support
- **Security**: Command validation, timeout controls, safe execution
- **Common Fields**: Consistent session ID, timestamps, model info across all hooks

Key features:
- Hooks receive JSON via stdin and can control flow via stdout
- Pattern matching for tool-specific hooks (regex support)
- Enhanced Stop hook with agent response and metadata
- Centralized session management with consistent IDs
- Built-in examples for logging, validation, and monitoring

This enables users to:
- Log and audit all tool usage and prompts
- Implement custom security policies
- Track usage metrics and model performance
- Integrate with external systems
- Build custom workflows around MCPHost

🤖 Generated with [opencode](https://opencode.ai)

Co-Authored-By: opencode <noreply@opencode.ai>

* Enable hooks in script mode

Previously, hooks were only initialized and executed in normal mode but not
in script mode. This was because script mode had its own execution path that
bypassed the hook initialization code.

This fix:
- Adds hook initialization to runScriptMode function
- Creates hook executor with proper session ID and model info
- Passes the hook executor to runAgenticLoop

Now hooks work consistently across all execution modes (normal, script, and
interactive), ensuring uniform behavior for logging, validation, and monitoring.

🤖 Generated with [opencode](https://opencode.ai)

Co-Authored-By: opencode <noreply@opencode.ai>

* Remove unnecessary hooks.local.yml pattern

The .local.yml pattern adds unnecessary complexity. Users who want project-specific
hooks that aren't committed to git can simply add .mcphost/ to their .gitignore.

This simplifies the hooks configuration loading and makes it clearer that:
- Global user hooks go in ~/.config/mcphost/hooks.yml
- Project-specific hooks go in .mcphost/hooks.yml
- Git ignore management is left to the user

🤖 Generated with [opencode](https://opencode.ai)

Co-Authored-By: opencode <noreply@opencode.ai>

* Fix hooks test isolation and add --no-hooks flag

- Fix TestLoadHooksConfig by setting temporary XDG_CONFIG_HOME to prevent loading global hooks
- Add --no-hooks flag to disable all hooks execution across all modes
- Update README with documentation for the new flag
- Add test to verify hooks loading behavior

This allows users to temporarily disable hooks for security or debugging purposes.

🤖 Generated with [opencode](https://opencode.ai)

Co-Authored-By: opencode <noreply@opencode.ai>

---------

Co-authored-by: opencode <noreply@opencode.ai>
This commit is contained in:
Ed Zynda
2025-07-24 13:56:33 +03:00
committed by GitHub
parent 66b7a72281
commit cdc4abfb36
17 changed files with 1897 additions and 21 deletions
+66
View File
@@ -558,6 +558,72 @@ See `examples/scripts/` for sample scripts:
- `example-script.sh` - Script with custom MCP servers
- `simple-script.sh` - Script using default config fallback
### Hooks System
MCPHost supports a powerful hooks system that allows you to execute custom commands at specific points during execution. This enables security policies, logging, custom integrations, and automated workflows.
#### Quick Start
1. Initialize a hooks configuration:
```bash
mcphost hooks init
```
2. View active hooks:
```bash
mcphost hooks list
```
3. Validate your configuration:
```bash
mcphost hooks validate
```
#### Configuration
Hooks are configured in YAML files with the following precedence (highest to lowest):
- `.mcphost/hooks.yml` (project-specific hooks)
- `$XDG_CONFIG_HOME/mcphost/hooks.yml` (user global hooks, defaults to `~/.config/mcphost/hooks.yml`)
Example configuration:
```yaml
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "/usr/local/bin/validate-bash.py"
timeout: 5
UserPromptSubmit:
- hooks:
- type: command
command: "~/.mcphost/hooks/log-prompt.sh"
```
#### Available Hook Events
- **PreToolUse**: Before any tool execution (bash, fetch, todo, MCP tools)
- **PostToolUse**: After tool execution completes
- **UserPromptSubmit**: When user submits a prompt
- **Stop**: When the agent finishes responding
- **SubagentStop**: When a subagent (Task tool) finishes
- **Notification**: When MCPHost sends notifications
#### Security
⚠️ **WARNING**: Hooks execute arbitrary commands on your system. Only use hooks from trusted sources and always review hook commands before enabling them.
To temporarily disable all hooks, use the `--no-hooks` flag:
```bash
mcphost --no-hooks
```
See the example hook scripts in `examples/hooks/`:
- `bash-validator.py` - Validates and blocks dangerous bash commands
- `prompt-logger.sh` - Logs all user prompts with timestamps
- `mcp-monitor.py` - Monitors and enforces policies on MCP tool usage
### Non-Interactive Mode
Run a single prompt and exit - perfect for scripting and automation:
+169
View File
@@ -0,0 +1,169 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"github.com/mark3labs/mcphost/internal/hooks"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
)
var hooksCmd = &cobra.Command{
Use: "hooks",
Short: "Manage MCPHost hooks",
Long: "Commands for managing and testing MCPHost hooks configuration",
}
var hooksListCmd = &cobra.Command{
Use: "list",
Short: "List all configured hooks",
RunE: func(cmd *cobra.Command, args []string) error {
config, err := hooks.LoadHooksConfig()
if err != nil {
return fmt.Errorf("loading hooks config: %w", err)
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "EVENT\tMATCHER\tCOMMAND\tTIMEOUT")
for event, matchers := range config.Hooks {
for _, matcher := range matchers {
for _, hook := range matcher.Hooks {
timeout := "60s"
if hook.Timeout > 0 {
timeout = fmt.Sprintf("%ds", hook.Timeout)
}
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
event, matcher.Matcher, hook.Command, timeout)
}
}
}
return w.Flush()
},
}
var hooksValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate hooks configuration",
RunE: func(cmd *cobra.Command, args []string) error {
config, err := hooks.LoadHooksConfig()
if err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Additional validation
if err := hooks.ValidateHookConfig(config); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
fmt.Println("✓ Hooks configuration is valid")
return nil
},
}
var hooksInitCmd = &cobra.Command{
Use: "init",
Short: "Generate example hooks configuration",
RunE: func(cmd *cobra.Command, args []string) error {
example := &hooks.HookConfig{
Hooks: map[hooks.HookEvent][]hooks.HookMatcher{
// PreToolUse - runs before any tool execution
hooks.PreToolUse: {
{
Matcher: "bash.*",
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] $ " + .tool_input.command' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-commands.log"`,
Timeout: 5,
},
},
},
{
Matcher: ".*", // Log all tool usage
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), event: "pre", tool: .tool_name, input: .tool_input}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/all-tools.jsonl"`,
Timeout: 5,
},
},
},
},
// PostToolUse - runs after tool execution completes
hooks.PostToolUse: {
{
Matcher: "bash.*",
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), cmd: .tool_input.command, exit: .tool_response._meta.exit, stdout: (.tool_response._meta.stdout | rtrimstr("\n") | .[0:100]), stderr: (.tool_response._meta.stderr | rtrimstr("\n"))}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/bash-audit.jsonl"`,
Timeout: 5,
},
},
},
{
Matcher: "mcp__.*", // Log MCP tool responses
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `jq -c '{time: now | strftime("%Y-%m-%d %H:%M:%S"), tool: .tool_name, response_preview: (.tool_response | tostring | .[0:200])}' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/mcp-tools.jsonl"`,
Timeout: 5,
},
},
},
},
// UserPromptSubmit - runs when user submits a prompt
hooks.UserPromptSubmit: {
{
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `mkdir -p "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs" && jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] " + .prompt' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/prompts.log"`,
},
},
},
},
// Stop - runs when the main agent finishes responding
hooks.Stop: {
{
Hooks: []hooks.HookEntry{
{
Type: "command",
Command: `jq -r '"[" + (now | strftime("%Y-%m-%d %H:%M:%S")) + "] Session " + .session_id + " stopped"' >> "${XDG_CONFIG_HOME:-$HOME/.config}/mcphost/logs/sessions.log"`,
},
},
},
},
},
}
// Create .mcphost directory if it doesn't exist
if err := os.MkdirAll(".mcphost", 0755); err != nil {
return fmt.Errorf("creating .mcphost directory: %w", err)
}
// Write example configuration
data, err := yaml.Marshal(example)
if err != nil {
return fmt.Errorf("marshaling example: %w", err)
}
if err := os.WriteFile(".mcphost/hooks.yml", data, 0644); err != nil {
return fmt.Errorf("writing example: %w", err)
}
fmt.Println("Created .mcphost/hooks.yml with example configuration")
return nil
},
}
func init() {
rootCmd.AddCommand(hooksCmd)
hooksCmd.AddCommand(hooksListCmd)
hooksCmd.AddCommand(hooksValidateCmd)
hooksCmd.AddCommand(hooksInitCmd)
}
+190 -20
View File
@@ -9,10 +9,12 @@ import (
"os"
"path/filepath"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/mark3labs/mcphost/internal/agent"
"github.com/mark3labs/mcphost/internal/config"
"github.com/mark3labs/mcphost/internal/hooks"
"github.com/mark3labs/mcphost/internal/models"
"github.com/mark3labs/mcphost/internal/session"
"github.com/mark3labs/mcphost/internal/tokens"
@@ -50,6 +52,9 @@ var (
// Ollama-specific parameters
numGPU int32
mainGPU int32
// Hooks control
noHooks bool
)
// agentUIAdapter adapts agent.Agent to ui.AgentInterface
@@ -177,6 +182,19 @@ func initConfig() {
// Set environment variable prefix
viper.SetEnvPrefix("MCPHOST")
viper.AutomaticEnv()
// Load hooks configuration unless disabled
if !viper.GetBool("no-hooks") {
hooksConfig, err := hooks.LoadHooksConfig()
if err != nil {
// Hooks are optional, so just log a warning
if debugMode {
fmt.Fprintf(os.Stderr, "Warning: Failed to load hooks configuration: %v\n", err)
}
} else {
viper.Set("hooks", hooksConfig)
}
}
}
// loadConfigWithEnvSubstitution loads a config file with environment variable substitution
@@ -230,6 +248,8 @@ func init() {
BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display")
rootCmd.PersistentFlags().
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
rootCmd.PersistentFlags().
BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution")
// Session management flags
rootCmd.PersistentFlags().
@@ -261,6 +281,7 @@ func init() {
viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps"))
viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream"))
viper.BindPFlag("compact", rootCmd.PersistentFlags().Lookup("compact"))
viper.BindPFlag("no-hooks", rootCmd.PersistentFlags().Lookup("no-hooks"))
viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url"))
viper.BindPFlag("provider-api-key", rootCmd.PersistentFlags().Lookup("provider-api-key"))
viper.BindPFlag("max-tokens", rootCmd.PersistentFlags().Lookup("max-tokens"))
@@ -374,6 +395,7 @@ func runNormalMode(ctx context.Context) error {
}
defer mcpAgent.Close()
// Initialize hook executor if hooks are configured
// Get model name for display
modelString := viper.GetString("model")
parts := strings.SplitN(modelString, ":", 2)
@@ -382,6 +404,20 @@ func runNormalMode(ctx context.Context) error {
modelName = parts[1]
}
var hookExecutor *hooks.Executor
if hooksConfig := viper.Get("hooks"); hooksConfig != nil {
if hc, ok := hooksConfig.(*hooks.HookConfig); ok {
// Generate a session ID for this run
sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix())
transcriptPath := "" // We could add transcript logging later
hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath)
// Set model and interactive mode
hookExecutor.SetModel(modelString)
hookExecutor.SetInteractive(promptFlag == "") // Interactive if no prompt flag
}
}
// Create an adapter for the agent to match the UI interface
agentAdapter := &agentUIAdapter{agent: mcpAgent}
@@ -608,7 +644,7 @@ func runNormalMode(ctx context.Context) error {
// Check if running in non-interactive mode
if promptFlag != "" {
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager)
return runNonInteractiveMode(ctx, mcpAgent, cli, promptFlag, modelName, messages, quietFlag, noExitFlag, mcpConfig, sessionManager, hookExecutor)
}
// Quiet mode is not allowed in interactive mode
@@ -616,7 +652,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
}
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager)
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
}
// AgenticLoopConfig configures the behavior of the unified agentic loop
@@ -672,9 +708,30 @@ func replaceMessagesHistory(messages *[]*schema.Message, sessionManager *session
}
// runAgenticLoop handles all execution modes with a single unified loop
func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error {
func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error {
// Handle initial prompt for non-interactive modes
if !config.IsInteractive && config.InitialPrompt != "" {
// Execute UserPromptSubmit hooks for non-interactive mode
if hookExecutor != nil {
input := &hooks.UserPromptSubmitInput{
CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit),
Prompt: config.InitialPrompt,
}
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input)
if err != nil {
// Log error but don't fail
if debugMode {
fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err)
}
}
// Check if hook blocked the prompt
if hookOutput != nil && hookOutput.Decision == "block" {
return fmt.Errorf("prompt blocked by hook: %s", hookOutput.Reason)
}
}
// Display user message (skip if quiet)
if !config.Quiet && cli != nil {
cli.DisplayUserMessage(config.InitialPrompt)
@@ -684,7 +741,7 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
tempMessages := append(messages, schema.UserMessage(config.InitialPrompt))
// Process the initial prompt with tool calls
_, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
_, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor)
if err != nil {
// Check if this was a user cancellation
if err.Error() == "generation cancelled by user" && cli != nil {
@@ -712,14 +769,14 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
// Interactive loop (or continuation after non-interactive)
if config.IsInteractive {
return runInteractiveLoop(ctx, mcpAgent, cli, messages, config)
return runInteractiveLoop(ctx, mcpAgent, cli, messages, config, hookExecutor)
}
return nil
}
// runAgenticStep processes a single step of the agentic loop (handles tool calls)
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) (*schema.Message, []*schema.Message, error) {
func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) (*schema.Message, []*schema.Message, error) {
var currentSpinner *ui.Spinner
// Start initial spinner (skip if quiet)
@@ -762,9 +819,17 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
streamingStarted = false
streamingContent.Reset()
// Variables to store tool information for hooks
var currentToolName string
var currentToolArgs string
result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages,
// Tool call handler - called when a tool is about to be executed
func(toolName, toolArgs string) {
// Store tool info for use in execution handler
currentToolName = toolName
currentToolArgs = toolArgs
if !config.Quiet && cli != nil {
// Stop spinner before displaying tool call
if currentSpinner != nil {
@@ -776,22 +841,66 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
},
// Tool execution handler - called when tool execution starts/ends
func(toolName string, isStarting bool) {
if !config.Quiet && cli != nil {
if isStarting {
if isStarting {
// Execute PreToolUse hooks
if hookExecutor != nil {
input := &hooks.PreToolUseInput{
CommonInput: hookExecutor.PopulateCommonFields(hooks.PreToolUse),
ToolName: currentToolName,
ToolInput: json.RawMessage(currentToolArgs),
}
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.PreToolUse, input)
if err != nil {
// Log error but don't fail the tool execution
if debugMode {
fmt.Fprintf(os.Stderr, "Hook execution error: %v\n", err)
}
}
// Check if hook blocked the execution
if hookOutput != nil && hookOutput.Decision == "block" {
// We need a way to cancel the tool execution
// For now, just log it
if !config.Quiet && cli != nil {
cli.DisplayInfo(fmt.Sprintf("Tool execution blocked by hook: %s", hookOutput.Reason))
}
}
}
if !config.Quiet && cli != nil {
// Start spinner for tool execution
currentSpinner = ui.NewSpinner(fmt.Sprintf("Executing %s...", toolName))
currentSpinner.Start()
} else {
// Stop spinner when tool execution completes
if currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
}
} else {
// Stop spinner when tool execution completes
if !config.Quiet && cli != nil && currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
}
},
// Tool result handler - called when a tool execution completes
func(toolName, toolArgs, result string, isError bool) {
// Execute PostToolUse hooks
if hookExecutor != nil && result != "" {
input := &hooks.PostToolUseInput{
CommonInput: hookExecutor.PopulateCommonFields(hooks.PostToolUse),
ToolName: currentToolName,
ToolInput: json.RawMessage(currentToolArgs),
ToolResponse: json.RawMessage(result),
}
_, err := hookExecutor.ExecuteHooks(ctx, hooks.PostToolUse, input)
if err != nil {
// Log error but don't fail
if debugMode {
fmt.Fprintf(os.Stderr, "PostToolUse hook execution error: %v\n", err)
}
}
}
if !config.Quiet && cli != nil {
// Parse tool result content - it might be JSON-encoded MCP content
resultContent := result
@@ -921,12 +1030,49 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
cli.DisplayUsageAfterResponse()
}
// Execute Stop hook after agent has finished responding
executeStopHook(hookExecutor, response, "completed", config.ModelName)
// Return the final response and all conversation messages
return response, conversationMessages, nil
}
// executeStopHook executes the Stop hook if a hook executor is available
func executeStopHook(hookExecutor *hooks.Executor, response *schema.Message, stopReason string, modelName string) {
if hookExecutor != nil {
// Prepare metadata
var meta json.RawMessage
if response != nil {
metaData := map[string]interface{}{
"model": modelName,
"role": string(response.Role),
"has_tool_calls": len(response.ToolCalls) > 0,
}
if metaBytes, err := json.Marshal(metaData); err == nil {
meta = json.RawMessage(metaBytes)
}
}
responseContent := ""
if response != nil {
responseContent = response.Content
}
input := &hooks.StopInput{
CommonInput: hookExecutor.PopulateCommonFields(hooks.Stop),
StopHookActive: true,
Response: responseContent,
StopReason: stopReason,
Meta: meta,
}
// Execute Stop hook (ignore errors as we're exiting anyway)
hookExecutor.ExecuteHooks(context.Background(), hooks.Stop, input)
}
}
// runInteractiveLoop handles the interactive portion of the agentic loop
func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig) error {
func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, messages []*schema.Message, config AgenticLoopConfig, hookExecutor *hooks.Executor) error {
for {
// Get user input
prompt, err := cli.GetPrompt()
@@ -942,6 +1088,30 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
continue
}
// Execute UserPromptSubmit hooks
if hookExecutor != nil {
input := &hooks.UserPromptSubmitInput{
CommonInput: hookExecutor.PopulateCommonFields(hooks.UserPromptSubmit),
Prompt: prompt,
}
hookOutput, err := hookExecutor.ExecuteHooks(ctx, hooks.UserPromptSubmit, input)
if err != nil {
// Log error but don't fail
if debugMode {
fmt.Fprintf(os.Stderr, "UserPromptSubmit hook execution error: %v\n", err)
}
}
// Check if hook blocked the prompt
if hookOutput != nil && hookOutput.Decision == "block" {
if cli != nil {
cli.DisplayInfo(fmt.Sprintf("Prompt blocked: %s", hookOutput.Reason))
}
continue // Skip this prompt
}
}
// Handle slash commands
if cli.IsSlashCommand(prompt) {
result := cli.HandleSlashCommand(prompt, config.ServerNames, config.ToolNames)
@@ -965,7 +1135,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
tempMessages := append(messages, schema.UserMessage(prompt))
// Process the user input with tool calls
_, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
_, conversationMessages, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config, hookExecutor)
if err != nil {
// Check if this was a user cancellation
if err.Error() == "generation cancelled by user" {
@@ -983,7 +1153,7 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
}
// runNonInteractiveMode handles the non-interactive mode execution
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager) error {
func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, prompt, modelName string, messages []*schema.Message, quiet, noExit bool, mcpConfig *config.Config, sessionManager *session.Manager, hookExecutor *hooks.Executor) error {
// Prepare data for slash commands (needed if continuing to interactive mode)
var serverNames []string
for name := range mcpConfig.MCPServers {
@@ -1011,11 +1181,11 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
SessionManager: sessionManager,
}
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor)
}
// runInteractiveMode handles the interactive mode execution
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager) error {
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error {
// Configure and run unified agentic loop
config := AgenticLoopConfig{
IsInteractive: true,
@@ -1029,5 +1199,5 @@ func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
SessionManager: sessionManager,
}
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor)
}
+18 -1
View File
@@ -8,10 +8,12 @@ import (
"os"
"regexp"
"strings"
"time"
"github.com/cloudwego/eino/schema"
"github.com/mark3labs/mcphost/internal/agent"
"github.com/mark3labs/mcphost/internal/config"
"github.com/mark3labs/mcphost/internal/hooks"
"github.com/mark3labs/mcphost/internal/models"
"github.com/mark3labs/mcphost/internal/ui"
"github.com/spf13/cobra"
@@ -652,6 +654,21 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
cli.DisplayDebugConfig(debugConfig)
}
// Initialize hooks
var hookExecutor *hooks.Executor
if hooksConfig := viper.Get("hooks"); hooksConfig != nil {
if hc, ok := hooksConfig.(*hooks.HookConfig); ok {
// Generate a session ID for this run
sessionID := fmt.Sprintf("mcphost-%d", time.Now().Unix())
transcriptPath := "" // We could add transcript logging later
hookExecutor = hooks.NewExecutor(hc, sessionID, transcriptPath)
// Set model and interactive mode
hookExecutor.SetModel(finalModel)
hookExecutor.SetInteractive(prompt == "")
}
}
// Prepare data for slash commands
var serverNames []string
for name := range mcpConfig.MCPServers {
@@ -679,5 +696,5 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
MCPConfig: mcpConfig,
}
return runAgenticLoop(ctx, mcpAgent, cli, messages, config)
return runAgenticLoop(ctx, mcpAgent, cli, messages, config, hookExecutor)
}
+57
View File
@@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
Validates bash commands before execution.
Blocks dangerous commands and suggests alternatives.
"""
import json
import sys
import re
# Define validation rules
DANGEROUS_PATTERNS = [
(r'\brm\s+-rf\s+/', "Dangerous command: rm -rf /"),
(r'\bdd\s+.*\bof=/dev/[sh]d[a-z]', "Direct disk write detected"),
(r'>\s*/dev/null\s+2>&1', "Consider using proper error handling instead of discarding stderr"),
]
SUGGEST_ALTERNATIVES = {
r'\bgrep\b': "Use 'rg' (ripgrep) for better performance",
r'\bfind\s+.*-name': "Use 'fd' for faster file finding",
}
def main():
try:
# Read input
input_data = json.load(sys.stdin)
# Only process bash commands
if input_data.get('tool_name') != 'bash':
sys.exit(0)
command = json.loads(input_data.get('tool_input', '{}')).get('command', '')
# Check dangerous patterns
for pattern, message in DANGEROUS_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
print(message, file=sys.stderr)
sys.exit(2) # Block execution
# Suggest alternatives
suggestions = []
for pattern, suggestion in SUGGEST_ALTERNATIVES.items():
if re.search(pattern, command):
suggestions.append(suggestion)
if suggestions:
output = {
"decision": "approve",
"reason": "Command approved. Suggestions: " + "; ".join(suggestions)
}
print(json.dumps(output))
except Exception as e:
print(f"Hook error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
+72
View File
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
"""
Monitors MCP tool usage and enforces policies.
"""
import json
import sys
import re
import os
from datetime import datetime
# Define MCP tool policies
BLOCKED_MCP_TOOLS = [
"mcp__github__delete_.*", # Block all GitHub delete operations
"mcp__aws__.*_production", # Block production AWS operations
]
RATE_LIMITS = {
"mcp__openai__.*": (10, 60), # 10 calls per 60 seconds
}
def check_rate_limit(tool_name, limits):
# This is a simplified example - real implementation would need persistent storage
# For now, just log the attempt
return True
def main():
try:
input_data = json.load(sys.stdin)
tool_name = input_data.get('tool_name', '')
# Check if tool is blocked
for pattern in BLOCKED_MCP_TOOLS:
if re.match(pattern, tool_name):
output = {
"decision": "block",
"reason": f"Tool {tool_name} is blocked by security policy"
}
print(json.dumps(output))
sys.exit(0)
# Check rate limits
for pattern, (limit, window) in RATE_LIMITS.items():
if re.match(pattern, tool_name):
if not check_rate_limit(tool_name, (limit, window)):
output = {
"decision": "block",
"reason": f"Rate limit exceeded: {limit} calls per {window}s"
}
print(json.dumps(output))
sys.exit(0)
# Log MCP tool usage
log_entry = {
"timestamp": datetime.now().isoformat(),
"tool": tool_name,
"input": input_data.get('tool_input', {})
}
# Use XDG_CONFIG_HOME if set, otherwise default to ~/.config
config_home = os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config'))
log_dir = os.path.join(config_home, 'mcphost', 'logs')
os.makedirs(log_dir, exist_ok=True)
with open(os.path.join(log_dir, "mcp-usage.jsonl"), "a") as f:
f.write(json.dumps(log_entry) + "\n")
except Exception as e:
print(f"Hook error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()
+23
View File
@@ -0,0 +1,23 @@
#!/bin/bash
# Logs all user prompts with timestamp
# Read JSON input
input=$(cat)
# Extract prompt using jq (ensure jq is installed)
prompt=$(echo "$input" | jq -r '.prompt // empty')
if [ -n "$prompt" ]; then
# Use XDG_CONFIG_HOME if set, otherwise default to ~/.config
CONFIG_DIR="${XDG_CONFIG_HOME:-$HOME/.config}"
LOG_DIR="$CONFIG_DIR/mcphost/logs"
# Create log directory if it doesn't exist
mkdir -p "$LOG_DIR"
# Log with timestamp
echo "[$(date '+%Y-%m-%d %H:%M:%S')] $prompt" >> "$LOG_DIR/prompts.log"
fi
# Always allow prompt to continue
exit 0
+132
View File
@@ -0,0 +1,132 @@
package hooks
import (
"encoding/json"
"fmt"
"github.com/mark3labs/mcphost/internal/config"
"gopkg.in/yaml.v3"
"os"
"path/filepath"
)
// HookConfig represents the complete hooks configuration
type HookConfig struct {
Hooks map[HookEvent][]HookMatcher `yaml:"hooks" json:"hooks"`
}
// HookMatcher matches specific tools and defines hooks to execute
type HookMatcher struct {
Matcher string `yaml:"matcher,omitempty" json:"matcher,omitempty"`
Merge string `yaml:"_merge,omitempty" json:"_merge,omitempty"`
Hooks []HookEntry `yaml:"hooks" json:"hooks"`
}
// HookEntry defines a single hook command
type HookEntry struct {
Type string `yaml:"type" json:"type"`
Command string `yaml:"command" json:"command"`
Timeout int `yaml:"timeout,omitempty" json:"timeout,omitempty"`
}
// LoadHooksConfig loads and merges hook configurations from multiple sources
func LoadHooksConfig(customPaths ...string) (*HookConfig, error) {
// Get config directory following XDG Base Directory specification
configDir := getConfigDir()
// Define search paths in order of precedence (lowest to highest)
searchPaths := []string{
filepath.Join(configDir, "mcphost", "hooks.json"),
filepath.Join(configDir, "mcphost", "hooks.yml"),
".mcphost/hooks.json",
".mcphost/hooks.yml",
}
// Add custom paths with highest precedence
searchPaths = append(searchPaths, customPaths...)
merged := &HookConfig{
Hooks: make(map[HookEvent][]HookMatcher),
}
for _, path := range searchPaths {
if _, err := os.Stat(path); os.IsNotExist(err) {
continue
}
// Read file content
content, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading %s: %w", path, err)
}
// Apply environment substitution
envSubstituter := &config.EnvSubstituter{}
substituted, err := envSubstituter.SubstituteEnvVars(string(content))
if err != nil {
return nil, fmt.Errorf("substituting env vars in %s: %w", path, err)
}
// Parse configuration
var cfg HookConfig
if filepath.Ext(path) == ".json" {
err = json.Unmarshal([]byte(substituted), &cfg)
} else {
err = yaml.Unmarshal([]byte(substituted), &cfg)
}
if err != nil {
return nil, fmt.Errorf("parsing %s: %w", path, err)
}
// Merge configurations
mergeHookConfigs(merged, &cfg)
}
return merged, nil
}
// getConfigDir returns the configuration directory following XDG Base Directory specification
func getConfigDir() string {
// Try XDG_CONFIG_HOME first
if xdgConfig := os.Getenv("XDG_CONFIG_HOME"); xdgConfig != "" {
return xdgConfig
}
// Fall back to ~/.config
if home := os.Getenv("HOME"); home != "" {
return filepath.Join(home, ".config")
}
// Last resort: current directory
return "."
}
// mergeHookConfigs merges source hooks into destination
func mergeHookConfigs(dst, src *HookConfig) {
for event, matchers := range src.Hooks {
if dst.Hooks[event] == nil {
dst.Hooks[event] = matchers
continue
}
// Handle merge strategies
for _, srcMatcher := range matchers {
if srcMatcher.Merge == "replace" {
// Replace all matchers for this event
dst.Hooks[event] = []HookMatcher{srcMatcher}
} else {
// Append or update existing matcher
found := false
for i, dstMatcher := range dst.Hooks[event] {
if dstMatcher.Matcher == srcMatcher.Matcher {
dst.Hooks[event][i] = srcMatcher
found = true
break
}
}
if !found {
dst.Hooks[event] = append(dst.Hooks[event], srcMatcher)
}
}
}
}
}
+224
View File
@@ -0,0 +1,224 @@
package hooks
import (
"os"
"path/filepath"
"reflect"
"testing"
)
func TestLoadHooksConfig(t *testing.T) {
// Save original XDG_CONFIG_HOME
originalXDG := os.Getenv("XDG_CONFIG_HOME")
defer func() {
if originalXDG != "" {
os.Setenv("XDG_CONFIG_HOME", originalXDG)
} else {
os.Unsetenv("XDG_CONFIG_HOME")
}
}()
tests := []struct {
name string
files map[string]string
expected *HookConfig
wantErr bool
}{
{
name: "single yaml file",
files: map[string]string{
"hooks.yml": `
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "echo test"
timeout: 5
`,
},
expected: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "bash",
Hooks: []HookEntry{
{Type: "command", Command: "echo test", Timeout: 5},
},
},
},
},
},
},
{
name: "environment substitution",
files: map[string]string{
"hooks.yml": `
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "${env://TEST_HOOK_CMD:-echo default}"
`,
},
expected: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "bash",
Hooks: []HookEntry{
{Type: "command", Command: "echo default"},
},
},
},
},
},
},
{
name: "merge multiple files",
files: map[string]string{
"global.yml": `
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "global-hook"
`,
"local.yml": `
hooks:
PreToolUse:
- matcher: "fetch"
hooks:
- type: command
command: "local-hook"
`,
},
expected: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "bash",
Hooks: []HookEntry{{Type: "command", Command: "global-hook"}},
},
{
Matcher: "fetch",
Hooks: []HookEntry{{Type: "command", Command: "local-hook"}},
},
},
},
},
},
{
name: "invalid yaml",
files: map[string]string{
"hooks.yml": `invalid yaml content`,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create temporary directory for test files
tmpDir := t.TempDir()
// Set XDG_CONFIG_HOME to a temp directory to avoid loading global hooks
testConfigDir := filepath.Join(tmpDir, "config")
os.Setenv("XDG_CONFIG_HOME", testConfigDir)
// Write test files
var paths []string
for filename, content := range tt.files {
path := filepath.Join(tmpDir, filename)
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
paths = append(paths, path)
}
// Load configuration
got, err := LoadHooksConfig(paths...)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(got, tt.expected) {
t.Errorf("LoadHooksConfig() = %+v, want %+v", got, tt.expected)
}
})
}
}
func TestMatchesPattern(t *testing.T) {
tests := []struct {
pattern string
toolName string
want bool
}{
{"", "bash", true}, // Empty pattern matches all
{"bash", "bash", true}, // Exact match
{"bash", "Bash", false}, // Case sensitive
{"bash|fetch", "bash", true}, // Regex OR
{"bash|fetch", "fetch", true}, // Regex OR
{"bash|fetch", "todo", false}, // Regex OR no match
{"mcp__.*", "mcp__filesystem__read", true}, // MCP pattern
{".*write.*", "mcp__fs__write_file", true}, // Contains pattern
{"^bash$", "bash", true}, // Anchored regex
{"^bash$", "bash2", false}, // Anchored regex no match
}
for _, tt := range tests {
t.Run(tt.pattern+"_"+tt.toolName, func(t *testing.T) {
got := matchesPattern(tt.pattern, tt.toolName)
if got != tt.want {
t.Errorf("matchesPattern(%q, %q) = %v, want %v", tt.pattern, tt.toolName, got, tt.want)
}
})
}
}
func TestNoHooksFlag(t *testing.T) {
// This test verifies that when hooks are disabled via configuration,
// the LoadHooksConfig function is not called. The actual implementation
// of this is in cmd/root.go where viper.GetBool("no-hooks") is checked.
// This test documents the expected behavior.
// Create a test hooks file
tmpDir := t.TempDir()
hooksFile := filepath.Join(tmpDir, "hooks.yml")
content := `
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "echo 'This should not run'"
`
if err := os.WriteFile(hooksFile, []byte(content), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
// Load the hooks config normally
config, err := LoadHooksConfig(hooksFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Verify hooks are loaded
if len(config.Hooks) == 0 {
t.Error("expected hooks to be loaded")
}
// The actual --no-hooks flag implementation is in cmd/root.go
// where it checks viper.GetBool("no-hooks") before calling LoadHooksConfig
}
+32
View File
@@ -0,0 +1,32 @@
package hooks
// HookEvent represents a point in MCPHost's lifecycle where hooks can be executed
type HookEvent string
const (
// PreToolUse fires before any tool execution
PreToolUse HookEvent = "PreToolUse"
// PostToolUse fires after tool execution completes
PostToolUse HookEvent = "PostToolUse"
// UserPromptSubmit fires when user submits a prompt
UserPromptSubmit HookEvent = "UserPromptSubmit"
// Stop fires when the main agent finishes responding
Stop HookEvent = "Stop"
)
// IsValid returns true if the event is a valid hook event
func (e HookEvent) IsValid() bool {
switch e {
case PreToolUse, PostToolUse, UserPromptSubmit, Stop:
return true
}
return false
}
// RequiresMatcher returns true if the event uses tool matchers
func (e HookEvent) RequiresMatcher() bool {
return e == PreToolUse || e == PostToolUse
}
+254
View File
@@ -0,0 +1,254 @@
package hooks
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"regexp"
"sync"
"time"
)
// Executor handles hook execution
type Executor struct {
config *HookConfig
sessionID string
transcript string
model string
interactive bool
mu sync.RWMutex
}
// NewExecutor creates a new hook executor
func NewExecutor(config *HookConfig, sessionID, transcriptPath string) *Executor {
return &Executor{
config: config,
sessionID: sessionID,
transcript: transcriptPath,
}
}
// SetModel sets the model name for hook context
func (e *Executor) SetModel(model string) {
e.mu.Lock()
defer e.mu.Unlock()
e.model = model
}
// SetInteractive sets whether we're in interactive mode
func (e *Executor) SetInteractive(interactive bool) {
e.mu.Lock()
defer e.mu.Unlock()
e.interactive = interactive
}
// PopulateCommonFields fills in the common fields for any hook input
func (e *Executor) PopulateCommonFields(event HookEvent) CommonInput {
e.mu.RLock()
defer e.mu.RUnlock()
cwd, _ := os.Getwd()
return CommonInput{
SessionID: e.sessionID,
TranscriptPath: e.transcript,
CWD: cwd,
HookEventName: event,
Timestamp: time.Now().Unix(),
Model: e.model,
Interactive: e.interactive,
}
}
// ExecuteHooks runs all matching hooks for an event
func (e *Executor) ExecuteHooks(ctx context.Context, event HookEvent, input interface{}) (*HookOutput, error) {
matchers, ok := e.config.Hooks[event]
if !ok || len(matchers) == 0 {
return nil, nil
}
// Get tool name if applicable
toolName := ""
if event.RequiresMatcher() {
toolName = extractToolName(input)
}
// Find matching hooks
var hooksToRun []HookEntry
for _, matcher := range matchers {
if matchesPattern(matcher.Matcher, toolName) {
hooksToRun = append(hooksToRun, matcher.Hooks...)
}
}
if len(hooksToRun) == 0 {
return nil, nil
}
// Execute hooks in parallel
results := make(chan *hookResult, len(hooksToRun))
var wg sync.WaitGroup
for _, hook := range hooksToRun {
wg.Add(1)
go func(h HookEntry) {
defer wg.Done()
result := e.executeHook(ctx, h, input)
results <- result
}(hook)
}
wg.Wait()
close(results)
// Process results
return e.processResults(results)
}
// executeHook runs a single hook command
func (e *Executor) executeHook(ctx context.Context, hook HookEntry, input interface{}) *hookResult {
// Prepare input JSON
inputJSON, err := json.Marshal(input)
if err != nil {
return &hookResult{err: fmt.Errorf("marshaling input: %w", err)}
}
// Set timeout
timeout := time.Duration(hook.Timeout) * time.Second
if timeout == 0 {
timeout = 60 * time.Second
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Create command
cmd := exec.CommandContext(ctx, "sh", "-c", hook.Command)
cmd.Stdin = bytes.NewReader(inputJSON)
cmd.Dir = getCurrentWorkingDir()
// Capture output
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Execute
err = cmd.Run()
exitCode := 0
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
exitCode = exitErr.ExitCode()
} else {
exitCode = -1
}
}
return &hookResult{
exitCode: exitCode,
stdout: stdout.String(),
stderr: stderr.String(),
err: err,
}
}
// matchesPattern checks if a tool name matches a pattern
func matchesPattern(pattern, toolName string) bool {
if pattern == "" {
return true // Empty pattern matches all
}
// Try exact match first
if pattern == toolName {
return true
}
// Try regex match
matched, err := regexp.MatchString(pattern, toolName)
if err != nil {
// Invalid regex pattern, return false
return false
}
return matched
}
// extractToolName gets the tool name from various input types
func extractToolName(input interface{}) string {
switch v := input.(type) {
case *PreToolUseInput:
return v.ToolName
case *PostToolUseInput:
return v.ToolName
default:
return ""
}
}
type hookResult struct {
exitCode int
stdout string
stderr string
err error
}
// processResults combines results from multiple hooks
func (e *Executor) processResults(results <-chan *hookResult) (*HookOutput, error) {
var finalOutput HookOutput
for result := range results {
if result.err != nil && result.exitCode != 2 {
// Hook execution failed, skip this result
continue
}
// Handle exit code 2 (blocking error)
if result.exitCode == 2 {
finalOutput.Decision = "block"
finalOutput.Reason = result.stderr
continueVal := false
finalOutput.Continue = &continueVal
return &finalOutput, nil
}
// Try to parse JSON output
if result.stdout != "" {
var output HookOutput
if err := json.Unmarshal([]byte(result.stdout), &output); err == nil {
// Merge outputs (later hooks can override)
mergeHookOutputs(&finalOutput, &output)
}
}
}
return &finalOutput, nil
}
// mergeHookOutputs combines two hook outputs
func mergeHookOutputs(dst, src *HookOutput) {
if src.Continue != nil {
dst.Continue = src.Continue
}
if src.StopReason != "" {
dst.StopReason = src.StopReason
}
if src.Decision != "" {
dst.Decision = src.Decision
}
if src.Reason != "" {
dst.Reason = src.Reason
}
if src.SuppressOutput {
dst.SuppressOutput = true
}
}
func getCurrentWorkingDir() string {
cwd, err := os.Getwd()
if err != nil {
return "/"
}
return cwd
}
+193
View File
@@ -0,0 +1,193 @@
package hooks
import (
"context"
"encoding/json"
"os"
"path/filepath"
"testing"
"time"
)
func TestExecuteHooks(t *testing.T) {
// Create test scripts
tmpDir := t.TempDir()
// Simple echo script
echoScript := filepath.Join(tmpDir, "echo.sh")
if err := os.WriteFile(echoScript, []byte(`#!/bin/bash
cat
`), 0755); err != nil {
t.Fatalf("failed to create echo script: %v", err)
}
// Blocking script (exit code 2)
blockScript := filepath.Join(tmpDir, "block.sh")
if err := os.WriteFile(blockScript, []byte(`#!/bin/bash
echo "Blocked by policy" >&2
exit 2
`), 0755); err != nil {
t.Fatalf("failed to create block script: %v", err)
}
// JSON output script
jsonScript := filepath.Join(tmpDir, "json.sh")
if err := os.WriteFile(jsonScript, []byte(`#!/bin/bash
echo '{"decision": "approve", "reason": "Approved by test"}'
`), 0755); err != nil {
t.Fatalf("failed to create json script: %v", err)
}
tests := []struct {
name string
config *HookConfig
event HookEvent
input interface{}
expected *HookOutput
wantErr bool
}{
{
name: "simple command execution",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {{
Matcher: "bash",
Hooks: []HookEntry{{
Type: "command",
Command: echoScript,
}},
}},
},
},
event: PreToolUse,
input: &PreToolUseInput{
CommonInput: CommonInput{HookEventName: PreToolUse},
ToolName: "bash",
},
expected: &HookOutput{},
},
{
name: "blocking hook",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {{
Matcher: "bash",
Hooks: []HookEntry{{
Type: "command",
Command: blockScript,
}},
}},
},
},
event: PreToolUse,
input: &PreToolUseInput{
CommonInput: CommonInput{HookEventName: PreToolUse},
ToolName: "bash",
},
expected: &HookOutput{
Decision: "block",
Reason: "Blocked by policy\n",
Continue: boolPtr(false),
},
},
{
name: "JSON output parsing",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {{
Matcher: "bash",
Hooks: []HookEntry{{
Type: "command",
Command: jsonScript,
}},
}},
},
},
event: PreToolUse,
input: &PreToolUseInput{
CommonInput: CommonInput{HookEventName: PreToolUse},
ToolName: "bash",
},
expected: &HookOutput{
Decision: "approve",
Reason: "Approved by test",
},
},
{
name: "timeout handling",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {{
Matcher: "bash",
Hooks: []HookEntry{{
Type: "command",
Command: "sleep 10",
Timeout: 1,
}},
}},
},
},
event: PreToolUse,
input: &PreToolUseInput{
CommonInput: CommonInput{HookEventName: PreToolUse},
ToolName: "bash",
},
expected: &HookOutput{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
executor := NewExecutor(tt.config, "test-session", "/tmp/test.jsonl")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
got, err := executor.ExecuteHooks(ctx, tt.event, tt.input)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Compare outputs
if !compareHookOutputs(got, tt.expected) {
gotJSON, _ := json.MarshalIndent(got, "", " ")
expectedJSON, _ := json.MarshalIndent(tt.expected, "", " ")
t.Errorf("ExecuteHooks() output mismatch:\ngot:\n%s\nwant:\n%s", gotJSON, expectedJSON)
}
})
}
}
func boolPtr(b bool) *bool {
return &b
}
func compareHookOutputs(a, b *HookOutput) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
// Compare Continue pointers
if (a.Continue == nil) != (b.Continue == nil) {
return false
}
if a.Continue != nil && *a.Continue != *b.Continue {
return false
}
return a.StopReason == b.StopReason &&
a.SuppressOutput == b.SuppressOutput &&
a.Decision == b.Decision &&
a.Reason == b.Reason
}
+55
View File
@@ -0,0 +1,55 @@
package hooks
import (
"encoding/json"
)
// CommonInput contains fields common to all hook inputs
type CommonInput struct {
SessionID string `json:"session_id"` // Unique session identifier
TranscriptPath string `json:"transcript_path"` // Path to transcript file (if enabled)
CWD string `json:"cwd"` // Current working directory
HookEventName HookEvent `json:"hook_event_name"` // The hook event type
Timestamp int64 `json:"timestamp"` // Unix timestamp when hook fired
Model string `json:"model"` // AI model being used
Interactive bool `json:"interactive"` // Whether in interactive mode
}
// PreToolUseInput is passed to PreToolUse hooks
type PreToolUseInput struct {
CommonInput
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
}
// PostToolUseInput is passed to PostToolUse hooks
type PostToolUseInput struct {
CommonInput
ToolName string `json:"tool_name"`
ToolInput json.RawMessage `json:"tool_input"`
ToolResponse json.RawMessage `json:"tool_response"`
}
// UserPromptSubmitInput is passed to UserPromptSubmit hooks
type UserPromptSubmitInput struct {
CommonInput
Prompt string `json:"prompt"`
}
// StopInput is passed to Stop hooks
type StopInput struct {
CommonInput
StopHookActive bool `json:"stop_hook_active"`
Response string `json:"response"` // The agent's final response
StopReason string `json:"stop_reason"` // "completed", "cancelled", "error"
Meta json.RawMessage `json:"meta,omitempty"` // Additional metadata (e.g., token usage, model info)
}
// HookOutput represents the JSON output from a hook
type HookOutput struct {
Continue *bool `json:"continue,omitempty"`
StopReason string `json:"stopReason,omitempty"`
SuppressOutput bool `json:"suppressOutput,omitempty"`
Decision string `json:"decision,omitempty"` // "approve", "block", or ""
Reason string `json:"reason,omitempty"`
}
+10
View File
@@ -0,0 +1,10 @@
hooks:
InvalidEvent:
- hooks:
- type: command
command: "echo test"
PreToolUse:
- matcher: "[invalid regex"
hooks:
- type: command
command: "echo test"
+21
View File
@@ -0,0 +1,21 @@
hooks:
PreToolUse:
- matcher: "bash"
hooks:
- type: command
command: "echo 'Executing bash command'"
timeout: 5
- matcher: "fetch"
hooks:
- type: command
command: "echo 'Fetching URL'"
timeout: 10
UserPromptSubmit:
- hooks:
- type: command
command: "date >> /tmp/mcphost-prompts.log"
PostToolUse:
- matcher: ".*"
hooks:
- type: command
command: "echo 'Tool execution completed'"
+130
View File
@@ -0,0 +1,130 @@
package hooks
import (
"fmt"
"regexp"
"strings"
)
// Security patterns to detect potentially dangerous commands
var (
commandInjectionPattern = regexp.MustCompile(`[;&|]|\$\(|` + "`")
pathTraversalPattern = regexp.MustCompile(`\.\.\/`)
commandSubstitutionPattern = regexp.MustCompile(`\$\([^)]+\)|` + "`" + `[^` + "`" + `]+` + "`")
)
// validateHookCommand validates a hook command for security issues
func validateHookCommand(command string) error {
if command == "" {
return fmt.Errorf("empty command")
}
// Check for command injection attempts
if commandInjectionPattern.MatchString(command) {
// Allow simple pipes and redirects, but check for dangerous patterns
if containsDangerousPattern(command) {
return fmt.Errorf("potential command injection detected")
}
}
// Check for path traversal
if pathTraversalPattern.MatchString(command) {
return fmt.Errorf("path traversal detected")
}
// Check for command substitution
if commandSubstitutionPattern.MatchString(command) {
return fmt.Errorf("command substitution detected")
}
return nil
}
// containsDangerousPattern checks for specific dangerous command patterns
func containsDangerousPattern(command string) bool {
dangerousPatterns := []string{
"; rm ",
"&& rm ",
"| rm ",
"; dd ",
"&& dd ",
"| dd ",
"/dev/null 2>&1",
}
for _, pattern := range dangerousPatterns {
if strings.Contains(command, pattern) {
return true
}
}
// Check for multiple command separators which might indicate injection
separatorCount := 0
for _, sep := range []string{";", "&&", "||", "|"} {
separatorCount += strings.Count(command, sep)
}
// Allow up to 2 separators for reasonable command chaining
return separatorCount > 2
}
// ValidateHookConfig validates the entire hook configuration
func ValidateHookConfig(config *HookConfig) error {
if config == nil {
return fmt.Errorf("nil configuration")
}
for event, matchers := range config.Hooks {
if !event.IsValid() {
return fmt.Errorf("invalid event: %s", event)
}
for i, matcher := range matchers {
// Validate regex pattern if provided
if matcher.Matcher != "" {
if _, err := regexp.Compile(matcher.Matcher); err != nil {
return fmt.Errorf("invalid regex pattern in matcher %d for event %s: %w", i, event, err)
}
}
// Validate hooks
if len(matcher.Hooks) == 0 {
return fmt.Errorf("no hooks defined for matcher %d in event %s", i, event)
}
for j, hook := range matcher.Hooks {
if err := validateHookEntry(hook); err != nil {
return fmt.Errorf("invalid hook %d in matcher %d for event %s: %w", j, i, event, err)
}
}
}
}
return nil
}
// validateHookEntry validates a single hook entry
func validateHookEntry(hook HookEntry) error {
if hook.Type != "command" {
return fmt.Errorf("invalid hook type: %s (only 'command' is supported)", hook.Type)
}
if hook.Command == "" {
return fmt.Errorf("empty command")
}
// Basic security validation
if err := validateHookCommand(hook.Command); err != nil {
return fmt.Errorf("command validation failed: %w", err)
}
if hook.Timeout < 0 {
return fmt.Errorf("negative timeout: %d", hook.Timeout)
}
if hook.Timeout > 600 { // 10 minutes max
return fmt.Errorf("timeout too large: %d (max 600 seconds)", hook.Timeout)
}
return nil
}
+251
View File
@@ -0,0 +1,251 @@
package hooks
import (
"strings"
"testing"
)
func TestValidateHookCommand(t *testing.T) {
tests := []struct {
name string
command string
wantErr bool
errMsg string
}{
{
name: "simple command",
command: "echo hello",
wantErr: false,
},
{
name: "absolute path",
command: "/usr/local/bin/validator.py",
wantErr: false,
},
{
name: "command injection attempt",
command: "echo test; rm -rf /",
wantErr: true,
errMsg: "potential command injection",
},
{
name: "path traversal",
command: "cat ../../../etc/passwd",
wantErr: true,
errMsg: "path traversal detected",
},
{
name: "command substitution",
command: "echo $(/bin/sh -c 'malicious')",
wantErr: true,
errMsg: "command substitution detected",
},
{
name: "backtick substitution",
command: "echo `whoami`",
wantErr: true,
errMsg: "command substitution detected",
},
{
name: "empty command",
command: "",
wantErr: true,
errMsg: "empty command",
},
{
name: "simple pipe allowed",
command: "ps aux | grep process",
wantErr: false,
},
{
name: "too many command separators",
command: "cmd1 | cmd2 && cmd3 ; cmd4",
wantErr: true,
errMsg: "potential command injection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateHookCommand(tt.command)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestValidateHookConfig(t *testing.T) {
tests := []struct {
name string
config *HookConfig
wantErr bool
errMsg string
}{
{
name: "valid config",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "bash",
Hooks: []HookEntry{
{Type: "command", Command: "echo test"},
},
},
},
},
},
wantErr: false,
},
{
name: "nil config",
config: nil,
wantErr: true,
errMsg: "nil configuration",
},
{
name: "invalid event",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
"InvalidEvent": {
{
Hooks: []HookEntry{
{Type: "command", Command: "echo test"},
},
},
},
},
},
wantErr: true,
errMsg: "invalid event",
},
{
name: "invalid regex pattern",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "[invalid",
Hooks: []HookEntry{
{Type: "command", Command: "echo test"},
},
},
},
},
},
wantErr: true,
errMsg: "invalid regex pattern",
},
{
name: "no hooks defined",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Matcher: "bash",
Hooks: []HookEntry{},
},
},
},
},
wantErr: true,
errMsg: "no hooks defined",
},
{
name: "invalid hook type",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Hooks: []HookEntry{
{Type: "invalid", Command: "echo test"},
},
},
},
},
},
wantErr: true,
errMsg: "invalid hook type",
},
{
name: "empty command",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Hooks: []HookEntry{
{Type: "command", Command: ""},
},
},
},
},
},
wantErr: true,
errMsg: "empty command",
},
{
name: "negative timeout",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Hooks: []HookEntry{
{Type: "command", Command: "echo test", Timeout: -1},
},
},
},
},
},
wantErr: true,
errMsg: "negative timeout",
},
{
name: "timeout too large",
config: &HookConfig{
Hooks: map[HookEvent][]HookMatcher{
PreToolUse: {
{
Hooks: []HookEntry{
{Type: "command", Command: "echo test", Timeout: 700},
},
},
},
},
},
wantErr: true,
errMsg: "timeout too large",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateHookConfig(tt.config)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("error message %q does not contain %q", err.Error(), tt.errMsg)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}