Streaming (#87)

* basic streaming

* fix output

* update

* update

* fix display issue

* update readme

* fmt

* cleanup

* cleanup

* cleanup

* fix
This commit is contained in:
Ed Zynda
2025-06-26 18:15:17 +03:00
committed by GitHub
parent a66c55e175
commit ec620e4e88
6 changed files with 330 additions and 45 deletions
+4
View File
@@ -573,6 +573,7 @@ mcphost -p "Generate a random UUID" --quiet | tr '[:lower:]' '[:upper:]'
- `-m, --model string`: Model to use (format: provider:model) (default "anthropic:claude-sonnet-4-20250514")
- `-p, --prompt string`: **Run in non-interactive mode with the given prompt**
- `--quiet`: **Suppress all output except the AI response (only works with --prompt)**
- `--stream`: Enable streaming responses (default: true, use `--stream=false` to disable)
### Authentication Subcommands
- `mcphost auth login anthropic`: Authenticate with Anthropic using OAuth (alternative to API keys)
@@ -625,6 +626,9 @@ top-p: 0.95
top-k: 40
stop-sequences: ["Human:", "Assistant:"]
# Streaming configuration
stream: false # Disable streaming (default: true)
# API Configuration
provider-api-key: "your-api-key" # For OpenAI, Anthropic, or Google
provider-url: "https://api.openai.com/v1" # Custom base URL
+64 -10
View File
@@ -32,6 +32,7 @@ var (
quietFlag bool
noExitFlag bool
maxSteps int
streamFlag bool // Enable streaming output
scriptMCPConfig *config.Config // Used to override config in script mode
// Session management
@@ -163,6 +164,8 @@ func init() {
BoolVar(&noExitFlag, "no-exit", false, "prevent non-interactive mode from exiting, show input prompt instead")
rootCmd.PersistentFlags().
IntVar(&maxSteps, "max-steps", 0, "maximum number of agent steps (0 for unlimited)")
rootCmd.PersistentFlags().
BoolVar(&streamFlag, "stream", true, "enable streaming output for faster response display")
// Session management flags
rootCmd.PersistentFlags().
@@ -192,6 +195,7 @@ func init() {
viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))
viper.BindPFlag("prompt", rootCmd.PersistentFlags().Lookup("prompt"))
viper.BindPFlag("max-steps", rootCmd.PersistentFlags().Lookup("max-steps"))
viper.BindPFlag("stream", rootCmd.PersistentFlags().Lookup("stream"))
viper.BindPFlag("provider-url", rootCmd.PersistentFlags().Lookup("provider-url"))
viper.BindPFlag("provider-api-key", rootCmd.PersistentFlags().Lookup("provider-api-key"))
viper.BindPFlag("max-tokens", rootCmd.PersistentFlags().Lookup("max-tokens"))
@@ -285,10 +289,11 @@ func runNormalMode(ctx context.Context) error {
// Create agent configuration
agentConfig := &agent.AgentConfig{
ModelConfig: modelConfig,
MCPConfig: mcpConfig,
SystemPrompt: systemPrompt,
MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it
ModelConfig: modelConfig,
MCPConfig: mcpConfig,
SystemPrompt: systemPrompt,
MaxSteps: viper.GetInt("max-steps"), // Pass 0 for infinite, agent will handle it
StreamingEnabled: viper.GetBool("stream"),
}
// Create the agent with spinner for Ollama models
@@ -635,7 +640,40 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
currentSpinner.Start()
}
result, err := mcpAgent.GenerateWithLoop(ctx, messages,
// Create streaming callback for real-time display
var streamingCallback agent.StreamingResponseHandler
var responseWasStreamed bool
var lastDisplayedContent string
var streamingContent strings.Builder
var streamingStarted bool
if cli != nil && !config.Quiet {
streamingCallback = func(chunk string) {
// Stop spinner before first chunk if still running
if currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
// Mark that this response is being streamed
responseWasStreamed = true
// Start streaming message on first chunk
if !streamingStarted {
cli.StartStreamingMessage(config.ModelName)
streamingStarted = true
}
// Accumulate content and update message
streamingContent.WriteString(chunk)
cli.UpdateStreamingMessage(streamingContent.String())
}
}
// Reset streaming state before agent execution
responseWasStreamed = false
streamingStarted = false
streamingContent.Reset()
result, err := mcpAgent.GenerateWithLoopAndStreaming(ctx, messages,
// Tool call handler - called when a tool is about to be executed
func(toolName, toolArgs string) {
if !config.Quiet && cli != nil {
@@ -713,18 +751,31 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
},
// Tool call content handler - called when content accompanies tool calls
func(content string) {
if !config.Quiet && cli != nil {
if !config.Quiet && cli != nil && !responseWasStreamed {
// Only display if content wasn't already streamed
// Stop spinner before displaying content
if currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
cli.DisplayAssistantMessageWithModel(content, config.ModelName)
lastDisplayedContent = content
// Start spinner again for tool calls
currentSpinner = ui.NewSpinner("Thinking...")
currentSpinner.Start()
} else if responseWasStreamed {
// Content was already streamed, just track it and manage spinner
lastDisplayedContent = content
if currentSpinner != nil {
currentSpinner.Stop()
currentSpinner = nil
}
// Start spinner again for tool calls
currentSpinner = ui.NewSpinner("Thinking...")
currentSpinner.Start()
}
},
streamingCallback, // Add streaming callback as the last parameter
)
// Make sure spinner is stopped if still running
@@ -743,14 +794,17 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
response := result.FinalResponse
conversationMessages := result.ConversationMessages
// Display assistant response with model name (skip if quiet)
if !config.Quiet && cli != nil {
// Display assistant response with model name
// Skip if: quiet mode, same content already displayed, or if streaming completed the full response
streamedFullResponse := responseWasStreamed && streamingContent.String() == response.Content
if !config.Quiet && cli != nil && response.Content != lastDisplayedContent && response.Content != "" && !streamedFullResponse {
if err := cli.DisplayAssistantMessageWithModel(response.Content, config.ModelName); err != nil {
cli.DisplayError(fmt.Errorf("display error: %v", err))
return nil, nil, err
}
// Update usage tracking with the last user message and response
} else if streamedFullResponse {
// Streaming was used - the message is already displayed in the message component
// Just update usage tracking with the last user message and response
if len(messages) > 0 {
lastUserMessage := ""
// Find the last user message
+119 -20
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
tea "github.com/charmbracelet/bubbletea"
@@ -18,10 +19,11 @@ import (
// AgentConfig is the config for agent.
type AgentConfig struct {
ModelConfig *models.ProviderConfig
MCPConfig *config.Config
SystemPrompt string
MaxSteps int
ModelConfig *models.ProviderConfig
MCPConfig *config.Config
SystemPrompt string
MaxSteps int
StreamingEnabled bool
}
// ToolCallHandler is a function type for handling tool calls as they happen
@@ -36,16 +38,21 @@ type ToolResultHandler func(toolName, toolArgs, result string, isError bool)
// ResponseHandler is a function type for handling LLM responses
type ResponseHandler func(content string)
// StreamingResponseHandler is a function type for handling streaming LLM responses
type StreamingResponseHandler func(content string)
// ToolCallContentHandler is a function type for handling content that accompanies tool calls
type ToolCallContentHandler func(content string)
// Agent is the agent with real-time tool call display.
type Agent struct {
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
loadingMessage string // Message from provider loading (e.g., GPU fallback info)
toolManager *tools.MCPToolManager
model model.ToolCallingChatModel
maxSteps int
systemPrompt string
loadingMessage string // Message from provider loading (e.g., GPU fallback info)
providerType string // Provider type for streaming behavior
streamingEnabled bool // Whether streaming is enabled
}
// NewAgent creates an agent with MCP tool integration and real-time tool call display
@@ -62,12 +69,23 @@ func NewAgent(ctx context.Context, config *AgentConfig) (*Agent, error) {
return nil, fmt.Errorf("failed to load MCP tools: %v", err)
}
// Determine provider type from model string
providerType := "default"
if config.ModelConfig != nil && config.ModelConfig.ModelString != "" {
parts := strings.SplitN(config.ModelConfig.ModelString, ":", 2)
if len(parts) >= 1 {
providerType = parts[0]
}
}
return &Agent{
toolManager: toolManager,
model: providerResult.Model,
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
systemPrompt: config.SystemPrompt,
loadingMessage: providerResult.Message,
toolManager: toolManager,
model: providerResult.Model,
maxSteps: config.MaxSteps, // Keep 0 for infinite, handle in loop
systemPrompt: config.SystemPrompt,
loadingMessage: providerResult.Message,
providerType: providerType,
streamingEnabled: config.StreamingEnabled,
}, nil
}
@@ -81,6 +99,13 @@ type GenerateWithLoopResult struct {
func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult, onResponse, onToolCallContent, nil)
}
// GenerateWithLoopAndStreaming processes messages with a custom loop that displays tool calls in real-time and supports streaming callbacks
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*schema.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, onStreamingResponse StreamingResponseHandler) (*GenerateWithLoopResult, error) {
// Create a copy of messages to avoid modifying the original
workingMessages := make([]*schema.Message, len(messages))
copy(workingMessages, messages)
@@ -125,7 +150,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
}
// Call the LLM with cancellation support
response, err := a.generateWithCancellation(ctx, workingMessages, toolInfos)
response, err := a.generateWithCancellationAndStreaming(ctx, workingMessages, toolInfos, onStreamingResponse)
if err != nil {
return nil, err
}
@@ -133,9 +158,8 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []*schema.Message
// Add response to working messages
workingMessages = append(workingMessages, response)
// Check if this is a tool call or final response
if len(response.ToolCalls) > 0 {
// Check if this is a tool call or final response
if len(response.ToolCalls) > 0 {
// Display any content that accompanies the tool calls
if response.Content != "" && onToolCallContent != nil {
onToolCallContent(response.Content)
@@ -227,8 +251,83 @@ func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
}
// generateWithCancellation calls the LLM with ESC key cancellation support
func (a *Agent) generateWithCancellation(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) {
// generateWithCancellationAndStreaming calls the LLM with ESC key cancellation support and streaming callbacks
func (a *Agent) generateWithCancellationAndStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, streamingCallback StreamingResponseHandler) (*schema.Message, error) {
// Check if streaming is enabled
if !a.streamingEnabled {
// Use traditional non-streaming approach
return a.generateWithoutStreaming(ctx, messages, toolInfos)
}
// Try streaming first if no tools are expected or if we can detect tool calls early
if len(toolInfos) == 0 {
// No tools available, use streaming directly
return a.generateWithStreamingAndCallback(ctx, messages, toolInfos, streamingCallback)
}
// Try streaming with tool call detection
return a.generateWithStreamingFirstAndCallback(ctx, messages, toolInfos, streamingCallback)
}
// generateWithStreamingAndCallback uses streaming for responses without tool calls with real-time callbacks
func (a *Agent) generateWithStreamingAndCallback(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, callback StreamingResponseHandler) (*schema.Message, error) {
// Try streaming first
reader, err := a.model.Stream(ctx, messages, model.WithTools(toolInfos))
if err != nil {
// Fallback to non-streaming if streaming fails
return a.model.Generate(ctx, messages, model.WithTools(toolInfos))
}
// Use streaming with callback for real-time display
response, err := StreamWithCallback(ctx, reader, func(chunk string) {
if callback != nil {
callback(chunk)
}
})
if err != nil {
// Fallback to non-streaming on error
return a.model.Generate(ctx, messages, model.WithTools(toolInfos))
}
// Return the complete streamed response (with tool calls if any)
return response, nil
}
// generateWithStreamingFirstAndCallback attempts streaming first with provider-aware tool call detection and callbacks
func (a *Agent) generateWithStreamingFirstAndCallback(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, callback StreamingResponseHandler) (*schema.Message, error) {
// Try streaming first
reader, err := a.model.Stream(ctx, messages, model.WithTools(toolInfos))
if err != nil {
// Fallback to non-streaming if streaming fails
return a.model.Generate(ctx, messages, model.WithTools(toolInfos))
}
// Use streaming with callback for real-time display
response, err := StreamWithCallback(ctx, reader, func(chunk string) {
if callback != nil {
callback(chunk)
}
})
if err != nil {
// Fallback to non-streaming on error
return a.model.Generate(ctx, messages, model.WithTools(toolInfos))
}
// Return the complete streamed response (with tool calls if any)
// No need to restart - we have everything we need!
return response, nil
}
// generateWithoutStreaming uses the traditional non-streaming approach
func (a *Agent) generateWithoutStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) {
// Create a cancellable context for just this LLM call
llmCtx, cancel := context.WithCancel(ctx)
defer cancel()
+110
View File
@@ -0,0 +1,110 @@
package agent
import (
"context"
"io"
"strings"
"github.com/cloudwego/eino/schema"
)
// StreamWithCallback streams content with real-time callbacks and returns complete response
// IMPORTANT: Tool calls are only processed after EOF is reached to ensure we have the complete
// and final tool call information. This prevents premature tool execution on partial data.
// Handles different provider streaming patterns:
// - Anthropic: Text content first, then tool calls streamed incrementally
// - OpenAI/Others: Tool calls first or alone
// - Mixed: Tool calls and content interleaved
func StreamWithCallback(ctx context.Context, reader *schema.StreamReader[*schema.Message], callback func(string)) (*schema.Message, error) {
defer reader.Close()
var content strings.Builder
var accumulatedToolCalls map[string]*schema.ToolCall // Track tool calls by ID to handle incremental updates
var streamComplete bool
accumulatedToolCalls = make(map[string]*schema.ToolCall)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
msg, err := reader.Recv()
if err == io.EOF {
// Stream is complete - now we can safely process tool calls
streamComplete = true
break
}
if err != nil {
return nil, err
}
// Call callback for each chunk if provided (for real-time display)
if callback != nil && msg.Content != "" {
callback(msg.Content)
}
// Accumulate content from all chunks
content.WriteString(msg.Content)
// Accumulate tool calls incrementally - Anthropic streams them piece by piece
// NOTE: We don't process these tool calls until EOF is reached
if len(msg.ToolCalls) > 0 {
for _, toolCall := range msg.ToolCalls {
// Use tool call ID as key, but handle cases where ID might be empty in partial chunks
key := toolCall.ID
if key == "" {
// For chunks without ID, try to find existing tool call or create a temporary key
if len(accumulatedToolCalls) == 1 {
// If we have exactly one tool call being built, assume this chunk belongs to it
for existingKey := range accumulatedToolCalls {
key = existingKey
break
}
} else {
// Create a temporary key for this tool call
key = "temp_" + toolCall.Function.Name
}
}
existing := accumulatedToolCalls[key]
if existing == nil {
// First time seeing this tool call
accumulatedToolCalls[key] = &schema.ToolCall{
ID: toolCall.ID,
Function: toolCall.Function,
}
} else {
// Update existing tool call with new information
// Preserve non-empty values, accumulate arguments
if toolCall.ID != "" {
existing.ID = toolCall.ID
}
if toolCall.Function.Name != "" {
existing.Function.Name = toolCall.Function.Name
}
// Accumulate arguments (they come in pieces)
existing.Function.Arguments += toolCall.Function.Arguments
}
}
}
}
// Only process tool calls after EOF - ensures we have complete information
var finalToolCalls []schema.ToolCall
if streamComplete && len(accumulatedToolCalls) > 0 {
finalToolCalls = make([]schema.ToolCall, 0, len(accumulatedToolCalls))
for _, toolCall := range accumulatedToolCalls {
finalToolCalls = append(finalToolCalls, *toolCall)
}
}
// Return complete message with all content and final tool calls
return &schema.Message{
Role: schema.Assistant,
Content: content.String(),
ToolCalls: finalToolCalls,
}, nil
}
+15 -15
View File
@@ -143,25 +143,25 @@ func (c *CLI) DisplayToolMessage(toolName, toolArgs, toolResult string, isError
c.displayContainer()
}
// DisplayStreamingMessage displays streaming content
func (c *CLI) DisplayStreamingMessage(reader *schema.StreamReader[*schema.Message]) error {
// For streaming, we'll collect the content and then display it
var content strings.Builder
for {
msg, err := reader.Recv()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("stream receive error: %v", err)
}
content.WriteString(msg.Content)
}
return c.DisplayAssistantMessage(content.String())
// StartStreamingMessage starts a streaming assistant message
func (c *CLI) StartStreamingMessage(modelName string) {
// Add an empty assistant message that we'll update during streaming
msg := c.messageRenderer.RenderAssistantMessage("", time.Now(), modelName)
c.messageContainer.AddMessage(msg)
c.displayContainer()
}
// UpdateStreamingMessage updates the streaming message with new content
func (c *CLI) UpdateStreamingMessage(content string) {
// Update the last message (which should be the streaming assistant message)
c.messageContainer.UpdateLastMessage(content)
c.displayContainer()
}
// DisplayError displays an error message using the message component
func (c *CLI) DisplayError(err error) {
msg := c.messageRenderer.RenderErrorMessage(err.Error(), time.Now())
+18
View File
@@ -486,6 +486,24 @@ func (c *MessageContainer) AddMessage(msg UIMessage) {
c.messages = append(c.messages, msg)
}
// UpdateLastMessage updates the content of the last message efficiently
func (c *MessageContainer) UpdateLastMessage(content string) {
if len(c.messages) == 0 {
return
}
lastIdx := len(c.messages) - 1
lastMsg := &c.messages[lastIdx]
// Only re-render if content actually changed and it's an assistant message
if lastMsg.Type == AssistantMessage {
// Create a new renderer to update the message
renderer := NewMessageRenderer(c.width, false)
newMsg := renderer.RenderAssistantMessage(content, lastMsg.Timestamp, "")
c.messages[lastIdx] = newMsg
}
}
// Clear clears all messages from the container
func (c *MessageContainer) Clear() {
c.messages = make([]UIMessage, 0)