mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
fix: update token counting when switching models mid-session
When switching models (e.g., via /model command or ctx.SetModel), the usage tracker now updates its model info to reflect the new model's: - Pricing for cost calculations - Context limits for percentage display - OAuth status (to show bash costs when using OAuth creds) Previously, token costs and context percentages continued using the old model's settings after a switch, causing incorrect display for: - Users switching from paid to free/OAuth models - Users switching between models with different pricing Changes: - Add UpdateModelInfo() method to UsageTracker - Call UpdateModelInfo() in both SetModel callbacks (extension and UI) - Add auth import for OAuth detection in root.go
This commit is contained in:
+37
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -955,6 +956,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
@@ -1152,6 +1171,24 @@ func runNormalMode(ctx context.Context) error {
|
||||
// this callback runs synchronously inside BubbleTea's Update(), and
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
|
||||
@@ -266,3 +266,14 @@ func (ut *UsageTracker) SetWidth(width int) {
|
||||
defer ut.mu.Unlock()
|
||||
ut.width = width
|
||||
}
|
||||
|
||||
// UpdateModelInfo updates the model information and OAuth status when the model
|
||||
// is switched mid-session. This ensures token costs and context limits are
|
||||
// calculated correctly for the new model.
|
||||
func (ut *UsageTracker) UpdateModelInfo(modelInfo *models.ModelInfo, provider string, isOAuth bool) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
ut.modelInfo = modelInfo
|
||||
ut.provider = provider
|
||||
ut.isOAuth = isOAuth
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user