diff --git a/cmd/root.go b/cmd/root.go index 6e540afa..3871456f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -384,6 +384,11 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes // Process the initial prompt with tool calls response, err := runAgenticStep(ctx, mcpAgent, cli, messages, config) if err != nil { + // Check if this was a user cancellation + if err.Error() == "generation cancelled by user" && cli != nil { + cli.DisplayCancellation() + return nil // Don't treat cancellation as an error for exit code + } return err } @@ -536,17 +541,23 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, // Display user message cli.DisplayUserMessage(prompt) - // Add user message to history - messages = append(messages, schema.UserMessage(prompt)) + // Create temporary messages with user input for processing + tempMessages := append(messages, schema.UserMessage(prompt)) // Process the user input with tool calls - response, err := runAgenticStep(ctx, mcpAgent, cli, messages, config) + response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config) if err != nil { - cli.DisplayError(fmt.Errorf("agent error: %v", err)) + // Check if this was a user cancellation + if err.Error() == "generation cancelled by user" { + cli.DisplayCancellation() + } else { + cli.DisplayError(fmt.Errorf("agent error: %v", err)) + } continue } - // Add assistant response to history + // Only add to history after successful completion + messages = append(messages, schema.UserMessage(prompt)) messages = append(messages, response) } } diff --git a/internal/agent/agent.go b/internal/agent/agent.go index ba79433c..135533d7 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "fmt" + "time" + tea "github.com/charmbracelet/bubbletea" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/components/tool" "github.com/cloudwego/eino/schema" @@ -104,10 +106,17 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message // Main loop for step := 0; a.maxSteps == 0 || step < a.maxSteps; step++ { - // Call the LLM - response, err := a.model.Generate(ctx, workingMessages, model.WithTools(toolInfos)) + // Check if context was cancelled before making LLM call + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + // Call the LLM with cancellation support + response, err := a.generateWithCancellation(ctx, workingMessages, toolInfos) if err != nil { - return nil, fmt.Errorf("failed to generate response: %v", err) + return nil, err } // Add response to working messages @@ -194,6 +203,133 @@ func (a *Agent) GetTools() []tool.BaseTool { return a.toolManager.GetTools() } +// generateWithCancellation calls the LLM with ESC key cancellation support +func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) { + // Create a cancellable context for just this LLM call + llmCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Channel to receive the LLM result + resultChan := make(chan struct { + message *schema.Message + err error + }, 1) + + // Start the LLM generation in a goroutine + go func() { + message, err := a.model.Generate(llmCtx, messages, model.WithTools(toolInfos)) + if err != nil { + err = fmt.Errorf("failed to generate response: %v", err) + } + resultChan <- struct { + message *schema.Message + err error + }{message, err} + }() + + // Start ESC key listener (Bubble Tea handles all the complexity) + escChan := make(chan bool, 1) + stopListening := make(chan bool, 1) + + go func() { + if a.listenForESC(stopListening) { + escChan <- true + } else { + escChan <- false + } + }() + + // Wait for either LLM completion or ESC key + select { + case result := <-resultChan: + // Stop the ESC listener + close(stopListening) + return result.message, result.err + case escPressed := <-escChan: + if escPressed { + cancel() // Cancel the LLM context + return nil, fmt.Errorf("generation cancelled by user") + } + // ESC listener stopped normally, wait for LLM result + result := <-resultChan + return result.message, result.err + case <-ctx.Done(): + // Stop the ESC listener + close(stopListening) + return nil, ctx.Err() + } +} + +// escListenerModel is a simple Bubble Tea model for ESC key detection +type escListenerModel struct { + escPressed chan bool +} + +func (m escListenerModel) Init() tea.Cmd { + return nil +} + +func (m escListenerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + if msg.Type == tea.KeyEsc { + // Signal ESC was pressed + select { + case m.escPressed <- true: + default: + } + return m, tea.Quit + } + } + return m, nil +} + +func (m escListenerModel) View() string { + return "" // No visual output needed +} + +// listenForESC listens for ESC key press using Bubble Tea and returns true if detected +func (a *Agent) listenForESC(stopChan chan bool) bool { + escPressed := make(chan bool, 1) + + model := escListenerModel{ + escPressed: escPressed, + } + + // Create a Bubble Tea program + p := tea.NewProgram(model, tea.WithoutRenderer()) + + // Start the program in a goroutine + go func() { + if _, err := p.Run(); err != nil { + // Program failed, try to signal completion + select { + case escPressed <- false: + default: + } + } + }() + + // Wait for either ESC key or stop signal + select { + case <-stopChan: + p.Kill() + // Give the program time to fully terminate + time.Sleep(50 * time.Millisecond) + return false + case pressed := <-escPressed: + p.Kill() + // Give the program time to fully terminate + time.Sleep(50 * time.Millisecond) + return pressed + case <-time.After(30 * time.Second): + // Timeout after 30 seconds to prevent hanging + p.Kill() + time.Sleep(50 * time.Millisecond) + return false + } +} + // Close closes the agent and cleans up resources func (a *Agent) Close() error { return a.toolManager.Close() diff --git a/internal/config/config.go b/internal/config/config.go index 1778183e..238a4bf9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,8 +48,6 @@ func (c *Config) Validate() error { return nil } - - // LoadSystemPrompt loads system prompt from file or returns the string directly func LoadSystemPrompt(input string) (string, error) { if input == "" { diff --git a/internal/ui/cli.go b/internal/ui/cli.go index bb547c82..fe8c29df 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -52,7 +52,7 @@ func (c *CLI) GetPrompt() (string, error) { var prompt string err := huh.NewForm(huh.NewGroup(huh.NewText(). - Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)"). + Title("Enter your prompt (Type /help for commands, Ctrl+C to quit, ESC to cancel generation)"). Value(&prompt). CharLimit(5000)), ).WithWidth(c.width). @@ -152,6 +152,13 @@ func (c *CLI) DisplayInfo(message string) { c.displayContainer() } +// DisplayCancellation displays a cancellation message +func (c *CLI) DisplayCancellation() { + msg := c.messageRenderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()) + c.messageContainer.AddMessage(msg) + c.displayContainer() +} + // DisplayDebugConfig displays configuration settings in debug mode using tool response block styling func (c *CLI) DisplayDebugConfig(config map[string]any) { msg := c.messageRenderer.RenderDebugConfigMessage(config, time.Now()) @@ -169,6 +176,7 @@ func (c *CLI) DisplayHelp() { - ` + "`/history`" + `: Display conversation history - ` + "`/quit`" + `: Exit the application - ` + "`Ctrl+C`" + `: Exit at any time +- ` + "`ESC`" + `: Cancel ongoing LLM generation You can also just type your message to chat with the AI assistant.`