Files
kit/internal/ui/usage_tracker_test.go
T
Ed Zynda a05da5f3ab fix(auth): support OAuth credentials in ACP mode and auto-refresh tokens
Remove the early ValidateEnvironment gate from CreateProvider that only
checked env vars and --provider-api-key, blocking stored OAuth credentials
from working. Each provider creation function already handles its own auth
resolution with clear error messages.

Update ValidateEnvironment to also check stored Anthropic credentials so
the model selector UI correctly shows Anthropic models for OAuth users.

Add automatic token refresh in oauthTransport so long-lived ACP sessions
survive token renewals. Surface actionable auth error messages in ACP
session creation.

Fix pre-existing staticcheck SA5011 warnings in test files.
2026-03-15 12:38:23 +03:00

113 lines
3.5 KiB
Go

package ui
import (
"testing"
"github.com/mark3labs/kit/internal/models"
)
func TestUsageTracker_OAuthCosts(t *testing.T) {
// Create a mock model info with costs
modelInfo := &models.ModelInfo{
ID: "claude-3-5-sonnet-20241022",
Name: "Claude 3.5 Sonnet v2",
Cost: models.Cost{
Input: 3.0,
Output: 15.0,
},
}
// Test with regular API key (costs should be calculated)
regularTracker := NewUsageTracker(modelInfo, "anthropic", 80, false)
regularTracker.UpdateUsage(1000, 500, 0, 0) // 1000 input, 500 output tokens
stats := regularTracker.GetLastRequestStats()
if stats == nil {
t.Fatal("Expected stats to be non-nil")
return
}
// Check that costs are calculated for regular API key
expectedInputCost := float64(1000) * 3.0 / 1000000 // $0.003
expectedOutputCost := float64(500) * 15.0 / 1000000 // $0.0075
expectedTotalCost := expectedInputCost + expectedOutputCost // $0.0105
if stats.InputCost != expectedInputCost {
t.Errorf("Expected input cost %f, got %f", expectedInputCost, stats.InputCost)
}
if stats.OutputCost != expectedOutputCost {
t.Errorf("Expected output cost %f, got %f", expectedOutputCost, stats.OutputCost)
}
if stats.TotalCost != expectedTotalCost {
t.Errorf("Expected total cost %f, got %f", expectedTotalCost, stats.TotalCost)
}
// Test with OAuth credentials (costs should be $0)
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
oauthTracker.UpdateUsage(1000, 500, 0, 0) // Same token usage
oauthStats := oauthTracker.GetLastRequestStats()
if oauthStats == nil {
t.Fatal("Expected OAuth stats to be non-nil")
return
}
// Check that all costs are $0 for OAuth
if oauthStats.InputCost != 0.0 {
t.Errorf("Expected OAuth input cost to be $0, got %f", oauthStats.InputCost)
}
if oauthStats.OutputCost != 0.0 {
t.Errorf("Expected OAuth output cost to be $0, got %f", oauthStats.OutputCost)
}
if oauthStats.TotalCost != 0.0 {
t.Errorf("Expected OAuth total cost to be $0, got %f", oauthStats.TotalCost)
}
// Verify token counts are still tracked correctly for OAuth
if oauthStats.InputTokens != 1000 {
t.Errorf("Expected OAuth input tokens to be 1000, got %d", oauthStats.InputTokens)
}
if oauthStats.OutputTokens != 500 {
t.Errorf("Expected OAuth output tokens to be 500, got %d", oauthStats.OutputTokens)
}
}
func TestUsageTracker_OAuthSessionStats(t *testing.T) {
// Create a mock model info with costs
modelInfo := &models.ModelInfo{
ID: "claude-3-5-sonnet-20241022",
Name: "Claude 3.5 Sonnet v2",
Cost: models.Cost{
Input: 3.0,
Output: 15.0,
},
}
// Test OAuth session stats accumulation
oauthTracker := NewUsageTracker(modelInfo, "anthropic", 80, true)
// Make multiple requests
oauthTracker.UpdateUsage(1000, 500, 0, 0)
oauthTracker.UpdateUsage(2000, 1000, 0, 0)
sessionStats := oauthTracker.GetSessionStats()
// Check that tokens are accumulated correctly
if sessionStats.TotalInputTokens != 3000 {
t.Errorf("Expected total input tokens to be 3000, got %d", sessionStats.TotalInputTokens)
}
if sessionStats.TotalOutputTokens != 1500 {
t.Errorf("Expected total output tokens to be 1500, got %d", sessionStats.TotalOutputTokens)
}
// Check that total cost remains $0 for OAuth
if sessionStats.TotalCost != 0.0 {
t.Errorf("Expected OAuth session total cost to be $0, got %f", sessionStats.TotalCost)
}
// Check request count
if sessionStats.RequestCount != 2 {
t.Errorf("Expected request count to be 2, got %d", sessionStats.RequestCount)
}
}