From ec620e4e88fe4e1430b9e5119329960878ec11b3 Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 26 Jun 2025 18:15:17 +0300 Subject: [PATCH] Streaming (#87) * basic streaming * fix output * update * update * fix display issue * update readme * fmt * cleanup * cleanup * cleanup * fix --- README.md | 4 ++ cmd/root.go | 74 ++++++++++++++++--- internal/agent/agent.go | 139 ++++++++++++++++++++++++++++++------ internal/agent/streaming.go | 110 ++++++++++++++++++++++++++++ internal/ui/cli.go | 30 ++++---- internal/ui/messages.go | 18 +++++ 6 files changed, 330 insertions(+), 45 deletions(-) create mode 100644 internal/agent/streaming.go diff --git a/README.md b/README.md index 1e95cc9a..7f5807bc 100644 --- a/README.md +++ b/README.md @@ -573,6 +573,7 @@ mcphost -p "Generate a random UUID" --quiet | tr '[:lower:]' '[:upper:]' - `-m, --model string`: Model to use (format: provider:model) (default "anthropic:claude-sonnet-4-20250514") - `-p, --prompt string`: **Run in non-interactive mode with the given prompt** - `--quiet`: **Suppress all output except the AI response (only works with --prompt)** +- `--stream`: Enable streaming responses (default: true, use `--stream=false` to disable) ### Authentication Subcommands - `mcphost auth login anthropic`: Authenticate with Anthropic using OAuth (alternative to API keys) @@ -625,6 +626,9 @@ top-p: 0.95 top-k: 40 stop-sequences: ["Human:", "Assistant:"] +# Streaming configuration +stream: false # Disable streaming (default: true) + # API Configuration provider-api-key: "your-api-key" # For OpenAI, Anthropic, or Google provider-url: "https://api.openai.com/v1" # Custom base URL diff --git a/cmd/root.go b/cmd/root.go index 9566602b..a15f969e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,6 +32,7 @@ var ( quietFlag bool noExitFlag bool maxSteps int + streamFlag bool // Enable streaming output scriptMCPConfig *config.Config // Used to override config in script mode // Session management @@ -163,6 +164,8 @@ func init() { BoolVar(&noExitFlag, "no-exit", false, "prevent non-interactive mode from exiting, show input prompt instead") rootCmd.PersistentFlags(). IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)") + rootCmd.PersistentFlags(). + BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display") // Session management flags rootCmd.PersistentFlags(). @@ -192,6 +195,7 @@ func init() { viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug")) viper.BindPFlag("prompt", rootCmd.PersistentFlags().Lookup("prompt")) viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps")) + viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream")) viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url")) viper.BindPFlag("provider-api-key", rootCmd.PersistentFlags().Lookup("provider-api-key")) viper.BindPFlag("max-tokens", rootCmd.PersistentFlags().Lookup("max-tokens")) @@ -285,10 +289,11 @@ func runNormalMode(ctx context.Context) error { // Create agent configuration agentConfig := &agent.AgentConfig{ - ModelConfig: modelConfig, - MCPConfig: mcpConfig, - SystemPrompt: systemPrompt, - MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it + ModelConfig: modelConfig, + MCPConfig: mcpConfig, + SystemPrompt: systemPrompt, + MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it + StreamingEnabled: viper.GetBool("stream"), } // Create the agent with spinner for Ollama models @@ -635,7 +640,40 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes currentSpinner.Start() } - result, err := mcpAgent.GenerateWithLoop(ctx, messages, + // Create streaming callback for real-time display + var streamingCallback agent.StreamingResponseHandler + var responseWasStreamed bool + var lastDisplayedContent string + var streamingContent strings.Builder + var streamingStarted bool + if cli != nil && !config.Quiet { + streamingCallback = func(chunk string) { + // Stop spinner before first chunk if still running + if currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil + } + // Mark that this response is being streamed + responseWasStreamed = true + + // Start streaming message on first chunk + if !streamingStarted { + cli.StartStreamingMessage(config.ModelName) + streamingStarted = true + } + + // Accumulate content and update message + streamingContent.WriteString(chunk) + cli.UpdateStreamingMessage(streamingContent.String()) + } + } + + // Reset streaming state before agent execution + responseWasStreamed = false + streamingStarted = false + streamingContent.Reset() + + result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages, // Tool call handler - called when a tool is about to be executed func(toolName, toolArgs string) { if !config.Quiet && cli != nil { @@ -713,18 +751,31 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes }, // Tool call content handler - called when content accompanies tool calls func(content string) { - if !config.Quiet && cli != nil { + if !config.Quiet && cli != nil && !responseWasStreamed { + // Only display if content wasn't already streamed // Stop spinner before displaying content if currentSpinner != nil { currentSpinner.Stop() currentSpinner = nil } cli.DisplayAssistantMessageWithModel(content, config.ModelName) + lastDisplayedContent = content + // Start spinner again for tool calls + currentSpinner = ui.NewSpinner("Thinking...") + currentSpinner.Start() + } else if responseWasStreamed { + // Content was already streamed, just track it and manage spinner + lastDisplayedContent = content + if currentSpinner != nil { + currentSpinner.Stop() + currentSpinner = nil + } // Start spinner again for tool calls currentSpinner = ui.NewSpinner("Thinking...") currentSpinner.Start() } }, + streamingCallback, // Add streaming callback as the last parameter ) // Make sure spinner is stopped if still running @@ -743,14 +794,17 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes response := result.FinalResponse conversationMessages := result.ConversationMessages - // Display assistant response with model name (skip if quiet) - if !config.Quiet && cli != nil { + // Display assistant response with model name + // Skip if: quiet mode, same content already displayed, or if streaming completed the full response + streamedFullResponse := responseWasStreamed && streamingContent.String() == response.Content + if !config.Quiet && cli != nil && response.Content != lastDisplayedContent && response.Content != "" && !streamedFullResponse { if err := cli.DisplayAssistantMessageWithModel(response.Content, config.ModelName); err != nil { cli.DisplayError(fmt.Errorf("display error: %v", err)) return nil, nil, err } - - // Update usage tracking with the last user message and response + } else if streamedFullResponse { + // Streaming was used - the message is already displayed in the message component + // Just update usage tracking with the last user message and response if len(messages) > 0 { lastUserMessage := "" // Find the last user message diff --git a/internal/agent/agent.go b/internal/agent/agent.go index cfcc076a..5cc87a8c 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "time" tea "github.com/charmbracelet/bubbletea" @@ -18,10 +19,11 @@ import ( // AgentConfig is the config for agent. type AgentConfig struct { - ModelConfig *models.ProviderConfig - MCPConfig *config.Config - SystemPrompt string - MaxSteps int + ModelConfig *models.ProviderConfig + MCPConfig *config.Config + SystemPrompt string + MaxSteps int + StreamingEnabled bool } // ToolCallHandler is a function type for handling tool calls as they happen @@ -36,16 +38,21 @@ type ToolResultHandler func(toolName, toolArgs, result string, isError bool) // ResponseHandler is a function type for handling LLM responses type ResponseHandler func(content string) +// StreamingResponseHandler is a function type for handling streaming LLM responses +type StreamingResponseHandler func(content string) + // ToolCallContentHandler is a function type for handling content that accompanies tool calls type ToolCallContentHandler func(content string) // Agent is the agent with real-time tool call display. type Agent struct { - toolManager *tools.MCPToolManager - model model.ToolCallingChatModel - maxSteps int - systemPrompt string - loadingMessage string // Message from provider loading (e.g., GPU fallback info) + toolManager *tools.MCPToolManager + model model.ToolCallingChatModel + maxSteps int + systemPrompt string + loadingMessage string // Message from provider loading (e.g., GPU fallback info) + providerType string // Provider type for streaming behavior + streamingEnabled bool // Whether streaming is enabled } // NewAgent creates an agent with MCP tool integration and real-time tool call display @@ -62,12 +69,23 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) { return nil, fmt.Errorf("failed to load MCP tools: %v", err) } + // Determine provider type from model string + providerType := "default" + if config.ModelConfig != nil && config.ModelConfig.ModelString != "" { + parts := strings.SplitN(config.ModelConfig.ModelString, ":", 2) + if len(parts) >= 1 { + providerType = parts[0] + } + } + return &Agent{ - toolManager: toolManager, - model: providerResult.Model, - maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop - systemPrompt: config.SystemPrompt, - loadingMessage: providerResult.Message, + toolManager: toolManager, + model: providerResult.Model, + maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop + systemPrompt: config.SystemPrompt, + loadingMessage: providerResult.Message, + providerType: providerType, + streamingEnabled: config.StreamingEnabled, }, nil } @@ -81,6 +99,13 @@ type GenerateWithLoopResult struct { func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message, onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) { + return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil) +} + +// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks +func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message, + onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) { + // Create a copy of messages to avoid modifying the original workingMessages := make([]*schema.Message, len(messages)) copy(workingMessages, messages) @@ -125,7 +150,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message } // Call the LLM with cancellation support - response, err := a.generateWithCancellation(ctx, workingMessages, toolInfos) + response, err := a.generateWithCancellationAndStreaming(ctx, workingMessages, toolInfos, onStreamingResponse) if err != nil { return nil, err } @@ -133,9 +158,8 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message // Add response to working messages workingMessages = append(workingMessages, response) - // Check if this is a tool call or final response - if len(response.ToolCalls) > 0 { - + // Check if this is a tool call or final response + if len(response.ToolCalls) > 0 { // Display any content that accompanies the tool calls if response.Content != "" && onToolCallContent != nil { onToolCallContent(response.Content) @@ -227,8 +251,83 @@ func (a *Agent) GetLoadingMessage() string { return a.loadingMessage } -// generateWithCancellation calls the LLM with ESC key cancellation support -func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) { + + +// generateWithCancellationAndStreaming calls the LLM with ESC key cancellation support and streaming callbacks +func (a *Agent) generateWithCancellationAndStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, streamingCallback StreamingResponseHandler) (*schema.Message, error) { + // Check if streaming is enabled + if !a.streamingEnabled { + // Use traditional non-streaming approach + return a.generateWithoutStreaming(ctx, messages, toolInfos) + } + + // Try streaming first if no tools are expected or if we can detect tool calls early + if len(toolInfos) == 0 { + // No tools available, use streaming directly + return a.generateWithStreamingAndCallback(ctx, messages, toolInfos, streamingCallback) + } + + // Try streaming with tool call detection + return a.generateWithStreamingFirstAndCallback(ctx, messages, toolInfos, streamingCallback) +} + + + + + +// generateWithStreamingAndCallback uses streaming for responses without tool calls with real-time callbacks +func (a *Agent) generateWithStreamingAndCallback(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, callback StreamingResponseHandler) (*schema.Message, error) { + // Try streaming first + reader, err := a.model.Stream(ctx, messages, model.WithTools(toolInfos)) + if err != nil { + // Fallback to non-streaming if streaming fails + return a.model.Generate(ctx, messages, model.WithTools(toolInfos)) + } + + // Use streaming with callback for real-time display + response, err := StreamWithCallback(ctx, reader, func(chunk string) { + if callback != nil { + callback(chunk) + } + }) + if err != nil { + // Fallback to non-streaming on error + return a.model.Generate(ctx, messages, model.WithTools(toolInfos)) + } + + // Return the complete streamed response (with tool calls if any) + return response, nil +} + +// generateWithStreamingFirstAndCallback attempts streaming first with provider-aware tool call detection and callbacks +func (a *Agent) generateWithStreamingFirstAndCallback(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, callback StreamingResponseHandler) (*schema.Message, error) { + // Try streaming first + reader, err := a.model.Stream(ctx, messages, model.WithTools(toolInfos)) + if err != nil { + // Fallback to non-streaming if streaming fails + return a.model.Generate(ctx, messages, model.WithTools(toolInfos)) + } + + // Use streaming with callback for real-time display + response, err := StreamWithCallback(ctx, reader, func(chunk string) { + if callback != nil { + callback(chunk) + } + }) + if err != nil { + // Fallback to non-streaming on error + return a.model.Generate(ctx, messages, model.WithTools(toolInfos)) + } + + // Return the complete streamed response (with tool calls if any) + // No need to restart - we have everything we need! + return response, nil +} + + + +// generateWithoutStreaming uses the traditional non-streaming approach +func (a *Agent) generateWithoutStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) { // Create a cancellable context for just this LLM call llmCtx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/internal/agent/streaming.go b/internal/agent/streaming.go new file mode 100644 index 00000000..a3dc9698 --- /dev/null +++ b/internal/agent/streaming.go @@ -0,0 +1,110 @@ +package agent + +import ( + "context" + "io" + "strings" + + "github.com/cloudwego/eino/schema" +) + +// StreamWithCallback streams content with real-time callbacks and returns complete response +// IMPORTANT: Tool calls are only processed after EOF is reached to ensure we have the complete +// and final tool call information. This prevents premature tool execution on partial data. +// Handles different provider streaming patterns: +// - Anthropic: Text content first, then tool calls streamed incrementally +// - OpenAI/Others: Tool calls first or alone +// - Mixed: Tool calls and content interleaved +func StreamWithCallback(ctx context.Context, reader *schema.StreamReader[*schema.Message], callback func(string)) (*schema.Message, error) { + defer reader.Close() + + var content strings.Builder + var accumulatedToolCalls map[string]*schema.ToolCall // Track tool calls by ID to handle incremental updates + var streamComplete bool + + accumulatedToolCalls = make(map[string]*schema.ToolCall) + + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + msg, err := reader.Recv() + if err == io.EOF { + // Stream is complete - now we can safely process tool calls + streamComplete = true + break + } + if err != nil { + return nil, err + } + + // Call callback for each chunk if provided (for real-time display) + if callback != nil && msg.Content != "" { + callback(msg.Content) + } + + // Accumulate content from all chunks + content.WriteString(msg.Content) + + // Accumulate tool calls incrementally - Anthropic streams them piece by piece + // NOTE: We don't process these tool calls until EOF is reached + if len(msg.ToolCalls) > 0 { + for _, toolCall := range msg.ToolCalls { + // Use tool call ID as key, but handle cases where ID might be empty in partial chunks + key := toolCall.ID + if key == "" { + // For chunks without ID, try to find existing tool call or create a temporary key + if len(accumulatedToolCalls) == 1 { + // If we have exactly one tool call being built, assume this chunk belongs to it + for existingKey := range accumulatedToolCalls { + key = existingKey + break + } + } else { + // Create a temporary key for this tool call + key = "temp_" + toolCall.Function.Name + } + } + + existing := accumulatedToolCalls[key] + if existing == nil { + // First time seeing this tool call + accumulatedToolCalls[key] = &schema.ToolCall{ + ID: toolCall.ID, + Function: toolCall.Function, + } + } else { + // Update existing tool call with new information + // Preserve non-empty values, accumulate arguments + if toolCall.ID != "" { + existing.ID = toolCall.ID + } + if toolCall.Function.Name != "" { + existing.Function.Name = toolCall.Function.Name + } + // Accumulate arguments (they come in pieces) + existing.Function.Arguments += toolCall.Function.Arguments + } + } + } + } + + // Only process tool calls after EOF - ensures we have complete information + var finalToolCalls []schema.ToolCall + if streamComplete && len(accumulatedToolCalls) > 0 { + finalToolCalls = make([]schema.ToolCall, 0, len(accumulatedToolCalls)) + for _, toolCall := range accumulatedToolCalls { + finalToolCalls = append(finalToolCalls, *toolCall) + } + } + + // Return complete message with all content and final tool calls + return &schema.Message{ + Role: schema.Assistant, + Content: content.String(), + ToolCalls: finalToolCalls, + }, nil +} \ No newline at end of file diff --git a/internal/ui/cli.go b/internal/ui/cli.go index 6d9bfd9a..e670087a 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -143,25 +143,25 @@ func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError c.displayContainer() } -// DisplayStreamingMessage displays streaming content -func (c *CLI) DisplayStreamingMessage(reader *schema.StreamReader[*schema.Message]) error { - // For streaming, we'll collect the content and then display it - var content strings.Builder - for { - msg, err := reader.Recv() - if err == io.EOF { - break - } - if err != nil { - return fmt.Errorf("stream receive error: %v", err) - } - content.WriteString(msg.Content) - } - return c.DisplayAssistantMessage(content.String()) +// StartStreamingMessage starts a streaming assistant message +func (c *CLI) StartStreamingMessage(modelName string) { + // Add an empty assistant message that we'll update during streaming + msg := c.messageRenderer.RenderAssistantMessage("", time.Now(), modelName) + c.messageContainer.AddMessage(msg) + c.displayContainer() } +// UpdateStreamingMessage updates the streaming message with new content +func (c *CLI) UpdateStreamingMessage(content string) { + // Update the last message (which should be the streaming assistant message) + c.messageContainer.UpdateLastMessage(content) + c.displayContainer() +} + + + // DisplayError displays an error message using the message component func (c *CLI) DisplayError(err error) { msg := c.messageRenderer.RenderErrorMessage(err.Error(), time.Now()) diff --git a/internal/ui/messages.go b/internal/ui/messages.go index ec1ba422..29581eab 100644 --- a/internal/ui/messages.go +++ b/internal/ui/messages.go @@ -486,6 +486,24 @@ func (c *MessageContainer) AddMessage(msg UIMessage) { c.messages = append(c.messages, msg) } +// UpdateLastMessage updates the content of the last message efficiently +func (c *MessageContainer) UpdateLastMessage(content string) { + if len(c.messages) == 0 { + return + } + + lastIdx := len(c.messages) - 1 + lastMsg := &c.messages[lastIdx] + + // Only re-render if content actually changed and it's an assistant message + if lastMsg.Type == AssistantMessage { + // Create a new renderer to update the message + renderer := NewMessageRenderer(c.width, false) + newMsg := renderer.RenderAssistantMessage(content, lastMsg.Timestamp, "") + c.messages[lastIdx] = newMsg + } +} + // Clear clears all messages from the container func (c *MessageContainer) Clear() { c.messages = make([]UIMessage, 0)