diff --git a/internal/extensions/runner.go b/internal/extensions/runner.go index 023c5112..1564e0b4 100644 --- a/internal/extensions/runner.go +++ b/internal/extensions/runner.go @@ -104,6 +104,7 @@ type Runner struct { configStore *viper.Viper // per-instance config store (nil = global) state map[string]string // session-scoped extension state (last-write-wins) stateMu sync.RWMutex // guards state independently of mu + saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave stateSaver func() // optional persistence hook invoked after each state mutation mu sync.RWMutex } @@ -770,6 +771,11 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig { // SetState records a key-value pair in the runner's session-scoped extension // state store. The store is in-memory; callers wire SetStateSaver to persist // changes to a sidecar file. Thread-safe. +// +// When a saver is installed, concurrent SetState/DeleteState invocations are +// serialized through saverMu so that overlapping snapshot-and-rename writes +// cannot interleave (which would otherwise race on the shared tmp file and +// risk persisting an older snapshot after a newer one). func (r *Runner) SetState(key, value string) { r.stateMu.Lock() if r.state == nil { @@ -779,7 +785,9 @@ func (r *Runner) SetState(key, value string) { saver := r.stateSaver r.stateMu.Unlock() if saver != nil { + r.saverMu.Lock() saver() + r.saverMu.Unlock() } } @@ -793,7 +801,8 @@ func (r *Runner) GetState(key string) (string, bool) { } // DeleteState removes a key from the state store. No-op if the key is -// missing. Thread-safe. +// missing. Thread-safe. Saver invocations are serialized via saverMu — see +// SetState for the rationale. func (r *Runner) DeleteState(key string) { r.stateMu.Lock() _, existed := r.state[key] @@ -803,7 +812,9 @@ func (r *Runner) DeleteState(key string) { saver := r.stateSaver r.stateMu.Unlock() if existed && saver != nil { + r.saverMu.Lock() saver() + r.saverMu.Unlock() } } @@ -847,17 +858,25 @@ func (r *Runner) SnapshotState() map[string]string { // LoadStateFromFile reads a JSON map from path and replaces the in-memory // state store with its contents. Missing or empty files are treated as -// "no prior state" (not an error). Malformed JSON returns the parse error -// without touching the existing store. Thread-safe. +// "no prior state": the in-memory store is replaced with an empty map so +// callers can safely switch sessions without leaking keys from a prior +// session into a new one. Malformed JSON returns the parse error without +// touching the existing store. Thread-safe. func (r *Runner) LoadStateFromFile(path string) error { data, err := os.ReadFile(path) if err != nil { if os.IsNotExist(err) { + r.stateMu.Lock() + r.state = map[string]string{} + r.stateMu.Unlock() return nil } return fmt.Errorf("reading extension state: %w", err) } if len(data) == 0 { + r.stateMu.Lock() + r.state = map[string]string{} + r.stateMu.Unlock() return nil } var loaded map[string]string diff --git a/internal/extensions/state_test.go b/internal/extensions/state_test.go index 59c8ee1c..c570848e 100644 --- a/internal/extensions/state_test.go +++ b/internal/extensions/state_test.go @@ -101,15 +101,37 @@ func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) { } } -func TestRunner_State_LoadMissingFileIsNoop(t *testing.T) { +func TestRunner_State_LoadMissingFileClearsState(t *testing.T) { + // LoadStateFromFile is documented to "replace the in-memory state store + // with its contents"; for a missing file that means clearing the store. + // This is what makes session-switching safe: a new session that has not + // yet written a sidecar must not inherit keys from a prior session. r := NewRunner(nil) r.SetState("a", "1") if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil { t.Errorf("expected nil error for missing file, got %v", err) } - // Existing in-memory state is left alone when file doesn't exist. - if v, ok := r.GetState("a"); !ok || v != "1" { - t.Errorf("expected pre-existing state preserved, got (%q,%v)", v, ok) + if _, ok := r.GetState("a"); ok { + t.Error("expected pre-existing state to be cleared when target file is missing") + } + if keys := r.ListState(); keys != nil { + t.Errorf("expected ListState() to be nil after clearing, got %v", keys) + } +} + +func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.json") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatal(err) + } + r := NewRunner(nil) + r.SetState("a", "1") + if err := r.LoadStateFromFile(path); err != nil { + t.Errorf("expected nil error for empty file, got %v", err) + } + if _, ok := r.GetState("a"); ok { + t.Error("expected pre-existing state to be cleared when target file is empty") } } diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go index 4e999674..36e069bb 100644 --- a/pkg/kit/extensions_bridge.go +++ b/pkg/kit/extensions_bridge.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/models" ) @@ -24,6 +25,15 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { // Per-turn aggregator: collects tool/LLM/usage signals between AgentStart // and AgentEnd so the enriched AgentEndEvent can be populated without // requiring extensions to maintain parallel bookkeeping. + // + // NOTE: this aggregator assumes a single in-flight turn per *Kit instance, + // which is the current contract — runTurn does not serialize callers and + // the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two + // concurrent Prompt() calls on the same *Kit would clobber the counters. + // All current callers (TUI app layer, CLI runner, SDK examples) serialize + // turns above this layer. If concurrent turns become a supported use case, + // extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per + // turn instead. turnAgg := &turnAggregator{kit: m} m.Subscribe(func(e Event) { switch ev := e.(type) { @@ -530,8 +540,12 @@ func (a *turnAggregator) consume() turnSnapshot { // llmUsageMeta returns the current provider, model id, and computed cost for // the given usage values using the Kit instance's active model. Cost is zero -// when the model is not in the registry (e.g. custom fine-tunes, unknown -// providers) or pricing fields are unset. +// in any of the following cases: +// - the *Kit pointer is nil or has no active model; +// - the model is not in the registry (custom fine-tunes, unknown providers); +// - the model has no pricing fields set; +// - the active credential is an Anthropic OAuth token (matches the +// existing usage_tracker behavior of suppressing cost for OAuth users). func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) { if m == nil { return "", "", 0 @@ -549,6 +563,9 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6 if info == nil { return provider, modelID, 0 } + if isAnthropicOAuth(m, provider) { + return provider, modelID, 0 + } cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000 cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000 if info.Cost.CacheRead != nil { @@ -560,6 +577,21 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6 return provider, modelID, cost } +// isAnthropicOAuth reports whether the current Anthropic credential resolves +// to a stored OAuth token (in which case the user is not billed per-token). +// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker +// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage(). +func isAnthropicOAuth(m *Kit, provider string) bool { + if m == nil || provider != "anthropic" { + return false + } + _, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key")) + if err != nil { + return false + } + return strings.HasPrefix(source, "stored OAuth") +} + // llmToContextMessages converts a slice of LLM messages to extension // ContextMessage values, extracting plain text from each message. func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage { diff --git a/pkg/kit/extensions_bridge_test.go b/pkg/kit/extensions_bridge_test.go index 29a5c18d..cf0dd976 100644 --- a/pkg/kit/extensions_bridge_test.go +++ b/pkg/kit/extensions_bridge_test.go @@ -108,6 +108,16 @@ func TestLLMUsageMeta_NilKit(t *testing.T) { } } +// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any +// provider other than "anthropic" without touching the credential store. +func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) { + for _, provider := range []string{"openai", "google", "openrouter", ""} { + if isAnthropicOAuth(nil, provider) { + t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider) + } + } +} + func TestExtStateSidecarPath(t *testing.T) { tests := []struct { name string