mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Feat: Add option to require approval before tool execution (#140)
Adds a new CLI option, `--approve-tool-run` (or via config setting), that when enabled, prompts the user to approve a tool's execution before it runs. This option is disabled by default to maintain existing behavior.
This commit is contained in:
+33
-4
@@ -38,6 +38,7 @@ var (
|
||||
streamFlag bool // Enable streaming output
|
||||
compactMode bool // Enable compact output mode
|
||||
scriptMCPConfig *config.Config // Used to override config in script mode
|
||||
approveToolRun bool
|
||||
|
||||
// Session management
|
||||
saveSessionPath string
|
||||
@@ -302,6 +303,8 @@ func init() {
|
||||
BoolVar(&compactMode, "compact", false, "enable compact output mode without fancy styling")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noHooks, "no-hooks", false, "disable all hooks execution")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&approveToolRun, "approve-tool-run", false, "enable requiring user approval for every tool call")
|
||||
|
||||
// Session management flags
|
||||
rootCmd.PersistentFlags().
|
||||
@@ -347,6 +350,7 @@ func init() {
|
||||
viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
|
||||
viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
|
||||
viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
|
||||
viper.BindPFlag("approve-tool-run", rootCmd.PersistentFlags().Lookup("approve-tool-run"))
|
||||
|
||||
// Defaults are already set in flag definitions, no need to duplicate in viper
|
||||
|
||||
@@ -445,7 +449,8 @@ func runNormalMode(ctx context.Context) error {
|
||||
debugLogger = bufferedLogger
|
||||
}
|
||||
|
||||
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ModelConfig: modelConfig,
|
||||
mcpAgent, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: mcpConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
@@ -743,7 +748,8 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet flag can only be used with --prompt/-p")
|
||||
}
|
||||
|
||||
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor)
|
||||
approveToolRun := viper.GetBool("approve-tool-run")
|
||||
return runInteractiveMode(ctx, mcpAgent, cli, serverNames, toolNames, modelName, messages, sessionManager, hookExecutor, approveToolRun)
|
||||
}
|
||||
|
||||
// AgenticLoopConfig configures the behavior of the unified agentic loop.
|
||||
@@ -754,6 +760,7 @@ type AgenticLoopConfig struct {
|
||||
IsInteractive bool // true for interactive mode, false for non-interactive
|
||||
InitialPrompt string // initial prompt for non-interactive mode
|
||||
ContinueAfterRun bool // true to continue to interactive mode after initial run (--no-exit)
|
||||
ApproveToolRun bool // only used in interactive mode
|
||||
|
||||
// UI configuration
|
||||
Quiet bool // suppress all output except final response
|
||||
@@ -1103,7 +1110,27 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
currentSpinner.Start()
|
||||
}
|
||||
},
|
||||
streamingCallback, // Add streaming callback as the last parameter
|
||||
// Add streaming callback handler
|
||||
streamingCallback,
|
||||
// Tool call approval handler - called before tool execution to get user approval
|
||||
func(toolName, toolArgs string) (bool, error) {
|
||||
if !config.IsInteractive || !config.ApproveToolRun {
|
||||
return true, nil
|
||||
}
|
||||
if currentSpinner != nil {
|
||||
currentSpinner.Stop()
|
||||
currentSpinner = nil
|
||||
}
|
||||
allow, err := cli.GetToolApproval(toolName, toolArgs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
// Start spinner again for tool calls
|
||||
currentSpinner = ui.NewSpinner("Thinking...")
|
||||
currentSpinner.Start()
|
||||
|
||||
return allow, nil
|
||||
},
|
||||
)
|
||||
|
||||
// Make sure spinner is stopped if still running
|
||||
@@ -1306,6 +1333,7 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
|
||||
IsInteractive: false,
|
||||
InitialPrompt: prompt,
|
||||
ContinueAfterRun: noExit,
|
||||
ApproveToolRun: false,
|
||||
Quiet: quiet,
|
||||
ServerNames: serverNames,
|
||||
ToolNames: toolNames,
|
||||
@@ -1318,12 +1346,13 @@ func runNonInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.C
|
||||
}
|
||||
|
||||
// runInteractiveMode handles the interactive mode execution
|
||||
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor) error {
|
||||
func runInteractiveMode(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, serverNames, toolNames []string, modelName string, messages []*schema.Message, sessionManager *session.Manager, hookExecutor *hooks.Executor, approveToolRun bool) error {
|
||||
// Configure and run unified agentic loop
|
||||
config := AgenticLoopConfig{
|
||||
IsInteractive: true,
|
||||
InitialPrompt: "",
|
||||
ContinueAfterRun: false,
|
||||
ApproveToolRun: approveToolRun,
|
||||
Quiet: false,
|
||||
ServerNames: serverNames,
|
||||
ToolNames: toolNames,
|
||||
|
||||
+25
-7
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
@@ -12,8 +15,6 @@ import (
|
||||
"github.com/mark3labs/mcphost/internal/config"
|
||||
"github.com/mark3labs/mcphost/internal/models"
|
||||
"github.com/mark3labs/mcphost/internal/tools"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AgentConfig holds configuration options for creating a new Agent.
|
||||
@@ -57,6 +58,10 @@ type StreamingResponseHandler func(content string)
|
||||
// It receives any text content that the model generates alongside tool calls.
|
||||
type ToolCallContentHandler func(content string)
|
||||
|
||||
// ToolApprovalHandler is a function type for handling user approval of tool calls.
|
||||
// It receives the tool name and arguments, and returns true if the user approves.
|
||||
type ToolApprovalHandler func(toolName, toolArgs string) (bool, error)
|
||||
|
||||
// Agent represents an AI agent with MCP tool integration and real-time tool call display.
|
||||
// It manages the interaction between an LLM and various tools through the MCP protocol.
|
||||
type Agent struct {
|
||||
@@ -128,17 +133,17 @@ type GenerateWithLoopResult struct {
|
||||
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events.
|
||||
// This method does not support streaming responses; use GenerateWithLoopAndStreaming for streaming support.
|
||||
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
|
||||
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onToolApproval ToolApprovalHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil, onToolApproval)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks.
|
||||
// It handles the conversation flow, executing tools as needed and invoking callbacks for various events including streaming chunks.
|
||||
// The onStreamingResponse callback is invoked for each content chunk during streaming if streaming is enabled.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {
|
||||
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler, onToolApproval ToolApprovalHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
// Create a copy of messages to avoid modifying the original
|
||||
workingMessages := make([]*schema.Message, len(messages))
|
||||
copy(workingMessages, messages)
|
||||
@@ -200,6 +205,19 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc
|
||||
|
||||
// Handle tool calls
|
||||
for _, toolCall := range response.ToolCalls {
|
||||
if onToolApproval != nil {
|
||||
approved, err := onToolApproval(toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !approved {
|
||||
rejectedMsg := fmt.Sprintf("The user did not allow tool call %s. Reason: User cancelled.", toolCall.Function.Name)
|
||||
toolMessage := schema.ToolMessage(rejectedMsg, toolCall.ID)
|
||||
workingMessages = append(workingMessages, toolMessage)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Notify about tool call
|
||||
if onToolCall != nil {
|
||||
onToolCall(toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
|
||||
@@ -166,6 +166,7 @@ type Config struct {
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
MarkdownTheme any `json:"markdown-theme" yaml:"markdown-theme"`
|
||||
ApproveToolRun bool `json:"approve-tool-run" yaml:"approve-tool-run"`
|
||||
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
|
||||
+17
-3
@@ -13,9 +13,7 @@ import (
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
var (
|
||||
promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
|
||||
)
|
||||
var promptStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("12"))
|
||||
|
||||
// CLI manages the command-line interface for MCPHost, providing message rendering,
|
||||
// user input handling, and display management. It supports both standard and compact
|
||||
@@ -377,6 +375,22 @@ func (c *CLI) IsSlashCommand(input string) bool {
|
||||
return strings.HasPrefix(input, "/")
|
||||
}
|
||||
|
||||
// GetToolApproval asks the user for permission to execute the tool with the given
|
||||
// arguments. Returns true if the user approves.
|
||||
func (c *CLI) GetToolApproval(toolName, toolArgs string) (bool, error) {
|
||||
input := NewToolApprovalInput(toolName, toolArgs, c.width)
|
||||
p := tea.NewProgram(input)
|
||||
finalModel, err := p.Run()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if finalInput, ok := finalModel.(*ToolApprovalInput); ok {
|
||||
return finalInput.approved, nil
|
||||
}
|
||||
return false, fmt.Errorf("GetToolApproval: unexpected error type")
|
||||
}
|
||||
|
||||
// SlashCommandResult encapsulates the outcome of processing a slash command,
|
||||
// indicating whether the command was recognized and handled, and whether the
|
||||
// conversation history should be cleared as a result of the command.
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textarea"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
type ToolApprovalInput struct {
|
||||
textarea textarea.Model
|
||||
toolName string
|
||||
toolArgs string
|
||||
width int
|
||||
selected bool // true when "yes" is highlighted and false when "no" is
|
||||
approved bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 1000
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Style the textarea to match huh theme
|
||||
ta.FocusedStyle.Base = lipgloss.NewStyle()
|
||||
ta.FocusedStyle.Placeholder = lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
|
||||
ta.FocusedStyle.Text = lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
|
||||
ta.FocusedStyle.Prompt = lipgloss.NewStyle()
|
||||
ta.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
||||
ta.Cursor.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("39"))
|
||||
|
||||
return &ToolApprovalInput{
|
||||
textarea: ta,
|
||||
toolName: toolName,
|
||||
toolArgs: toolArgs,
|
||||
width: width,
|
||||
selected: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
t.approved = true
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "n", "N":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "left":
|
||||
t.selected = true
|
||||
return t, nil
|
||||
case "right":
|
||||
t.selected = false
|
||||
return t, nil
|
||||
case "enter":
|
||||
t.approved = t.selected
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "esc", "ctrl+c":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) View() string {
|
||||
if t.done {
|
||||
return "we are done"
|
||||
}
|
||||
// Add left padding to entire component (2 spaces like other UI elements)
|
||||
containerStyle := lipgloss.NewStyle().PaddingLeft(2)
|
||||
|
||||
// Title
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("252")).
|
||||
MarginBottom(1)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(lipgloss.Color("39")).
|
||||
PaddingLeft(1).
|
||||
Width(t.width - 2) // Account for container padding
|
||||
|
||||
// Style for the currently selected/highlighted option
|
||||
selectedStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("42")). // Bright green
|
||||
Bold(true).
|
||||
Underline(true)
|
||||
|
||||
// Style for the unselected/unhighlighted option
|
||||
unselectedStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("240")) // Dark gray
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render("Allow tool execution"))
|
||||
view.WriteString("\n")
|
||||
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
|
||||
view.WriteString(details)
|
||||
view.WriteString("Allow tool execution: ")
|
||||
|
||||
var yesText, noText string
|
||||
if t.selected {
|
||||
yesText = selectedStyle.Render("[y]es")
|
||||
noText = unselectedStyle.Render("[n]o")
|
||||
} else {
|
||||
yesText = unselectedStyle.Render("[y]es")
|
||||
noText = selectedStyle.Render("[n]o")
|
||||
}
|
||||
view.WriteString(yesText + "/" + noText + "\n")
|
||||
|
||||
return containerStyle.Render(inputBoxStyle.Render(view.String()))
|
||||
}
|
||||
@@ -142,6 +142,7 @@ func (m *MCPHost) Prompt(ctx context.Context, message string) (string, error) {
|
||||
nil, // onToolResult
|
||||
nil, // onResponse
|
||||
nil, // onToolCallContent
|
||||
nil, // onToolApproval
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -181,6 +182,7 @@ func (m *MCPHost) PromptWithCallbacks(
|
||||
nil, // onResponse
|
||||
nil, // onToolCallContent
|
||||
onStreaming,
|
||||
nil, // onToolApproval
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
Reference in New Issue
Block a user