diff --git a/cmd/root.go b/cmd/root.go index a3860f90..7d91f7ea 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -38,6 +38,7 @@ var ( streamFlag bool // Enable streaming output compactMode bool // Enable compact output mode scriptMCPConfig *config.Config // Used to override config in script mode + approveToolRun bool // Session management saveSessionPath string @@ -302,6 +303,8 @@ func init() { BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling") rootCmd.PersistentFlags(). BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution") + rootCmd.PersistentFlags(). + BoolVar(&approveToolRun, "approve-tool-run", false, "enable requiring user approval for every tool call") // Session management flags rootCmd.PersistentFlags(). @@ -347,6 +350,7 @@ func init() { viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers")) viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu")) viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify")) + viper.BindPFlag("approve-tool-run", rootCmd.PersistentFlags().Lookup("approve-tool-run")) // Defaults are already set in flag definitions, no need to duplicate in viper @@ -445,7 +449,8 @@ func runNormalMode(ctx context.Context) error { debugLogger = bufferedLogger } - mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ModelConfig: modelConfig, + mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ + ModelConfig: modelConfig, MCPConfig: mcpConfig, SystemPrompt: systemPrompt, MaxSteps: viper.GetInt("max-steps"), @@ -743,7 +748,8 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet flag can only be used with --prompt/-p") } - return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor) + approveToolRun := viper.GetBool("approve-tool-run") + return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor, approveToolRun) } // AgenticLoopConfig configures the behavior of the unified agentic loop. @@ -754,6 +760,7 @@ type AgenticLoopConfig struct { IsInteractive bool // true for interactive mode, false for non-interactive InitialPrompt string // initial prompt for non-interactive mode ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit) + ApproveToolRun bool // only used in interactive mode // UI configuration Quiet bool // suppress all output except final response @@ -1103,7 +1110,27 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes currentSpinner.Start() } }, - streamingCallback, // Add streaming callback as the last parameter + // Add streaming callback handler + streamingCallback, + // Tool call approval handler - called before tool execution to get user approval + func(toolName, toolArgs string) (bool, error) { + if !config.IsInteractive || !config.ApproveToolRun { + return true, nil + } + if currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil + } + allow, err := cli.GetToolApproval(toolName, toolArgs) + if err != nil { + return false, err + } + // Start spinner again for tool calls + currentSpinner = ui.NewSpinner("Thinking...") + currentSpinner.Start() + + return allow, nil + }, ) // Make sure spinner is stopped if still running @@ -1306,6 +1333,7 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C IsInteractive: false, InitialPrompt: prompt, ContinueAfterRun: noExit, + ApproveToolRun: false, Quiet: quiet, ServerNames: serverNames, ToolNames: toolNames, @@ -1318,12 +1346,13 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C } // runInteractiveMode handles the interactive mode execution -func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error { +func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor, approveToolRun bool) error { // Configure and run unified agentic loop config := AgenticLoopConfig{ IsInteractive: true, InitialPrompt: "", ContinueAfterRun: false, + ApproveToolRun: approveToolRun, Quiet: false, ServerNames: serverNames, ToolNames: toolNames, diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 8a0e3cfd..4e810a3d 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -4,6 +4,9 @@ import ( "context" "encoding/json" "fmt" + "strings" + "time" + tea "github.com/charmbracelet/bubbletea" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" @@ -12,8 +15,6 @@ import ( "github.com/mark3labs/mcphost/internal/config" "github.com/mark3labs/mcphost/internal/models" "github.com/mark3labs/mcphost/internal/tools" - "strings" - "time" ) // AgentConfig holds configuration options for creating a new Agent. @@ -57,6 +58,10 @@ type StreamingResponseHandler func(content string) // It receives any text content that the model generates alongside tool calls. type ToolCallContentHandler func(content string) +// ToolApprovalHandler is a function type for handling user approval of tool calls. +// It receives the tool name and arguments, and returns true if the user approves. +type ToolApprovalHandler func(toolName, toolArgs string) (bool, error) + // Agent represents an AI agent with MCP tool integration and real-time tool call display. // It manages the interaction between an LLM and various tools through the MCP protocol. type Agent struct { @@ -128,17 +133,17 @@ type GenerateWithLoopResult struct { // It handles the conversation flow, executing tools as needed and invoking callbacks for various events. // This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support. func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message, - onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) { - - return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil) + onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onToolApproval ToolApprovalHandler, +) (*GenerateWithLoopResult, error) { + return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil, onToolApproval) } // GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks. // It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks. // The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled. func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message, - onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) { - + onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler, onToolApproval ToolApprovalHandler, +) (*GenerateWithLoopResult, error) { // Create a copy of messages to avoid modifying the original workingMessages := make([]*schema.Message, len(messages)) copy(workingMessages, messages) @@ -200,6 +205,19 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc // Handle tool calls for _, toolCall := range response.ToolCalls { + if onToolApproval != nil { + approved, err := onToolApproval(toolCall.Function.Name, toolCall.Function.Arguments) + if err != nil { + return nil, err + } + if !approved { + rejectedMsg := fmt.Sprintf("The user did not allow tool call %s. Reason: User cancelled.", toolCall.Function.Name) + toolMessage := schema.ToolMessage(rejectedMsg, toolCall.ID) + workingMessages = append(workingMessages, toolMessage) + continue + } + } + // Notify about tool call if onToolCall != nil { onToolCall(toolCall.Function.Name, toolCall.Function.Arguments) diff --git a/internal/config/config.go b/internal/config/config.go index 55a32a76..c3e365e5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -166,6 +166,7 @@ type Config struct { Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"` Theme any `json:"theme" yaml:"theme"` MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"` + ApproveToolRun bool `json:"approve-tool-run" yaml:"approve-tool-run"` // Model generation parameters MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"` diff --git a/internal/ui/cli.go b/internal/ui/cli.go index 2bfb451b..89bef174 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -13,9 +13,7 @@ import ( "golang.org/x/term" ) -var ( - promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) -) +var promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12")) // CLI manages the command-line interface for MCPHost, providing message rendering, // user input handling, and display management. It supports both standard and compact @@ -377,6 +375,22 @@ func (c *CLI) IsSlashCommand(input string) bool { return strings.HasPrefix(input, "/") } +// GetToolApproval asks the user for permission to execute the tool with the given +// arguments. Returns true if the user approves. +func (c *CLI) GetToolApproval(toolName, toolArgs string) (bool, error) { + input := NewToolApprovalInput(toolName, toolArgs, c.width) + p := tea.NewProgram(input) + finalModel, err := p.Run() + if err != nil { + return false, err + } + + if finalInput, ok := finalModel.(*ToolApprovalInput); ok { + return finalInput.approved, nil + } + return false, fmt.Errorf("GetToolApproval: unexpected error type") +} + // SlashCommandResult encapsulates the outcome of processing a slash command, // indicating whether the command was recognized and handled, and whether the // conversation history should be cleared as a result of the command. diff --git a/internal/ui/tool_approval_input.go b/internal/ui/tool_approval_input.go new file mode 100644 index 00000000..01970ccf --- /dev/null +++ b/internal/ui/tool_approval_input.go @@ -0,0 +1,135 @@ +package ui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textarea" + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" +) + +type ToolApprovalInput struct { + textarea textarea.Model + toolName string + toolArgs string + width int + selected bool // true when "yes" is highlighted and false when "no" is + approved bool + done bool +} + +func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput { + ta := textarea.New() + ta.Placeholder = "" + ta.ShowLineNumbers = false + ta.CharLimit = 1000 + ta.SetWidth(width - 8) // Account for container padding, border and internal padding + ta.SetHeight(4) // Default to 3 lines like huh + ta.Focus() + + // Style the textarea to match huh theme + ta.FocusedStyle.Base = lipgloss.NewStyle() + ta.FocusedStyle.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + ta.FocusedStyle.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252")) + ta.FocusedStyle.Prompt = lipgloss.NewStyle() + ta.FocusedStyle.CursorLine = lipgloss.NewStyle() + ta.Cursor.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("39")) + + return &ToolApprovalInput{ + textarea: ta, + toolName: toolName, + toolArgs: toolArgs, + width: width, + selected: true, + } +} + +func (t *ToolApprovalInput) Init() tea.Cmd { + return textarea.Blink +} + +func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + switch msg.String() { + case "y", "Y": + t.approved = true + t.done = true + return t, tea.Quit + case "n", "N": + t.approved = false + t.done = true + return t, tea.Quit + case "left": + t.selected = true + return t, nil + case "right": + t.selected = false + return t, nil + case "enter": + t.approved = t.selected + t.done = true + return t, tea.Quit + case "esc", "ctrl+c": + t.approved = false + t.done = true + return t, tea.Quit + } + } + return t, nil +} + +func (t *ToolApprovalInput) View() string { + if t.done { + return "we are done" + } + // Add left padding to entire component (2 spaces like other UI elements) + containerStyle := lipgloss.NewStyle().PaddingLeft(2) + + // Title + titleStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("252")). + MarginBottom(1) + + // Input box with huh-like styling + inputBoxStyle := lipgloss.NewStyle(). + Border(lipgloss.ThickBorder()). + BorderLeft(true). + BorderRight(false). + BorderTop(false). + BorderBottom(false). + BorderForeground(lipgloss.Color("39")). + PaddingLeft(1). + Width(t.width - 2) // Account for container padding + + // Style for the currently selected/highlighted option + selectedStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")). // Bright green + Bold(true). + Underline(true) + + // Style for the unselected/unhighlighted option + unselectedStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("240")) // Dark gray + + // Build the view + var view strings.Builder + view.WriteString(titleStyle.Render("Allow tool execution")) + view.WriteString("\n") + details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs) + view.WriteString(details) + view.WriteString("Allow tool execution: ") + + var yesText, noText string + if t.selected { + yesText = selectedStyle.Render("[y]es") + noText = unselectedStyle.Render("[n]o") + } else { + yesText = unselectedStyle.Render("[y]es") + noText = selectedStyle.Render("[n]o") + } + view.WriteString(yesText + "/" + noText + "\n") + + return containerStyle.Render(inputBoxStyle.Render(view.String())) +} diff --git a/sdk/mcphost.go b/sdk/mcphost.go index 95c01db2..b598e5ca 100644 --- a/sdk/mcphost.go +++ b/sdk/mcphost.go @@ -142,6 +142,7 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) { nil, // onToolResult nil, // onResponse nil, // onToolCallContent + nil, // onToolApproval ) if err != nil { return "", err @@ -181,6 +182,7 @@ func (m *MCPHost) PromptWithCallbacks( nil, // onResponse nil, // onToolCallContent onStreaming, + nil, // onToolApproval ) if err != nil { return "", err