better token tracking and support for more openai models

This commit is contained in:
Ed Zynda
2025-06-18 15:19:04 +03:00
parent cda80f1572
commit 3faf46ff44
7 changed files with 62 additions and 254 deletions
+1 -1
View File
@@ -538,7 +538,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
break
}
}
cli.UpdateUsage(lastUserMessage, response.Content)
cli.UpdateUsageFromResponse(response, lastUserMessage)
}
} else if config.Quiet {
// In quiet mode, only output the final response content to stdout
+27 -8
View File
@@ -87,9 +87,9 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
// validateModelConfig validates configuration parameters against model capabilities
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
// Check if temperature is supported
// Omit temperature if not supported by the model
if config.Temperature != nil && !modelInfo.Temperature {
return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID)
config.Temperature = nil
}
// Warn about context limits if MaxTokens is set too high
@@ -162,16 +162,35 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
openaiConfig.BaseURL = config.ProviderURL
}
// Check if this is a reasoning model to handle beta limitations
registry := GetGlobalRegistry()
isReasoningModel := false
if modelInfo, err := registry.ValidateModel("openai", modelName); err == nil && modelInfo.Reasoning {
isReasoningModel = true
}
if config.MaxTokens > 0 {
openaiConfig.MaxTokens = &config.MaxTokens
if isReasoningModel {
// For reasoning models, use MaxCompletionTokens instead of MaxTokens
if openaiConfig.ExtraFields == nil {
openaiConfig.ExtraFields = make(map[string]any)
}
openaiConfig.ExtraFields["max_completion_tokens"] = config.MaxTokens
} else {
// For non-reasoning models, use MaxTokens as usual
openaiConfig.MaxTokens = &config.MaxTokens
}
}
if config.Temperature != nil {
openaiConfig.Temperature = config.Temperature
}
// For reasoning models, skip temperature and top_p due to beta limitations
if !isReasoningModel {
if config.Temperature != nil {
openaiConfig.Temperature = config.Temperature
}
if config.TopP != nil {
openaiConfig.TopP = config.TopP
if config.TopP != nil {
openaiConfig.TopP = config.TopP
}
}
if len(config.StopSequences) > 0 {
+1 -113
View File
@@ -1,113 +1 @@
package tokens
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// AnthropicTokenCounter implements token counting for Anthropic models
type AnthropicTokenCounter struct {
apiKey string
httpClient *http.Client
}
// NewAnthropicTokenCounter creates a new Anthropic token counter
func NewAnthropicTokenCounter(apiKey string) *AnthropicTokenCounter {
return &AnthropicTokenCounter{
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// AnthropicTokenRequest represents the request payload for Anthropic token counting
type AnthropicTokenRequest struct {
Messages []Message `json:"messages"`
Model string `json:"model"`
}
// AnthropicTokenResponse represents the response from Anthropic token counting API
type AnthropicTokenResponse struct {
InputTokens int `json:"input_tokens"`
}
// CountTokens counts tokens using Anthropic's token counting API
func (a *AnthropicTokenCounter) CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error) {
if a.apiKey == "" {
return nil, fmt.Errorf("anthropic API key not provided")
}
// Strip the anthropic: prefix if present
actualModel := model
if strings.HasPrefix(model, "anthropic:") {
actualModel = strings.TrimPrefix(model, "anthropic:")
}
// Prepare request payload
request := AnthropicTokenRequest{
Messages: messages,
Model: actualModel,
}
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewReader(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-api-key", a.apiKey)
req.Header.Set("anthropic-version", "2023-06-01")
// Make the request
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
// Read response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
// Check for HTTP errors
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var tokenResponse AnthropicTokenResponse
if err := json.Unmarshal(body, &tokenResponse); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &TokenCount{
InputTokens: tokenResponse.InputTokens,
}, nil
}
// SupportsModel returns true if this counter supports the given model
func (a *AnthropicTokenCounter) SupportsModel(model string) bool {
// Support all Anthropic models
return strings.HasPrefix(model, "anthropic:")
}
// ProviderName returns the name of the provider
func (a *AnthropicTokenCounter) ProviderName() string {
return "anthropic"
}
package tokens
-92
View File
@@ -1,99 +1,7 @@
package tokens
import (
"context"
)
// Message represents a message for token counting
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// TokenCount represents the result of token counting
type TokenCount struct {
InputTokens int `json:"input_tokens"`
}
// TokenCounter interface for provider-specific token counting
type TokenCounter interface {
// CountTokens counts tokens for the given messages and model
CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error)
// SupportsModel returns true if this counter supports the given model
SupportsModel(model string) bool
// ProviderName returns the name of the provider this counter is for
ProviderName() string
}
// EstimateTokens provides a rough estimate of tokens in text
// This is a fallback when no provider-specific counter is available
func EstimateTokens(text string) int {
// Rough approximation: ~4 characters per token for most models
// This is not accurate but gives a reasonable estimate
return len(text) / 4
}
// EstimateTokensFromMessages estimates tokens from a slice of messages
func EstimateTokensFromMessages(messages []Message) int {
totalChars := 0
for _, msg := range messages {
totalChars += len(msg.Content)
totalChars += len(msg.Role) + 10 // Add some overhead for role and formatting
}
return EstimateTokens(string(rune(totalChars)))
}
// Registry holds all registered token counters
type Registry struct {
counters map[string]TokenCounter
}
// NewRegistry creates a new token counter registry
func NewRegistry() *Registry {
return &Registry{
counters: make(map[string]TokenCounter),
}
}
// Register adds a token counter to the registry
func (r *Registry) Register(counter TokenCounter) {
r.counters[counter.ProviderName()] = counter
}
// GetCounter returns a token counter for the given provider
func (r *Registry) GetCounter(provider string) (TokenCounter, bool) {
counter, exists := r.counters[provider]
return counter, exists
}
// CountTokens attempts to count tokens using a provider-specific counter,
// falling back to estimation if no counter is available
func (r *Registry) CountTokens(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
if counter, exists := r.GetCounter(provider); exists && counter.SupportsModel(model) {
return counter.CountTokens(ctx, messages, model)
}
// Fallback to estimation
estimatedTokens := EstimateTokensFromMessages(messages)
return &TokenCount{
InputTokens: estimatedTokens,
}, nil
}
// Global registry instance
var globalRegistry = NewRegistry()
// GetGlobalRegistry returns the global token counter registry
func GetGlobalRegistry() *Registry {
return globalRegistry
}
// RegisterCounter registers a token counter with the global registry
func RegisterCounter(counter TokenCounter) {
globalRegistry.Register(counter)
}
// CountTokensGlobal counts tokens using the global registry
func CountTokensGlobal(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
return globalRegistry.CountTokens(ctx, provider, messages, model)
}
+3 -12
View File
@@ -1,20 +1,11 @@
package tokens
import (
"os"
)
// InitializeTokenCounters registers all available token counters
func InitializeTokenCounters() {
// Register Anthropic token counter if API key is available
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
RegisterCounter(NewAnthropicTokenCounter(apiKey))
}
// Future provider-specific counters can be registered here
}
// InitializeTokenCountersWithKeys registers token counters with provided API keys
func InitializeTokenCountersWithKeys(anthropicKey string) {
if anthropicKey != "" {
RegisterCounter(NewAnthropicTokenCounter(anthropicKey))
}
func InitializeTokenCountersWithKeys() {
// Future provider-specific counters can be registered here
}
+24 -6
View File
@@ -1,7 +1,6 @@
package ui
import (
"context"
"errors"
"fmt"
"io"
@@ -12,7 +11,6 @@ import (
"github.com/charmbracelet/huh"
"github.com/charmbracelet/lipgloss"
"github.com/cloudwego/eino/schema"
"github.com/mark3labs/mcphost/internal/tokens"
"golang.org/x/term"
)
@@ -324,10 +322,30 @@ func (c *CLI) UpdateUsage(inputText, outputText string) {
}
}
// UpdateUsageWithMessages updates the usage tracker using custom token counting for messages
func (c *CLI) UpdateUsageWithMessages(ctx context.Context, messages []tokens.Message, outputText string) {
if c.usageTracker != nil {
c.usageTracker.CountAndUpdateUsage(ctx, messages, outputText)
// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata
func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) {
if c.usageTracker == nil {
return
}
// Try to extract token usage from response metadata
if response.ResponseMeta != nil && response.ResponseMeta.Usage != nil {
usage := response.ResponseMeta.Usage
// Use actual token counts from the response
inputTokens := int(usage.PromptTokens)
outputTokens := int(usage.CompletionTokens)
// Handle cache tokens if available (some providers support this)
cacheReadTokens := 0
cacheWriteTokens := 0
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
} else {
// Fallback to estimation if no metadata is available
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content)
}
}
+6 -22
View File
@@ -1,7 +1,6 @@
package ui
import (
"context"
"fmt"
"sync"
@@ -102,30 +101,15 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens,
// EstimateAndUpdateUsage estimates tokens from text and updates usage
func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
inputTokens := EstimateTokens(inputText)
outputTokens := EstimateTokens(outputText)
inputTokens := tokens.EstimateTokens(inputText)
outputTokens := tokens.EstimateTokens(outputText)
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
}
// CountAndUpdateUsage counts tokens using provider-specific counters and updates usage
func (ut *UsageTracker) CountAndUpdateUsage(ctx context.Context, messages []tokens.Message, outputText string) {
// Count input tokens using provider-specific counter
tokenCount, err := tokens.CountTokensGlobal(ctx, ut.provider, messages, ut.modelInfo.ID)
var inputTokens int
if err != nil {
// Fallback to estimation if token counting fails
var totalInput string
for _, msg := range messages {
totalInput += msg.Content
}
inputTokens = EstimateTokens(totalInput)
} else {
inputTokens = tokenCount.InputTokens
}
// Estimate output tokens (providers typically don't count output tokens separately)
outputTokens := EstimateTokens(outputText)
// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage
func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) {
inputTokens := tokens.EstimateTokens(inputText)
outputTokens := tokens.EstimateTokens(outputText)
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
}