mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
better token tracking and support for more openai models
This commit is contained in:
+1
-1
@@ -538,7 +538,7 @@ func runAgenticStep(ctx context.Context, mcpAgent *agent.Agent, cli *ui.CLI, mes
|
||||
break
|
||||
}
|
||||
}
|
||||
cli.UpdateUsage(lastUserMessage, response.Content)
|
||||
cli.UpdateUsageFromResponse(response, lastUserMessage)
|
||||
}
|
||||
} else if config.Quiet {
|
||||
// In quiet mode, only output the final response content to stdout
|
||||
|
||||
@@ -87,9 +87,9 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (model.ToolCall
|
||||
|
||||
// validateModelConfig validates configuration parameters against model capabilities
|
||||
func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) error {
|
||||
// Check if temperature is supported
|
||||
// Omit temperature if not supported by the model
|
||||
if config.Temperature != nil && !modelInfo.Temperature {
|
||||
return fmt.Errorf("model %s does not support temperature parameter", modelInfo.ID)
|
||||
config.Temperature = nil
|
||||
}
|
||||
|
||||
// Warn about context limits if MaxTokens is set too high
|
||||
@@ -162,16 +162,35 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
openaiConfig.BaseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
// Check if this is a reasoning model to handle beta limitations
|
||||
registry := GetGlobalRegistry()
|
||||
isReasoningModel := false
|
||||
if modelInfo, err := registry.ValidateModel("openai", modelName); err == nil && modelInfo.Reasoning {
|
||||
isReasoningModel = true
|
||||
}
|
||||
|
||||
if config.MaxTokens > 0 {
|
||||
openaiConfig.MaxTokens = &config.MaxTokens
|
||||
if isReasoningModel {
|
||||
// For reasoning models, use MaxCompletionTokens instead of MaxTokens
|
||||
if openaiConfig.ExtraFields == nil {
|
||||
openaiConfig.ExtraFields = make(map[string]any)
|
||||
}
|
||||
openaiConfig.ExtraFields["max_completion_tokens"] = config.MaxTokens
|
||||
} else {
|
||||
// For non-reasoning models, use MaxTokens as usual
|
||||
openaiConfig.MaxTokens = &config.MaxTokens
|
||||
}
|
||||
}
|
||||
|
||||
if config.Temperature != nil {
|
||||
openaiConfig.Temperature = config.Temperature
|
||||
}
|
||||
// For reasoning models, skip temperature and top_p due to beta limitations
|
||||
if !isReasoningModel {
|
||||
if config.Temperature != nil {
|
||||
openaiConfig.Temperature = config.Temperature
|
||||
}
|
||||
|
||||
if config.TopP != nil {
|
||||
openaiConfig.TopP = config.TopP
|
||||
if config.TopP != nil {
|
||||
openaiConfig.TopP = config.TopP
|
||||
}
|
||||
}
|
||||
|
||||
if len(config.StopSequences) > 0 {
|
||||
|
||||
@@ -1,113 +1 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AnthropicTokenCounter implements token counting for Anthropic models
|
||||
type AnthropicTokenCounter struct {
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAnthropicTokenCounter creates a new Anthropic token counter
|
||||
func NewAnthropicTokenCounter(apiKey string) *AnthropicTokenCounter {
|
||||
return &AnthropicTokenCounter{
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AnthropicTokenRequest represents the request payload for Anthropic token counting
|
||||
type AnthropicTokenRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// AnthropicTokenResponse represents the response from Anthropic token counting API
|
||||
type AnthropicTokenResponse struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// CountTokens counts tokens using Anthropic's token counting API
|
||||
func (a *AnthropicTokenCounter) CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error) {
|
||||
if a.apiKey == "" {
|
||||
return nil, fmt.Errorf("anthropic API key not provided")
|
||||
}
|
||||
|
||||
// Strip the anthropic: prefix if present
|
||||
actualModel := model
|
||||
if strings.HasPrefix(model, "anthropic:") {
|
||||
actualModel = strings.TrimPrefix(model, "anthropic:")
|
||||
}
|
||||
|
||||
// Prepare request payload
|
||||
request := AnthropicTokenRequest{
|
||||
Messages: messages,
|
||||
Model: actualModel,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-api-key", a.apiKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
// Make the request
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Check for HTTP errors
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var tokenResponse AnthropicTokenResponse
|
||||
if err := json.Unmarshal(body, &tokenResponse); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &TokenCount{
|
||||
InputTokens: tokenResponse.InputTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SupportsModel returns true if this counter supports the given model
|
||||
func (a *AnthropicTokenCounter) SupportsModel(model string) bool {
|
||||
// Support all Anthropic models
|
||||
return strings.HasPrefix(model, "anthropic:")
|
||||
}
|
||||
|
||||
// ProviderName returns the name of the provider
|
||||
func (a *AnthropicTokenCounter) ProviderName() string {
|
||||
return "anthropic"
|
||||
}
|
||||
package tokens
|
||||
@@ -1,99 +1,7 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Message represents a message for token counting
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// TokenCount represents the result of token counting
|
||||
type TokenCount struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
}
|
||||
|
||||
// TokenCounter interface for provider-specific token counting
|
||||
type TokenCounter interface {
|
||||
// CountTokens counts tokens for the given messages and model
|
||||
CountTokens(ctx context.Context, messages []Message, model string) (*TokenCount, error)
|
||||
// SupportsModel returns true if this counter supports the given model
|
||||
SupportsModel(model string) bool
|
||||
// ProviderName returns the name of the provider this counter is for
|
||||
ProviderName() string
|
||||
}
|
||||
|
||||
// EstimateTokens provides a rough estimate of tokens in text
|
||||
// This is a fallback when no provider-specific counter is available
|
||||
func EstimateTokens(text string) int {
|
||||
// Rough approximation: ~4 characters per token for most models
|
||||
// This is not accurate but gives a reasonable estimate
|
||||
return len(text) / 4
|
||||
}
|
||||
|
||||
// EstimateTokensFromMessages estimates tokens from a slice of messages
|
||||
func EstimateTokensFromMessages(messages []Message) int {
|
||||
totalChars := 0
|
||||
for _, msg := range messages {
|
||||
totalChars += len(msg.Content)
|
||||
totalChars += len(msg.Role) + 10 // Add some overhead for role and formatting
|
||||
}
|
||||
return EstimateTokens(string(rune(totalChars)))
|
||||
}
|
||||
|
||||
// Registry holds all registered token counters
|
||||
type Registry struct {
|
||||
counters map[string]TokenCounter
|
||||
}
|
||||
|
||||
// NewRegistry creates a new token counter registry
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
counters: make(map[string]TokenCounter),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a token counter to the registry
|
||||
func (r *Registry) Register(counter TokenCounter) {
|
||||
r.counters[counter.ProviderName()] = counter
|
||||
}
|
||||
|
||||
// GetCounter returns a token counter for the given provider
|
||||
func (r *Registry) GetCounter(provider string) (TokenCounter, bool) {
|
||||
counter, exists := r.counters[provider]
|
||||
return counter, exists
|
||||
}
|
||||
|
||||
// CountTokens attempts to count tokens using a provider-specific counter,
|
||||
// falling back to estimation if no counter is available
|
||||
func (r *Registry) CountTokens(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
|
||||
if counter, exists := r.GetCounter(provider); exists && counter.SupportsModel(model) {
|
||||
return counter.CountTokens(ctx, messages, model)
|
||||
}
|
||||
|
||||
// Fallback to estimation
|
||||
estimatedTokens := EstimateTokensFromMessages(messages)
|
||||
return &TokenCount{
|
||||
InputTokens: estimatedTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewRegistry()
|
||||
|
||||
// GetGlobalRegistry returns the global token counter registry
|
||||
func GetGlobalRegistry() *Registry {
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// RegisterCounter registers a token counter with the global registry
|
||||
func RegisterCounter(counter TokenCounter) {
|
||||
globalRegistry.Register(counter)
|
||||
}
|
||||
|
||||
// CountTokensGlobal counts tokens using the global registry
|
||||
func CountTokensGlobal(ctx context.Context, provider string, messages []Message, model string) (*TokenCount, error) {
|
||||
return globalRegistry.CountTokens(ctx, provider, messages, model)
|
||||
}
|
||||
+3
-12
@@ -1,20 +1,11 @@
|
||||
package tokens
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// InitializeTokenCounters registers all available token counters
|
||||
func InitializeTokenCounters() {
|
||||
// Register Anthropic token counter if API key is available
|
||||
if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" {
|
||||
RegisterCounter(NewAnthropicTokenCounter(apiKey))
|
||||
}
|
||||
// Future provider-specific counters can be registered here
|
||||
}
|
||||
|
||||
// InitializeTokenCountersWithKeys registers token counters with provided API keys
|
||||
func InitializeTokenCountersWithKeys(anthropicKey string) {
|
||||
if anthropicKey != "" {
|
||||
RegisterCounter(NewAnthropicTokenCounter(anthropicKey))
|
||||
}
|
||||
func InitializeTokenCountersWithKeys() {
|
||||
// Future provider-specific counters can be registered here
|
||||
}
|
||||
+24
-6
@@ -1,7 +1,6 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,7 +11,6 @@ import (
|
||||
"github.com/charmbracelet/huh"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/mark3labs/mcphost/internal/tokens"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
@@ -324,10 +322,30 @@ func (c *CLI) UpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateUsageWithMessages updates the usage tracker using custom token counting for messages
|
||||
func (c *CLI) UpdateUsageWithMessages(ctx context.Context, messages []tokens.Message, outputText string) {
|
||||
if c.usageTracker != nil {
|
||||
c.usageTracker.CountAndUpdateUsage(ctx, messages, outputText)
|
||||
|
||||
|
||||
// UpdateUsageFromResponse updates the usage tracker using token usage from response metadata
|
||||
func (c *CLI) UpdateUsageFromResponse(response *schema.Message, inputText string) {
|
||||
if c.usageTracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to extract token usage from response metadata
|
||||
if response.ResponseMeta != nil && response.ResponseMeta.Usage != nil {
|
||||
usage := response.ResponseMeta.Usage
|
||||
|
||||
// Use actual token counts from the response
|
||||
inputTokens := int(usage.PromptTokens)
|
||||
outputTokens := int(usage.CompletionTokens)
|
||||
|
||||
// Handle cache tokens if available (some providers support this)
|
||||
cacheReadTokens := 0
|
||||
cacheWriteTokens := 0
|
||||
|
||||
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
|
||||
} else {
|
||||
// Fallback to estimation if no metadata is available
|
||||
c.usageTracker.EstimateAndUpdateUsage(inputText, response.Content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
@@ -102,30 +101,15 @@ func (ut *UsageTracker) UpdateUsage(inputTokens, outputTokens, cacheReadTokens,
|
||||
|
||||
// EstimateAndUpdateUsage estimates tokens from text and updates usage
|
||||
func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
inputTokens := EstimateTokens(inputText)
|
||||
outputTokens := EstimateTokens(outputText)
|
||||
inputTokens := tokens.EstimateTokens(inputText)
|
||||
outputTokens := tokens.EstimateTokens(outputText)
|
||||
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
|
||||
}
|
||||
|
||||
// CountAndUpdateUsage counts tokens using provider-specific counters and updates usage
|
||||
func (ut *UsageTracker) CountAndUpdateUsage(ctx context.Context, messages []tokens.Message, outputText string) {
|
||||
// Count input tokens using provider-specific counter
|
||||
tokenCount, err := tokens.CountTokensGlobal(ctx, ut.provider, messages, ut.modelInfo.ID)
|
||||
var inputTokens int
|
||||
if err != nil {
|
||||
// Fallback to estimation if token counting fails
|
||||
var totalInput string
|
||||
for _, msg := range messages {
|
||||
totalInput += msg.Content
|
||||
}
|
||||
inputTokens = EstimateTokens(totalInput)
|
||||
} else {
|
||||
inputTokens = tokenCount.InputTokens
|
||||
}
|
||||
|
||||
// Estimate output tokens (providers typically don't count output tokens separately)
|
||||
outputTokens := EstimateTokens(outputText)
|
||||
|
||||
// EstimateAndUpdateUsageFromText estimates tokens from text and updates usage
|
||||
func (ut *UsageTracker) EstimateAndUpdateUsageFromText(inputText, outputText string) {
|
||||
inputTokens := tokens.EstimateTokens(inputText)
|
||||
outputTokens := tokens.EstimateTokens(outputText)
|
||||
ut.UpdateUsage(inputTokens, outputTokens, 0, 0)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user