mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
fix(extensions): address review feedback on state store and llmUsageMeta
- Serialize SetState/DeleteState saver invocations through a new saverMu so overlapping atomic-rename writes can no longer race on the shared .tmp file and persist an older snapshot after a newer one. - LoadStateFromFile now clears the in-memory store when the sidecar is missing or empty, matching the documented "replace … with its contents" contract. This makes session-switching safe by preventing keys from a prior session leaking into a new one. Tests updated to cover both the missing-file and empty-file cases. - llmUsageMeta now detects Anthropic OAuth credentials and returns Cost=0, matching the comment and the existing usage_tracker behavior for OAuth users. Mirrors the OAuth detection already used in cmd/extension_context.go. - Document the single-in-flight-turn assumption baked into the per-turn aggregator with a clear migration path (per-turn ID) for if concurrent turns ever become a supported use case.
This commit is contained in:
@@ -104,6 +104,7 @@ type Runner struct {
|
||||
configStore *viper.Viper // per-instance config store (nil = global)
|
||||
state map[string]string // session-scoped extension state (last-write-wins)
|
||||
stateMu sync.RWMutex // guards state independently of mu
|
||||
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
|
||||
stateSaver func() // optional persistence hook invoked after each state mutation
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -770,6 +771,11 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
|
||||
// SetState records a key-value pair in the runner's session-scoped extension
|
||||
// state store. The store is in-memory; callers wire SetStateSaver to persist
|
||||
// changes to a sidecar file. Thread-safe.
|
||||
//
|
||||
// When a saver is installed, concurrent SetState/DeleteState invocations are
|
||||
// serialized through saverMu so that overlapping snapshot-and-rename writes
|
||||
// cannot interleave (which would otherwise race on the shared tmp file and
|
||||
// risk persisting an older snapshot after a newer one).
|
||||
func (r *Runner) SetState(key, value string) {
|
||||
r.stateMu.Lock()
|
||||
if r.state == nil {
|
||||
@@ -779,7 +785,9 @@ func (r *Runner) SetState(key, value string) {
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
if saver != nil {
|
||||
r.saverMu.Lock()
|
||||
saver()
|
||||
r.saverMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -793,7 +801,8 @@ func (r *Runner) GetState(key string) (string, bool) {
|
||||
}
|
||||
|
||||
// DeleteState removes a key from the state store. No-op if the key is
|
||||
// missing. Thread-safe.
|
||||
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
|
||||
// SetState for the rationale.
|
||||
func (r *Runner) DeleteState(key string) {
|
||||
r.stateMu.Lock()
|
||||
_, existed := r.state[key]
|
||||
@@ -803,7 +812,9 @@ func (r *Runner) DeleteState(key string) {
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
if existed && saver != nil {
|
||||
r.saverMu.Lock()
|
||||
saver()
|
||||
r.saverMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -847,17 +858,25 @@ func (r *Runner) SnapshotState() map[string]string {
|
||||
|
||||
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
|
||||
// state store with its contents. Missing or empty files are treated as
|
||||
// "no prior state" (not an error). Malformed JSON returns the parse error
|
||||
// without touching the existing store. Thread-safe.
|
||||
// "no prior state": the in-memory store is replaced with an empty map so
|
||||
// callers can safely switch sessions without leaking keys from a prior
|
||||
// session into a new one. Malformed JSON returns the parse error without
|
||||
// touching the existing store. Thread-safe.
|
||||
func (r *Runner) LoadStateFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading extension state: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var loaded map[string]string
|
||||
|
||||
@@ -101,15 +101,37 @@ func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMissingFileIsNoop(t *testing.T) {
|
||||
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
|
||||
// LoadStateFromFile is documented to "replace the in-memory state store
|
||||
// with its contents"; for a missing file that means clearing the store.
|
||||
// This is what makes session-switching safe: a new session that has not
|
||||
// yet written a sidecar must not inherit keys from a prior session.
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
// Existing in-memory state is left alone when file doesn't exist.
|
||||
if v, ok := r.GetState("a"); !ok || v != "1" {
|
||||
t.Errorf("expected pre-existing state preserved, got (%q,%v)", v, ok)
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is missing")
|
||||
}
|
||||
if keys := r.ListState(); keys != nil {
|
||||
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.json")
|
||||
if err := os.WriteFile(path, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(path); err != nil {
|
||||
t.Errorf("expected nil error for empty file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
)
|
||||
@@ -24,6 +25,15 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Per-turn aggregator: collects tool/LLM/usage signals between AgentStart
|
||||
// and AgentEnd so the enriched AgentEndEvent can be populated without
|
||||
// requiring extensions to maintain parallel bookkeeping.
|
||||
//
|
||||
// NOTE: this aggregator assumes a single in-flight turn per *Kit instance,
|
||||
// which is the current contract — runTurn does not serialize callers and
|
||||
// the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two
|
||||
// concurrent Prompt() calls on the same *Kit would clobber the counters.
|
||||
// All current callers (TUI app layer, CLI runner, SDK examples) serialize
|
||||
// turns above this layer. If concurrent turns become a supported use case,
|
||||
// extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per
|
||||
// turn instead.
|
||||
turnAgg := &turnAggregator{kit: m}
|
||||
m.Subscribe(func(e Event) {
|
||||
switch ev := e.(type) {
|
||||
@@ -530,8 +540,12 @@ func (a *turnAggregator) consume() turnSnapshot {
|
||||
|
||||
// llmUsageMeta returns the current provider, model id, and computed cost for
|
||||
// the given usage values using the Kit instance's active model. Cost is zero
|
||||
// when the model is not in the registry (e.g. custom fine-tunes, unknown
|
||||
// providers) or pricing fields are unset.
|
||||
// in any of the following cases:
|
||||
// - the *Kit pointer is nil or has no active model;
|
||||
// - the model is not in the registry (custom fine-tunes, unknown providers);
|
||||
// - the model has no pricing fields set;
|
||||
// - the active credential is an Anthropic OAuth token (matches the
|
||||
// existing usage_tracker behavior of suppressing cost for OAuth users).
|
||||
func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) {
|
||||
if m == nil {
|
||||
return "", "", 0
|
||||
@@ -549,6 +563,9 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
|
||||
if info == nil {
|
||||
return provider, modelID, 0
|
||||
}
|
||||
if isAnthropicOAuth(m, provider) {
|
||||
return provider, modelID, 0
|
||||
}
|
||||
cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000
|
||||
cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000
|
||||
if info.Cost.CacheRead != nil {
|
||||
@@ -560,6 +577,21 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
|
||||
return provider, modelID, cost
|
||||
}
|
||||
|
||||
// isAnthropicOAuth reports whether the current Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token).
|
||||
// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker
|
||||
// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
|
||||
func isAnthropicOAuth(m *Kit, provider string) bool {
|
||||
if m == nil || provider != "anthropic" {
|
||||
return false
|
||||
}
|
||||
_, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(source, "stored OAuth")
|
||||
}
|
||||
|
||||
// llmToContextMessages converts a slice of LLM messages to extension
|
||||
// ContextMessage values, extracting plain text from each message.
|
||||
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
|
||||
|
||||
@@ -108,6 +108,16 @@ func TestLLMUsageMeta_NilKit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any
|
||||
// provider other than "anthropic" without touching the credential store.
|
||||
func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) {
|
||||
for _, provider := range []string{"openai", "google", "openrouter", ""} {
|
||||
if isAnthropicOAuth(nil, provider) {
|
||||
t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtStateSidecarPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
Reference in New Issue
Block a user