fix(extensions): address review feedback on state store and llmUsageMeta

- Serialize SetState/DeleteState saver invocations through a new saverMu
  so overlapping atomic-rename writes can no longer race on the shared
  .tmp file and persist an older snapshot after a newer one.
- LoadStateFromFile now clears the in-memory store when the sidecar is
  missing or empty, matching the documented "replace … with its
  contents" contract. This makes session-switching safe by preventing
  keys from a prior session leaking into a new one. Tests updated to
  cover both the missing-file and empty-file cases.
- llmUsageMeta now detects Anthropic OAuth credentials and returns
  Cost=0, matching the comment and the existing usage_tracker behavior
  for OAuth users. Mirrors the OAuth detection already used in
  cmd/extension_context.go.
- Document the single-in-flight-turn assumption baked into the
  per-turn aggregator with a clear migration path (per-turn ID) for if
  concurrent turns ever become a supported use case.
This commit is contained in:
Ed Zynda
2026-06-09 14:51:45 +03:00
parent 8b442d2cbc
commit 4a7e4223e0
4 changed files with 92 additions and 9 deletions
+22 -3
View File
@@ -104,6 +104,7 @@ type Runner struct {
configStore *viper.Viper // per-instance config store (nil = global)
state map[string]string // session-scoped extension state (last-write-wins)
stateMu sync.RWMutex // guards state independently of mu
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
stateSaver func() // optional persistence hook invoked after each state mutation
mu sync.RWMutex
}
@@ -770,6 +771,11 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
// SetState records a key-value pair in the runner's session-scoped extension
// state store. The store is in-memory; callers wire SetStateSaver to persist
// changes to a sidecar file. Thread-safe.
//
// When a saver is installed, concurrent SetState/DeleteState invocations are
// serialized through saverMu so that overlapping snapshot-and-rename writes
// cannot interleave (which would otherwise race on the shared tmp file and
// risk persisting an older snapshot after a newer one).
func (r *Runner) SetState(key, value string) {
r.stateMu.Lock()
if r.state == nil {
@@ -779,7 +785,9 @@ func (r *Runner) SetState(key, value string) {
saver := r.stateSaver
r.stateMu.Unlock()
if saver != nil {
r.saverMu.Lock()
saver()
r.saverMu.Unlock()
}
}
@@ -793,7 +801,8 @@ func (r *Runner) GetState(key string) (string, bool) {
}
// DeleteState removes a key from the state store. No-op if the key is
// missing. Thread-safe.
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
// SetState for the rationale.
func (r *Runner) DeleteState(key string) {
r.stateMu.Lock()
_, existed := r.state[key]
@@ -803,7 +812,9 @@ func (r *Runner) DeleteState(key string) {
saver := r.stateSaver
r.stateMu.Unlock()
if existed && saver != nil {
r.saverMu.Lock()
saver()
r.saverMu.Unlock()
}
}
@@ -847,17 +858,25 @@ func (r *Runner) SnapshotState() map[string]string {
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
// state store with its contents. Missing or empty files are treated as
// "no prior state" (not an error). Malformed JSON returns the parse error
// without touching the existing store. Thread-safe.
// "no prior state": the in-memory store is replaced with an empty map so
// callers can safely switch sessions without leaking keys from a prior
// session into a new one. Malformed JSON returns the parse error without
// touching the existing store. Thread-safe.
func (r *Runner) LoadStateFromFile(path string) error {
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
r.stateMu.Lock()
r.state = map[string]string{}
r.stateMu.Unlock()
return nil
}
return fmt.Errorf("reading extension state: %w", err)
}
if len(data) == 0 {
r.stateMu.Lock()
r.state = map[string]string{}
r.stateMu.Unlock()
return nil
}
var loaded map[string]string
+26 -4
View File
@@ -101,15 +101,37 @@ func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
}
}
func TestRunner_State_LoadMissingFileIsNoop(t *testing.T) {
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
// LoadStateFromFile is documented to "replace the in-memory state store
// with its contents"; for a missing file that means clearing the store.
// This is what makes session-switching safe: a new session that has not
// yet written a sidecar must not inherit keys from a prior session.
r := NewRunner(nil)
r.SetState("a", "1")
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
t.Errorf("expected nil error for missing file, got %v", err)
}
// Existing in-memory state is left alone when file doesn't exist.
if v, ok := r.GetState("a"); !ok || v != "1" {
t.Errorf("expected pre-existing state preserved, got (%q,%v)", v, ok)
if _, ok := r.GetState("a"); ok {
t.Error("expected pre-existing state to be cleared when target file is missing")
}
if keys := r.ListState(); keys != nil {
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
}
}
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "empty.json")
if err := os.WriteFile(path, nil, 0o644); err != nil {
t.Fatal(err)
}
r := NewRunner(nil)
r.SetState("a", "1")
if err := r.LoadStateFromFile(path); err != nil {
t.Errorf("expected nil error for empty file, got %v", err)
}
if _, ok := r.GetState("a"); ok {
t.Error("expected pre-existing state to be cleared when target file is empty")
}
}
+34 -2
View File
@@ -5,6 +5,7 @@ import (
"sync"
"time"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
)
@@ -24,6 +25,15 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// Per-turn aggregator: collects tool/LLM/usage signals between AgentStart
// and AgentEnd so the enriched AgentEndEvent can be populated without
// requiring extensions to maintain parallel bookkeeping.
//
// NOTE: this aggregator assumes a single in-flight turn per *Kit instance,
// which is the current contract — runTurn does not serialize callers and
// the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two
// concurrent Prompt() calls on the same *Kit would clobber the counters.
// All current callers (TUI app layer, CLI runner, SDK examples) serialize
// turns above this layer. If concurrent turns become a supported use case,
// extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per
// turn instead.
turnAgg := &turnAggregator{kit: m}
m.Subscribe(func(e Event) {
switch ev := e.(type) {
@@ -530,8 +540,12 @@ func (a *turnAggregator) consume() turnSnapshot {
// llmUsageMeta returns the current provider, model id, and computed cost for
// the given usage values using the Kit instance's active model. Cost is zero
// when the model is not in the registry (e.g. custom fine-tunes, unknown
// providers) or pricing fields are unset.
// in any of the following cases:
// - the *Kit pointer is nil or has no active model;
// - the model is not in the registry (custom fine-tunes, unknown providers);
// - the model has no pricing fields set;
// - the active credential is an Anthropic OAuth token (matches the
// existing usage_tracker behavior of suppressing cost for OAuth users).
func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) {
if m == nil {
return "", "", 0
@@ -549,6 +563,9 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
if info == nil {
return provider, modelID, 0
}
if isAnthropicOAuth(m, provider) {
return provider, modelID, 0
}
cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000
cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000
if info.Cost.CacheRead != nil {
@@ -560,6 +577,21 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
return provider, modelID, cost
}
// isAnthropicOAuth reports whether the current Anthropic credential resolves
// to a stored OAuth token (in which case the user is not billed per-token).
// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker
// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
func isAnthropicOAuth(m *Kit, provider string) bool {
if m == nil || provider != "anthropic" {
return false
}
_, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key"))
if err != nil {
return false
}
return strings.HasPrefix(source, "stored OAuth")
}
// llmToContextMessages converts a slice of LLM messages to extension
// ContextMessage values, extracting plain text from each message.
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
+10
View File
@@ -108,6 +108,16 @@ func TestLLMUsageMeta_NilKit(t *testing.T) {
}
}
// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any
// provider other than "anthropic" without touching the credential store.
func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) {
for _, provider := range []string{"openai", "google", "openrouter", ""} {
if isAnthropicOAuth(nil, provider) {
t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider)
}
}
}
func TestExtStateSidecarPath(t *testing.T) {
tests := []struct {
name string