diff --git a/internal/app/app.go b/internal/app/app.go index f5530c6a..ea50881c 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "sync" + "sync/atomic" tea "charm.land/bubbletea/v2" "charm.land/fantasy" @@ -598,9 +599,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M } } - // Subscribe to SDK events for TUI rendering. The subscription is - // temporary — it lives only for the duration of this step. - unsub := a.subscribeSDKEvents(sendFn) + // Subscribe to SDK events for TUI rendering and per-step usage updates. + // The subscription is temporary — it lives only for the duration of this step. + var sawStepUsage atomic.Bool + unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage) defer unsub() // Show spinner while the agent works. @@ -620,8 +622,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M // Sync in-memory store with the SDK's authoritative conversation. a.store.Replace(result.Messages) - // Update usage tracker. - a.updateUsageFromTurnResult(result, prompt) + // Update usage tracker. If per-step usage was already recorded from + // StepUsageEvent callbacks, avoid double-counting totals. + a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load()) return result, nil } @@ -645,9 +648,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func( } } - // Subscribe to SDK events for TUI rendering. The subscription is - // temporary — it lives only for the duration of this step. - unsub := a.subscribeSDKEvents(sendFn) + // Subscribe to SDK events for TUI rendering and per-step usage updates. + // The subscription is temporary — it lives only for the duration of this step. + var sawStepUsage atomic.Bool + unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage) defer unsub() // Show spinner while the agent works. @@ -702,8 +706,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func( // Sync in-memory store with the SDK's authoritative conversation. a.store.Replace(result.Messages) - // Update usage tracker (using last item's prompt for tracking). - a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt) + // Update usage tracker (using last item's prompt for fallback estimation). + // If per-step usage was already recorded from StepUsageEvent callbacks, + // avoid double-counting totals. + a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load()) return result, nil } @@ -720,9 +726,10 @@ func (a *App) sendEvent(msg tea.Msg) { } // subscribeSDKEvents registers temporary SDK event subscribers that convert -// SDK events to tea.Msg events and dispatch them via sendFn. Returns an -// unsubscribe function that removes all listeners. -func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() { +// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen +// is provided, it is set to true after any non-zero StepUsageEvent is observed. +// Returns an unsubscribe function that removes all listeners. +func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() { k := a.opts.Kit var unsubs []func() @@ -756,6 +763,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() { }) case kit.SteerConsumedEvent: sendFn(SteerConsumedEvent{}) + case kit.StepUsageEvent: + a.recordStepUsage(ev, stepUsageSeen) } })) @@ -925,32 +934,56 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) { } } +// recordStepUsage applies token/cost usage reported for a completed step. +// Step usage events arrive even when a turn is later cancelled, so this keeps +// the usage widget accurate on all stop paths. +func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) { + hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0 + if !hasUsage { + return + } + if stepUsageSeen != nil { + stepUsageSeen.Store(true) + } + if a.opts.UsageTracker == nil { + return + } + a.opts.UsageTracker.UpdateUsage( + int(ev.InputTokens), + int(ev.OutputTokens), + int(ev.CacheReadTokens), + int(ev.CacheWriteTokens), + ) + // Keep context fill reasonably fresh during long/partial turns. + a.opts.UsageTracker.SetContextTokens(int(ev.InputTokens + ev.OutputTokens)) +} + // updateUsageFromTurnResult records token usage from an SDK TurnResult into the // configured UsageTracker. Called once per turn after the turn completes. // -// Cost/token accumulation uses TotalUsage (sum across all tool-calling steps in -// the turn). Context-window fill uses FinalUsage.InputTokens only — that is the -// number of tokens sent to the model on the last API call, which equals the -// actual context window occupation (all accumulated messages + tool results). -// OutputTokens are not added here because they are the response length, not -// context fill. -func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) { +// When sawStepUsage is true, totals were already accumulated incrementally via +// StepUsageEvent callbacks; in that case this method only updates context fill. +// Otherwise it falls back to TotalUsage (or estimation) to keep costs/tokens +// visible for providers/modes that don't emit per-step usage. +func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) { if a.opts.UsageTracker == nil || result == nil { return } // --- Accumulate cost/token totals for the session --- - if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 { - a.opts.UsageTracker.UpdateUsage( - int(result.TotalUsage.InputTokens), - int(result.TotalUsage.OutputTokens), - int(result.TotalUsage.CacheReadTokens), - int(result.TotalUsage.CacheCreationTokens), - ) - } else { - // Provider didn't report token counts — fall back to character-based - // estimates so the footer shows something rather than nothing. - a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response) + if !sawStepUsage { + if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 { + a.opts.UsageTracker.UpdateUsage( + int(result.TotalUsage.InputTokens), + int(result.TotalUsage.OutputTokens), + int(result.TotalUsage.CacheReadTokens), + int(result.TotalUsage.CacheCreationTokens), + ) + } else { + // Provider didn't report token counts — fall back to character-based + // estimates so the footer shows something rather than nothing. + a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response) + } } // --- Context window fill (drives the % bar) --- diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 246be5f0..48cddee4 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "charm.land/fantasy" + kit "github.com/mark3labs/kit/pkg/kit" ) @@ -14,6 +16,47 @@ import ( // Helpers // -------------------------------------------------------------------------- +type usageUpdaterStub struct { + mu sync.Mutex + + updateCalls int + estimateCalls int + contextCalls int + + lastUpdateInput int + lastUpdateOutput int + lastUpdateCacheRead int + lastUpdateCacheWrite int + lastContextTokens int + lastEstimateInput string + lastEstimateOutput string +} + +func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) { + s.mu.Lock() + defer s.mu.Unlock() + s.updateCalls++ + s.lastUpdateInput = inputTokens + s.lastUpdateOutput = outputTokens + s.lastUpdateCacheRead = cacheReadTokens + s.lastUpdateCacheWrite = cacheWriteTokens +} + +func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) { + s.mu.Lock() + defer s.mu.Unlock() + s.estimateCalls++ + s.lastEstimateInput = inputText + s.lastEstimateOutput = outputText +} + +func (s *usageUpdaterStub) SetContextTokens(tokens int) { + s.mu.Lock() + defer s.mu.Unlock() + s.contextCalls++ + s.lastContextTokens = tokens +} + // turnResult builds a minimal TurnResult with response text t. func turnResult(t string) *kit.TurnResult { return &kit.TurnResult{Response: t} @@ -489,3 +532,67 @@ func TestQueueLength_reflects(t *testing.T) { t.Fatalf("expected 3, got %d", got) } } + +// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are +// recorded immediately (including context tokens) for stop-path correctness. +func TestRecordStepUsage_updatesTracker(t *testing.T) { + usage := &usageUpdaterStub{} + app := New(Options{UsageTracker: usage}, nil) + defer app.Close() + + app.recordStepUsage(kit.StepUsageEvent{ + InputTokens: 120, + OutputTokens: 45, + CacheReadTokens: 5, + CacheWriteTokens: 2, + }, nil) + + usage.mu.Lock() + defer usage.mu.Unlock() + + if usage.updateCalls != 1 { + t.Fatalf("expected 1 update call, got %d", usage.updateCalls) + } + if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 { + t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d", + usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite) + } + if usage.contextCalls != 1 { + t.Fatalf("expected 1 context token update, got %d", usage.contextCalls) + } + if usage.lastContextTokens != 165 { + t.Fatalf("expected context tokens 165, got %d", usage.lastContextTokens) + } +} + +// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid +// double-counting totals once StepUsageEvent-based updates were already applied. +func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) { + usage := &usageUpdaterStub{} + app := New(Options{UsageTracker: usage}, nil) + defer app.Close() + + app.updateUsageFromTurnResult(&kit.TurnResult{ + Response: "ok", + TotalUsage: &fantasy.Usage{ + InputTokens: 999, + OutputTokens: 111, + CacheReadTokens: 7, + CacheCreationTokens: 3, + }, + FinalUsage: &fantasy.Usage{InputTokens: 456}, + }, "prompt", true) + + usage.mu.Lock() + defer usage.mu.Unlock() + + if usage.updateCalls != 0 { + t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls) + } + if usage.estimateCalls != 0 { + t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls) + } + if usage.contextCalls != 1 || usage.lastContextTokens != 456 { + t.Fatalf("expected final context tokens=456, got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens) + } +} diff --git a/internal/ui/usage_tracker.go b/internal/ui/usage_tracker.go index 1e6ef24f..456a0fe6 100644 --- a/internal/ui/usage_tracker.go +++ b/internal/ui/usage_tracker.go @@ -151,10 +151,6 @@ func (ut *UsageTracker) RenderUsageInfo() string { ut.mu.RLock() defer ut.mu.RUnlock() - if ut.sessionStats.RequestCount == 0 { - return "" - } - baseStyle := lipgloss.NewStyle() // Display the current context window token count (from the last API call), diff --git a/internal/ui/usage_tracker_render_test.go b/internal/ui/usage_tracker_render_test.go index b405ae38..805c3b6f 100644 --- a/internal/ui/usage_tracker_render_test.go +++ b/internal/ui/usage_tracker_render_test.go @@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) { t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered) } } + +func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) { + // Create a mock model info with costs and context limit + modelInfo := &models.ModelInfo{ + ID: "claude-3-5-sonnet-20241022", + Name: "Claude 3.5 Sonnet v2", + Cost: models.Cost{ + Input: 3.0, + Output: 15.0, + }, + Limit: models.Limit{ + Context: 200000, + Output: 8192, + }, + } + + // Test startup state (no requests made yet) - Regular API key + regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false) + rendered := stripAnsi(regularTracker.RenderUsageInfo()) + + // Should NOT return empty string on startup + if rendered == "" { + t.Errorf("Expected non-empty output on startup, got empty string") + } + + // Should show 0 tokens + if !strings.Contains(rendered, "Tokens: 0") { + t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered) + } + + // Should NOT show percentage when tokens are 0 + if strings.Contains(rendered, "(%") { + t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered) + } + + // Should show $0.0000 cost for regular API key + if !strings.Contains(rendered, "Cost: $0.0000") { + t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered) + } + + // Test startup state (no requests made yet) - OAuth + oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true) + oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo()) + + // Should NOT return empty string on startup + if oauthRendered == "" { + t.Errorf("Expected non-empty output on startup for OAuth, got empty string") + } + + // Should show 0 tokens for OAuth + if !strings.Contains(oauthRendered, "Tokens: 0") { + t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered) + } + + // Should show $0.00 cost for OAuth + if !strings.Contains(oauthRendered, "Cost: $0.00") { + t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered) + } +}