diff --git a/cmd/root.go b/cmd/root.go index a15f969e..4a1a7a33 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -794,6 +794,23 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes response := result.FinalResponse conversationMessages := result.ConversationMessages + // Extract the last user message for usage tracking (do this once) + lastUserMessage := "" + if len(messages) > 0 { + // Find the last user message + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == schema.User { + lastUserMessage = messages[i].Content + break + } + } + } + + // Update usage tracking for ALL responses (streaming and non-streaming) + if !config.Quiet && cli != nil { + cli.UpdateUsageFromResponse(response, lastUserMessage) + } + // Display assistant response with model name // Skip if: quiet mode, same content already displayed, or if streaming completed the full response streamedFullResponse := responseWasStreamed && streamingContent.String() == response.Content @@ -802,25 +819,16 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes cli.DisplayError(fmt.Errorf("display error: %v", err)) return nil, nil, err } - } else if streamedFullResponse { - // Streaming was used - the message is already displayed in the message component - // Just update usage tracking with the last user message and response - if len(messages) > 0 { - lastUserMessage := "" - // Find the last user message - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == schema.User { - lastUserMessage = messages[i].Content - break - } - } - cli.UpdateUsageFromResponse(response, lastUserMessage) - } } else if config.Quiet { // In quiet mode, only output the final response content to stdout fmt.Print(response.Content) } + // Display usage information immediately after the response (for both streaming and non-streaming) + if !config.Quiet && cli != nil { + cli.DisplayUsageAfterResponse() + } + // Return the final response and all conversation messages return response, conversationMessages, nil } diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 5cc87a8c..ce0cba43 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -158,8 +158,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []*sc // Add response to working messages workingMessages = append(workingMessages, response) - // Check if this is a tool call or final response - if len(response.ToolCalls) > 0 { + // Check if this is a tool call or final response + if len(response.ToolCalls) > 0 { // Display any content that accompanies the tool calls if response.Content != "" && onToolCallContent != nil { onToolCallContent(response.Content) @@ -251,8 +251,6 @@ func (a *Agent) GetLoadingMessage() string { return a.loadingMessage } - - // generateWithCancellationAndStreaming calls the LLM with ESC key cancellation support and streaming callbacks func (a *Agent) generateWithCancellationAndStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, streamingCallback StreamingResponseHandler) (*schema.Message, error) { // Check if streaming is enabled @@ -271,10 +269,6 @@ func (a *Agent) generateWithCancellationAndStreaming(ctx context.Context, messag return a.generateWithStreamingFirstAndCallback(ctx, messages, toolInfos, streamingCallback) } - - - - // generateWithStreamingAndCallback uses streaming for responses without tool calls with real-time callbacks func (a *Agent) generateWithStreamingAndCallback(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo, callback StreamingResponseHandler) (*schema.Message, error) { // Try streaming first @@ -324,8 +318,6 @@ func (a *Agent) generateWithStreamingFirstAndCallback(ctx context.Context, messa return response, nil } - - // generateWithoutStreaming uses the traditional non-streaming approach func (a *Agent) generateWithoutStreaming(ctx context.Context, messages []*schema.Message, toolInfos []*schema.ToolInfo) (*schema.Message, error) { // Create a cancellable context for just this LLM call diff --git a/internal/agent/streaming.go b/internal/agent/streaming.go index a3dc9698..07c933a5 100644 --- a/internal/agent/streaming.go +++ b/internal/agent/streaming.go @@ -21,6 +21,7 @@ func StreamWithCallback(ctx context.Context, reader *schema.StreamReader[*schema var content strings.Builder var accumulatedToolCalls map[string]*schema.ToolCall // Track tool calls by ID to handle incremental updates var streamComplete bool + var finalResponseMeta *schema.ResponseMeta // Accumulate response metadata from all chunks accumulatedToolCalls = make(map[string]*schema.ToolCall) @@ -49,6 +50,40 @@ func StreamWithCallback(ctx context.Context, reader *schema.StreamReader[*schema // Accumulate content from all chunks content.WriteString(msg.Content) + // Accumulate response metadata - merge from multiple chunks for accuracy + if msg.ResponseMeta != nil { + if finalResponseMeta == nil { + // First metadata we've seen - use as base + finalResponseMeta = &schema.ResponseMeta{} + if msg.ResponseMeta.Usage != nil { + finalResponseMeta.Usage = &schema.TokenUsage{} + } + } + + // Merge metadata intelligently to handle Anthropic's streaming behavior + if msg.ResponseMeta.Usage != nil && finalResponseMeta.Usage != nil { + usage := msg.ResponseMeta.Usage + + // Take PromptTokens from first chunk that has them (usually non-zero) + if finalResponseMeta.Usage.PromptTokens == 0 && usage.PromptTokens > 0 { + finalResponseMeta.Usage.PromptTokens = usage.PromptTokens + } + + // Always take the latest CompletionTokens (accumulates over chunks) + if usage.CompletionTokens > 0 { + finalResponseMeta.Usage.CompletionTokens = usage.CompletionTokens + } + + // Calculate TotalTokens from the components + finalResponseMeta.Usage.TotalTokens = finalResponseMeta.Usage.PromptTokens + finalResponseMeta.Usage.CompletionTokens + } + + // Preserve other metadata fields from the latest chunk + if msg.ResponseMeta.FinishReason != "" { + finalResponseMeta.FinishReason = msg.ResponseMeta.FinishReason + } + } + // Accumulate tool calls incrementally - Anthropic streams them piece by piece // NOTE: We don't process these tool calls until EOF is reached if len(msg.ToolCalls) > 0 { @@ -101,10 +136,11 @@ func StreamWithCallback(ctx context.Context, reader *schema.StreamReader[*schema } } - // Return complete message with all content and final tool calls + // Return complete message with all content, final tool calls, and preserved metadata return &schema.Message{ - Role: schema.Assistant, - Content: content.String(), - ToolCalls: finalToolCalls, + Role: schema.Assistant, + Content: content.String(), + ToolCalls: finalToolCalls, + ResponseMeta: finalResponseMeta, // Preserve usage and other metadata from streaming }, nil -} \ No newline at end of file +} diff --git a/internal/tools/mcp_test.go b/internal/tools/mcp_test.go index f0d9e1a8..9b4e72f8 100644 --- a/internal/tools/mcp_test.go +++ b/internal/tools/mcp_test.go @@ -167,4 +167,4 @@ func contains(s, substr string) bool { } } return false -} \ No newline at end of file +} diff --git a/internal/ui/cli.go b/internal/ui/cli.go index e670087a..73061cc4 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -47,16 +47,8 @@ func (c *CLI) SetUsageTracker(tracker *UsageTracker) { // GetPrompt gets user input using the huh library with divider and padding func (c *CLI) GetPrompt() (string, error) { - // Display usage info if available - if c.usageTracker != nil { - usageInfo := c.usageTracker.RenderUsageInfo() - if usageInfo != "" { - paddedUsage := lipgloss.NewStyle(). - PaddingLeft(2). - Render(usageInfo) - fmt.Print(paddedUsage) - } - } + // Usage info is now displayed immediately after responses via DisplayUsageAfterResponse() + // No need to display it here to avoid duplication // Create an enhanced divider with gradient effect theme := GetTheme() @@ -357,11 +349,18 @@ func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string inputTokens := int(usage.PromptTokens) outputTokens := int(usage.CompletionTokens) - // Handle cache tokens if available (some providers support this) - cacheReadTokens := 0 - cacheWriteTokens := 0 + // Validate that the metadata seems reasonable + // If token counts are 0 or seem unrealistic, fall back to estimation + if inputTokens > 0 && outputTokens > 0 { + // Handle cache tokens if available (some providers support this) + cacheReadTokens := 0 + cacheWriteTokens := 0 - c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens) + c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens) + } else { + // Metadata exists but seems incomplete/unreliable, use estimation + c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content) + } } else { // Fallback to estimation if no metadata is available c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content) @@ -405,6 +404,22 @@ func (c *CLI) ResetUsageStats() { c.DisplayInfo("Usage statistics have been reset.") } +// DisplayUsageAfterResponse displays usage information immediately after a response +func (c *CLI) DisplayUsageAfterResponse() { + if c.usageTracker == nil { + return + } + + usageInfo := c.usageTracker.RenderUsageInfo() + if usageInfo != "" { + paddedUsage := lipgloss.NewStyle(). + PaddingLeft(2). + PaddingTop(1). + Render(usageInfo) + fmt.Print(paddedUsage) + } +} + // updateSize updates the CLI size based on terminal dimensions func (c *CLI) updateSize() { width, height, err := term.GetSize(int(os.Stdout.Fd()))