mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Fix usage widget startup visibility and stop-path updates
This commit is contained in:
+64
-31
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -598,9 +599,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -620,8 +622,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -645,9 +648,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -702,8 +706,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -720,9 +726,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -756,6 +763,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -925,32 +934,56 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// Keep context fill reasonably fresh during long/partial turns.
|
||||
a.opts.UsageTracker.SetContextTokens(int(ev.InputTokens + ev.OutputTokens))
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// Cost/token accumulation uses TotalUsage (sum across all tool-calling steps in
|
||||
// the turn). Context-window fill uses FinalUsage.InputTokens only — that is the
|
||||
// number of tokens sent to the model on the last API call, which equals the
|
||||
// actual context window occupation (all accumulated messages + tool results).
|
||||
// OutputTokens are not added here because they are the response length, not
|
||||
// context fill.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage (or estimation) to keep costs/tokens
|
||||
// visible for providers/modes that don't emit per-step usage.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
if !sawStepUsage {
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -14,6 +16,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +532,67 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately (including context tokens) for stop-path correctness.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
if usage.contextCalls != 1 {
|
||||
t.Fatalf("expected 1 context token update, got %d", usage.contextCalls)
|
||||
}
|
||||
if usage.lastContextTokens != 165 {
|
||||
t.Fatalf("expected context tokens 165, got %d", usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456, got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,10 +151,6 @@ func (ut *UsageTracker) RenderUsageInfo() string {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if ut.sessionStats.RequestCount == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle()
|
||||
|
||||
// Display the current context window token count (from the last API call),
|
||||
|
||||
@@ -67,3 +67,62 @@ func TestUsageTracker_RenderUsageInfo_OAuth(t *testing.T) {
|
||||
t.Errorf("Expected regular rendered output to show actual cost, got: %s", regularRendered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTracker_RenderUsageInfo_StartupState(t *testing.T) {
|
||||
// Create a mock model info with costs and context limit
|
||||
modelInfo := &models.ModelInfo{
|
||||
ID: "claude-3-5-sonnet-20241022",
|
||||
Name: "Claude 3.5 Sonnet v2",
|
||||
Cost: models.Cost{
|
||||
Input: 3.0,
|
||||
Output: 15.0,
|
||||
},
|
||||
Limit: models.Limit{
|
||||
Context: 200000,
|
||||
Output: 8192,
|
||||
},
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - Regular API key
|
||||
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
|
||||
rendered := stripAnsi(regularTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if rendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens
|
||||
if !strings.Contains(rendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should NOT show percentage when tokens are 0
|
||||
if strings.Contains(rendered, "(%") {
|
||||
t.Errorf("Expected no percentage on startup with 0 tokens, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Should show $0.0000 cost for regular API key
|
||||
if !strings.Contains(rendered, "Cost: $0.0000") {
|
||||
t.Errorf("Expected 'Cost: $0.0000' on startup, got: %s", rendered)
|
||||
}
|
||||
|
||||
// Test startup state (no requests made yet) - OAuth
|
||||
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
|
||||
oauthRendered := stripAnsi(oauthTracker.RenderUsageInfo())
|
||||
|
||||
// Should NOT return empty string on startup
|
||||
if oauthRendered == "" {
|
||||
t.Errorf("Expected non-empty output on startup for OAuth, got empty string")
|
||||
}
|
||||
|
||||
// Should show 0 tokens for OAuth
|
||||
if !strings.Contains(oauthRendered, "Tokens: 0") {
|
||||
t.Errorf("Expected 'Tokens: 0' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
|
||||
// Should show $0.00 cost for OAuth
|
||||
if !strings.Contains(oauthRendered, "Cost: $0.00") {
|
||||
t.Errorf("Expected 'Cost: $0.00' on startup for OAuth, got: %s", oauthRendered)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user