mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
add cancellation
This commit is contained in:
+16
-5
@@ -384,6 +384,11 @@ func runAgenticLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
// Process the initial prompt with tool calls
|
||||
response, err := runAgenticStep(ctx, mcpAgent, cli, messages, config)
|
||||
if err != nil {
|
||||
// Check if this was a user cancellation
|
||||
if err.Error() == "generation cancelled by user" && cli != nil {
|
||||
cli.DisplayCancellation()
|
||||
return nil // Don't treat cancellation as an error for exit code
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -536,17 +541,23 @@ func runInteractiveLoop(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI,
|
||||
// Display user message
|
||||
cli.DisplayUserMessage(prompt)
|
||||
|
||||
// Add user message to history
|
||||
messages = append(messages, schema.UserMessage(prompt))
|
||||
// Create temporary messages with user input for processing
|
||||
tempMessages := append(messages, schema.UserMessage(prompt))
|
||||
|
||||
// Process the user input with tool calls
|
||||
response, err := runAgenticStep(ctx, mcpAgent, cli, messages, config)
|
||||
response, err := runAgenticStep(ctx, mcpAgent, cli, tempMessages, config)
|
||||
if err != nil {
|
||||
cli.DisplayError(fmt.Errorf("agent error: %v", err))
|
||||
// Check if this was a user cancellation
|
||||
if err.Error() == "generation cancelled by user" {
|
||||
cli.DisplayCancellation()
|
||||
} else {
|
||||
cli.DisplayError(fmt.Errorf("agent error: %v", err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Add assistant response to history
|
||||
// Only add to history after successful completion
|
||||
messages = append(messages, schema.UserMessage(prompt))
|
||||
messages = append(messages, response)
|
||||
}
|
||||
}
|
||||
|
||||
+139
-3
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -104,10 +106,17 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
|
||||
|
||||
// Main loop
|
||||
for step := 0; a.maxSteps == 0 || step < a.maxSteps; step++ {
|
||||
// Call the LLM
|
||||
response, err := a.model.Generate(ctx, workingMessages, model.WithTools(toolInfos))
|
||||
// Check if context was cancelled before making LLM call
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Call the LLM with cancellation support
|
||||
response, err := a.generateWithCancellation(ctx, workingMessages, toolInfos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate response: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add response to working messages
|
||||
@@ -194,6 +203,133 @@ func (a *Agent) GetTools() []tool.BaseTool {
|
||||
return a.toolManager.GetTools()
|
||||
}
|
||||
|
||||
// generateWithCancellation calls the LLM with ESC key cancellation support
|
||||
func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) {
|
||||
// Create a cancellable context for just this LLM call
|
||||
llmCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Channel to receive the LLM result
|
||||
resultChan := make(chan struct {
|
||||
message *schema.Message
|
||||
err error
|
||||
}, 1)
|
||||
|
||||
// Start the LLM generation in a goroutine
|
||||
go func() {
|
||||
message, err := a.model.Generate(llmCtx, messages, model.WithTools(toolInfos))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to generate response: %v", err)
|
||||
}
|
||||
resultChan <- struct {
|
||||
message *schema.Message
|
||||
err error
|
||||
}{message, err}
|
||||
}()
|
||||
|
||||
// Start ESC key listener (Bubble Tea handles all the complexity)
|
||||
escChan := make(chan bool, 1)
|
||||
stopListening := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
if a.listenForESC(stopListening) {
|
||||
escChan <- true
|
||||
} else {
|
||||
escChan <- false
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either LLM completion or ESC key
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
// Stop the ESC listener
|
||||
close(stopListening)
|
||||
return result.message, result.err
|
||||
case escPressed := <-escChan:
|
||||
if escPressed {
|
||||
cancel() // Cancel the LLM context
|
||||
return nil, fmt.Errorf("generation cancelled by user")
|
||||
}
|
||||
// ESC listener stopped normally, wait for LLM result
|
||||
result := <-resultChan
|
||||
return result.message, result.err
|
||||
case <-ctx.Done():
|
||||
// Stop the ESC listener
|
||||
close(stopListening)
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// escListenerModel is a simple Bubble Tea model for ESC key detection
|
||||
type escListenerModel struct {
|
||||
escPressed chan bool
|
||||
}
|
||||
|
||||
func (m escListenerModel) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m escListenerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if msg.Type == tea.KeyEsc {
|
||||
// Signal ESC was pressed
|
||||
select {
|
||||
case m.escPressed <- true:
|
||||
default:
|
||||
}
|
||||
return m, tea.Quit
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m escListenerModel) View() string {
|
||||
return "" // No visual output needed
|
||||
}
|
||||
|
||||
// listenForESC listens for ESC key press using Bubble Tea and returns true if detected
|
||||
func (a *Agent) listenForESC(stopChan chan bool) bool {
|
||||
escPressed := make(chan bool, 1)
|
||||
|
||||
model := escListenerModel{
|
||||
escPressed: escPressed,
|
||||
}
|
||||
|
||||
// Create a Bubble Tea program
|
||||
p := tea.NewProgram(model, tea.WithoutRenderer())
|
||||
|
||||
// Start the program in a goroutine
|
||||
go func() {
|
||||
if _, err := p.Run(); err != nil {
|
||||
// Program failed, try to signal completion
|
||||
select {
|
||||
case escPressed <- false:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for either ESC key or stop signal
|
||||
select {
|
||||
case <-stopChan:
|
||||
p.Kill()
|
||||
// Give the program time to fully terminate
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return false
|
||||
case pressed := <-escPressed:
|
||||
p.Kill()
|
||||
// Give the program time to fully terminate
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return pressed
|
||||
case <-time.After(30 * time.Second):
|
||||
// Timeout after 30 seconds to prevent hanging
|
||||
p.Kill()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources
|
||||
func (a *Agent) Close() error {
|
||||
return a.toolManager.Close()
|
||||
|
||||
@@ -48,8 +48,6 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// LoadSystemPrompt loads system prompt from file or returns the string directly
|
||||
func LoadSystemPrompt(input string) (string, error) {
|
||||
if input == "" {
|
||||
|
||||
+9
-1
@@ -52,7 +52,7 @@ func (c *CLI) GetPrompt() (string, error) {
|
||||
|
||||
var prompt string
|
||||
err := huh.NewForm(huh.NewGroup(huh.NewText().
|
||||
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit)").
|
||||
Title("Enter your prompt (Type /help for commands, Ctrl+C to quit, ESC to cancel generation)").
|
||||
Value(&prompt).
|
||||
CharLimit(5000)),
|
||||
).WithWidth(c.width).
|
||||
@@ -152,6 +152,13 @@ func (c *CLI) DisplayInfo(message string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a cancellation message
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
msg := c.messageRenderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now())
|
||||
c.messageContainer.AddMessage(msg)
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayDebugConfig displays configuration settings in debug mode using tool response block styling
|
||||
func (c *CLI) DisplayDebugConfig(config map[string]any) {
|
||||
msg := c.messageRenderer.RenderDebugConfigMessage(config, time.Now())
|
||||
@@ -169,6 +176,7 @@ func (c *CLI) DisplayHelp() {
|
||||
- ` + "`/history`" + `: Display conversation history
|
||||
- ` + "`/quit`" + `: Exit the application
|
||||
- ` + "`Ctrl+C`" + `: Exit at any time
|
||||
- ` + "`ESC`" + `: Cancel ongoing LLM generation
|
||||
|
||||
You can also just type your message to chat with the AI assistant.`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user