diff --git a/cmd/root.go b/cmd/root.go index 982b6232..3f978c76 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -538,7 +538,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes break } } - cli.UpdateUsage(lastUserMessage, response.Content) + cli.UpdateUsageFromResponse(response, lastUserMessage) } } else if config.Quiet { // In quiet mode, only output the final response content to stdout diff --git a/internal/models/providers.go b/internal/models/providers.go index bd63db24..8f4f2aa9 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -87,9 +87,9 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall // validateModelConfig validates configuration parameters against model capabilities func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error { - // Check if temperature is supported + // Omit temperature if not supported by the model if config.Temperature != nil && !modelInfo.Temperature { - return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID) + config.Temperature = nil } // Warn about context limits if MaxTokens is set too high @@ -162,16 +162,35 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName openaiConfig.BaseURL = config.ProviderURL } + // Check if this is a reasoning model to handle beta limitations + registry := GetGlobalRegistry() + isReasoningModel := false + if modelInfo, err := registry.ValidateModel("openai", modelName); err == nil && modelInfo.Reasoning { + isReasoningModel = true + } + if config.MaxTokens > 0 { - openaiConfig.MaxTokens = &config.MaxTokens + if isReasoningModel { + // For reasoning models, use MaxCompletionTokens instead of MaxTokens + if openaiConfig.ExtraFields == nil { + openaiConfig.ExtraFields = make(map[string]any) + } + openaiConfig.ExtraFields["max_completion_tokens"] = config.MaxTokens + } else { + // For non-reasoning models, use MaxTokens as usual + openaiConfig.MaxTokens = &config.MaxTokens + } } - if config.Temperature != nil { - openaiConfig.Temperature = config.Temperature - } + // For reasoning models, skip temperature and top_p due to beta limitations + if !isReasoningModel { + if config.Temperature != nil { + openaiConfig.Temperature = config.Temperature + } - if config.TopP != nil { - openaiConfig.TopP = config.TopP + if config.TopP != nil { + openaiConfig.TopP = config.TopP + } } if len(config.StopSequences) > 0 { diff --git a/internal/tokens/anthropic.go b/internal/tokens/anthropic.go index c7190954..92d07b8c 100644 --- a/internal/tokens/anthropic.go +++ b/internal/tokens/anthropic.go @@ -1,113 +1 @@ -package tokens - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" -) - -// AnthropicTokenCounter implements token counting for Anthropic models -type AnthropicTokenCounter struct { - apiKey string - httpClient *http.Client -} - -// NewAnthropicTokenCounter creates a new Anthropic token counter -func NewAnthropicTokenCounter(apiKey string) *AnthropicTokenCounter { - return &AnthropicTokenCounter{ - apiKey: apiKey, - httpClient: &http.Client{ - Timeout: 10 * time.Second, - }, - } -} - -// AnthropicTokenRequest represents the request payload for Anthropic token counting -type AnthropicTokenRequest struct { - Messages []Message `json:"messages"` - Model string `json:"model"` -} - -// AnthropicTokenResponse represents the response from Anthropic token counting API -type AnthropicTokenResponse struct { - InputTokens int `json:"input_tokens"` -} - -// CountTokens counts tokens using Anthropic's token counting API -func (a *AnthropicTokenCounter) CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error) { - if a.apiKey == "" { - return nil, fmt.Errorf("anthropic API key not provided") - } - - // Strip the anthropic: prefix if present - actualModel := model - if strings.HasPrefix(model, "anthropic:") { - actualModel = strings.TrimPrefix(model, "anthropic:") - } - - // Prepare request payload - request := AnthropicTokenRequest{ - Messages: messages, - Model: actualModel, - } - - jsonData, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("x-api-key", a.apiKey) - req.Header.Set("anthropic-version", "2023-06-01") - - // Make the request - resp, err := a.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - defer resp.Body.Close() - - // Read response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } - - // Check for HTTP errors - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Parse response - var tokenResponse AnthropicTokenResponse - if err := json.Unmarshal(body, &tokenResponse); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - return &TokenCount{ - InputTokens: tokenResponse.InputTokens, - }, nil -} - -// SupportsModel returns true if this counter supports the given model -func (a *AnthropicTokenCounter) SupportsModel(model string) bool { - // Support all Anthropic models - return strings.HasPrefix(model, "anthropic:") -} - -// ProviderName returns the name of the provider -func (a *AnthropicTokenCounter) ProviderName() string { - return "anthropic" -} \ No newline at end of file +package tokens \ No newline at end of file diff --git a/internal/tokens/counter.go b/internal/tokens/counter.go index 2539da2f..c2f2c109 100644 --- a/internal/tokens/counter.go +++ b/internal/tokens/counter.go @@ -1,99 +1,7 @@ package tokens -import ( - "context" -) - -// Message represents a message for token counting -type Message struct { - Role string `json:"role"` - Content string `json:"content"` -} - -// TokenCount represents the result of token counting -type TokenCount struct { - InputTokens int `json:"input_tokens"` -} - -// TokenCounter interface for provider-specific token counting -type TokenCounter interface { - // CountTokens counts tokens for the given messages and model - CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error) - // SupportsModel returns true if this counter supports the given model - SupportsModel(model string) bool - // ProviderName returns the name of the provider this counter is for - ProviderName() string -} - // EstimateTokens provides a rough estimate of tokens in text -// This is a fallback when no provider-specific counter is available func EstimateTokens(text string) int { // Rough approximation: ~4 characters per token for most models - // This is not accurate but gives a reasonable estimate return len(text) / 4 -} - -// EstimateTokensFromMessages estimates tokens from a slice of messages -func EstimateTokensFromMessages(messages []Message) int { - totalChars := 0 - for _, msg := range messages { - totalChars += len(msg.Content) - totalChars += len(msg.Role) + 10 // Add some overhead for role and formatting - } - return EstimateTokens(string(rune(totalChars))) -} - -// Registry holds all registered token counters -type Registry struct { - counters map[string]TokenCounter -} - -// NewRegistry creates a new token counter registry -func NewRegistry() *Registry { - return &Registry{ - counters: make(map[string]TokenCounter), - } -} - -// Register adds a token counter to the registry -func (r *Registry) Register(counter TokenCounter) { - r.counters[counter.ProviderName()] = counter -} - -// GetCounter returns a token counter for the given provider -func (r *Registry) GetCounter(provider string) (TokenCounter, bool) { - counter, exists := r.counters[provider] - return counter, exists -} - -// CountTokens attempts to count tokens using a provider-specific counter, -// falling back to estimation if no counter is available -func (r *Registry) CountTokens(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) { - if counter, exists := r.GetCounter(provider); exists && counter.SupportsModel(model) { - return counter.CountTokens(ctx, messages, model) - } - - // Fallback to estimation - estimatedTokens := EstimateTokensFromMessages(messages) - return &TokenCount{ - InputTokens: estimatedTokens, - }, nil -} - -// Global registry instance -var globalRegistry = NewRegistry() - -// GetGlobalRegistry returns the global token counter registry -func GetGlobalRegistry() *Registry { - return globalRegistry -} - -// RegisterCounter registers a token counter with the global registry -func RegisterCounter(counter TokenCounter) { - globalRegistry.Register(counter) -} - -// CountTokensGlobal counts tokens using the global registry -func CountTokensGlobal(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) { - return globalRegistry.CountTokens(ctx, provider, messages, model) } \ No newline at end of file diff --git a/internal/tokens/init.go b/internal/tokens/init.go index dc6700b5..e79b29a4 100644 --- a/internal/tokens/init.go +++ b/internal/tokens/init.go @@ -1,20 +1,11 @@ package tokens -import ( - "os" -) - // InitializeTokenCounters registers all available token counters func InitializeTokenCounters() { - // Register Anthropic token counter if API key is available - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - RegisterCounter(NewAnthropicTokenCounter(apiKey)) - } + // Future provider-specific counters can be registered here } // InitializeTokenCountersWithKeys registers token counters with provided API keys -func InitializeTokenCountersWithKeys(anthropicKey string) { - if anthropicKey != "" { - RegisterCounter(NewAnthropicTokenCounter(anthropicKey)) - } +func InitializeTokenCountersWithKeys() { + // Future provider-specific counters can be registered here } \ No newline at end of file diff --git a/internal/ui/cli.go b/internal/ui/cli.go index f3a52d93..1434261c 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -1,7 +1,6 @@ package ui import ( - "context" "errors" "fmt" "io" @@ -12,7 +11,6 @@ import ( "github.com/charmbracelet/huh" "github.com/charmbracelet/lipgloss" "github.com/cloudwego/eino/schema" - "github.com/mark3labs/mcphost/internal/tokens" "golang.org/x/term" ) @@ -324,10 +322,30 @@ func (c *CLI) UpdateUsage(inputText, outputText string) { } } -// UpdateUsageWithMessages updates the usage tracker using custom token counting for messages -func (c *CLI) UpdateUsageWithMessages(ctx context.Context, messages []tokens.Message, outputText string) { - if c.usageTracker != nil { - c.usageTracker.CountAndUpdateUsage(ctx, messages, outputText) + + +// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata +func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) { + if c.usageTracker == nil { + return + } + + // Try to extract token usage from response metadata + if response.ResponseMeta != nil && response.ResponseMeta.Usage != nil { + usage := response.ResponseMeta.Usage + + // Use actual token counts from the response + inputTokens := int(usage.PromptTokens) + outputTokens := int(usage.CompletionTokens) + + // Handle cache tokens if available (some providers support this) + cacheReadTokens := 0 + cacheWriteTokens := 0 + + c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens) + } else { + // Fallback to estimation if no metadata is available + c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content) } } diff --git a/internal/ui/usage_tracker.go b/internal/ui/usage_tracker.go index 4d2aa7fb..c905fae9 100644 --- a/internal/ui/usage_tracker.go +++ b/internal/ui/usage_tracker.go @@ -1,7 +1,6 @@ package ui import ( - "context" "fmt" "sync" @@ -102,30 +101,15 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, // EstimateAndUpdateUsage estimates tokens from text and updates usage func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) { - inputTokens := EstimateTokens(inputText) - outputTokens := EstimateTokens(outputText) + inputTokens := tokens.EstimateTokens(inputText) + outputTokens := tokens.EstimateTokens(outputText) ut.UpdateUsage(inputTokens, outputTokens, 0, 0) } -// CountAndUpdateUsage counts tokens using provider-specific counters and updates usage -func (ut *UsageTracker) CountAndUpdateUsage(ctx context.Context, messages []tokens.Message, outputText string) { - // Count input tokens using provider-specific counter - tokenCount, err := tokens.CountTokensGlobal(ctx, ut.provider, messages, ut.modelInfo.ID) - var inputTokens int - if err != nil { - // Fallback to estimation if token counting fails - var totalInput string - for _, msg := range messages { - totalInput += msg.Content - } - inputTokens = EstimateTokens(totalInput) - } else { - inputTokens = tokenCount.InputTokens - } - - // Estimate output tokens (providers typically don't count output tokens separately) - outputTokens := EstimateTokens(outputText) - +// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage +func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) { + inputTokens := tokens.EstimateTokens(inputText) + outputTokens := tokens.EstimateTokens(outputText) ut.UpdateUsage(inputTokens, outputTokens, 0, 0) }