mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
* feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53) Three additive primitives to the extension API: - OnLLMUsage event: per-LLM-call token + cost deltas attributed to the specific model/provider used for each round-trip. Derived from the SDK StepFinishEvent in the extension bridge. Enables accurate budget enforcement between calls instead of only at turn boundaries. - ctx.SetState / GetState / DeleteState / ListState: session-scoped, last-write-wins key-value store backed by a sidecar file (<session>.ext-state.json) outside the conversation tree. Reads are O(1), writes don't grow the JSONL, and the store is not duplicated on fork. State is preserved across hot-reloads. - Enriched AgentEndEvent: ToolCallCount, ToolNames, LLMCallCount, token deltas (input/output/cache-read/cache-write), CostDelta, and DurationMs populated by a per-turn aggregator. Existing handlers reading only Response/StopReason are unaffected. Includes unit tests for the state store, LLMUsage registration, enriched AgentEndEvent, turn aggregator, llmUsageMeta, and sidecar path derivation. Adds examples/extensions/usage-budget.go demoing all three primitives together. Documents the additions in README, the docs site (extensions overview, capabilities, examples), and the kit-extensions and kit-sdk skill guides. Fixes #53 * fix(extensions): address review feedback on state store and llmUsageMeta - Serialize SetState/DeleteState saver invocations through a new saverMu so overlapping atomic-rename writes can no longer race on the shared .tmp file and persist an older snapshot after a newer one. - LoadStateFromFile now clears the in-memory store when the sidecar is missing or empty, matching the documented "replace … with its contents" contract. This makes session-switching safe by preventing keys from a prior session leaking into a new one. Tests updated to cover both the missing-file and empty-file cases. - llmUsageMeta now detects Anthropic OAuth credentials and returns Cost=0, matching the comment and the existing usage_tracker behavior for OAuth users. Mirrors the OAuth detection already used in cmd/extension_context.go. - Document the single-in-flight-turn assumption baked into the per-turn aggregator with a clear migration path (per-turn ID) for if concurrent turns ever become a supported use case. * fix(extensions): release saverMu on panic in state store Extract a runSaver helper that locks saverMu and defers Unlock before invoking the persistence callback. Without the deferred Unlock, a panic inside the saver (e.g. disk full mid-write) would leave saverMu held forever and deadlock the next SetState/DeleteState. Both SetState and DeleteState now route through the helper. New TestRunner_State_Saver PanicReleasesSaverMu reproduces the deadlock window with a 2s deadline and proves the mutex is released after a panic.
This commit is contained in:
+135
-1
@@ -341,6 +341,13 @@ type Context struct {
|
||||
// The data survives across session restarts and can be retrieved via
|
||||
// GetEntries. Use entryType to namespace your data (e.g. "myext:state").
|
||||
//
|
||||
// AppendEntry is append-only and lives in the conversation tree, which
|
||||
// makes it the right tool for audit logs and event histories. For
|
||||
// last-write-wins snapshot state — "what's the current value of X?" —
|
||||
// prefer SetState / GetState instead. Those primitives store data in a
|
||||
// sidecar file outside the conversation tree, are O(1) to read/write,
|
||||
// and do not bloat branch reads or duplicate on fork.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := json.Marshal(myState)
|
||||
@@ -360,6 +367,45 @@ type Context struct {
|
||||
// }
|
||||
GetEntries func(entryType string) []ExtensionEntry
|
||||
|
||||
// SetState stores a key-value pair in session-scoped, last-write-wins
|
||||
// extension state. Unlike AppendEntry the value is kept in a sidecar
|
||||
// file outside the conversation tree, so:
|
||||
// - reads are O(1) (no branch walk)
|
||||
// - writes don't bloat the session JSONL
|
||||
// - state is not duplicated on fork (branches share the sidecar)
|
||||
// - state is invisible to the LLM
|
||||
//
|
||||
// Use SetState for snapshot state ("current value of X"); use
|
||||
// AppendEntry for audit logs and event histories. Namespace keys with
|
||||
// your extension name to avoid collisions (e.g. "myext:budget-cap").
|
||||
//
|
||||
// State persists for the lifetime of the session. For ephemeral or
|
||||
// in-memory sessions the state lives only in memory.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.SetState("myext:budget-cap", "10.00")
|
||||
SetState func(key string, value string)
|
||||
|
||||
// GetState returns the value previously stored via SetState. The bool
|
||||
// is false when the key was never written. Returns ("", false) when
|
||||
// state is unavailable.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if cap, ok := ctx.GetState("myext:budget-cap"); ok {
|
||||
// fmt.Println("current cap:", cap)
|
||||
// }
|
||||
GetState func(key string) (string, bool)
|
||||
|
||||
// DeleteState removes a key from session-scoped extension state.
|
||||
// No-op when the key is missing.
|
||||
DeleteState func(key string)
|
||||
|
||||
// ListState returns all keys currently stored in session-scoped
|
||||
// extension state, in unspecified order.
|
||||
ListState func() []string
|
||||
|
||||
// SetEditorText sets the text content of the input editor. This can
|
||||
// be used to pre-fill the editor with suggested text (e.g. extracted
|
||||
// questions, handoff prompts). The cursor is moved to the end.
|
||||
@@ -1102,6 +1148,7 @@ type API struct {
|
||||
onError func(func(ErrorEvent, Context))
|
||||
onRetry func(func(RetryEvent, Context))
|
||||
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
|
||||
onLLMUsage func(func(LLMUsageEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -1359,6 +1406,19 @@ func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStep
|
||||
a.onPrepareStep(handler)
|
||||
}
|
||||
|
||||
// OnLLMUsage registers a handler that fires after each LLM provider call
|
||||
// with the token and cost deltas for that single call. Use this for
|
||||
// per-call usage attribution, real-time budget enforcement, and cost
|
||||
// dashboards that need to react between calls within a single agent turn.
|
||||
//
|
||||
// Handlers receive an LLMUsageEvent describing the call's input/output
|
||||
// tokens, cache tokens, computed cost, model, and provider. A single agent
|
||||
// turn typically fires multiple LLMUsageEvents (one per tool-loop
|
||||
// iteration).
|
||||
func (a *API) OnLLMUsage(handler func(LLMUsageEvent, Context)) {
|
||||
a.onLLMUsage(handler)
|
||||
}
|
||||
|
||||
// RegisterToolRenderer registers a custom renderer for a specific tool's
|
||||
// display in the TUI. The renderer controls the header (parameter summary)
|
||||
// and/or body (result display) of the tool's output block. If multiple
|
||||
@@ -2091,10 +2151,47 @@ type AgentStartEvent struct {
|
||||
|
||||
func (e AgentStartEvent) Type() EventType { return AgentStart }
|
||||
|
||||
// AgentEndEvent fires when the agent finishes responding.
|
||||
// AgentEndEvent fires when the agent finishes responding. In addition to the
|
||||
// final response and stop reason, the event carries per-turn aggregates so
|
||||
// observer-style extensions don't have to maintain parallel bookkeeping in
|
||||
// OnToolResult / OnStepFinish handlers.
|
||||
type AgentEndEvent struct {
|
||||
Response string
|
||||
StopReason string // "completed", "cancelled", "error"
|
||||
|
||||
// ToolCallCount is the total number of tool invocations observed during
|
||||
// this turn (sum across all steps).
|
||||
ToolCallCount int
|
||||
|
||||
// ToolNames lists the tool names invoked during this turn, in call order.
|
||||
// Duplicates are preserved (e.g. two bash calls produce ["bash", "bash"]).
|
||||
ToolNames []string
|
||||
|
||||
// LLMCallCount is the number of LLM round-trips (tool-loop iterations)
|
||||
// performed during this turn. Always >= 1 for a successful turn.
|
||||
LLMCallCount int
|
||||
|
||||
// InputTokensDelta is the sum of input tokens consumed during this turn
|
||||
// across every LLM call (including cache-hit input tokens).
|
||||
InputTokensDelta int
|
||||
|
||||
// OutputTokensDelta is the sum of output tokens generated during this turn.
|
||||
OutputTokensDelta int
|
||||
|
||||
// CacheReadTokensDelta is the sum of cache-read tokens during this turn.
|
||||
CacheReadTokensDelta int
|
||||
|
||||
// CacheWriteTokensDelta is the sum of cache-write tokens during this turn.
|
||||
CacheWriteTokensDelta int
|
||||
|
||||
// CostDelta is the total cost in USD attributable to this turn. Computed
|
||||
// from per-step usage and current model pricing. Zero when pricing is
|
||||
// unknown or OAuth credentials are in use.
|
||||
CostDelta float64
|
||||
|
||||
// DurationMs is the elapsed wall-clock time from AgentStart to AgentEnd,
|
||||
// in milliseconds.
|
||||
DurationMs int64
|
||||
}
|
||||
|
||||
func (e AgentEndEvent) Type() EventType { return AgentEnd }
|
||||
@@ -2403,6 +2500,43 @@ type PrepareStepResult struct {
|
||||
|
||||
func (PrepareStepResult) isResult() {}
|
||||
|
||||
// LLMUsageEvent fires after each LLM provider call with the per-call token
|
||||
// and cost deltas. Use this for accurate budget tracking, cost dashboards,
|
||||
// and any logic that needs to react between LLM calls within a single agent
|
||||
// turn (rather than only at turn boundaries).
|
||||
//
|
||||
// A single agent turn typically produces multiple LLMUsageEvents (one per
|
||||
// tool-loop iteration). The Model and Provider fields reflect the model used
|
||||
// for that specific call, which may differ from earlier calls if the
|
||||
// extension switched models mid-turn via ctx.SetModel().
|
||||
type LLMUsageEvent struct {
|
||||
// InputTokens is the number of input tokens for this call.
|
||||
InputTokens int
|
||||
// OutputTokens is the number of output tokens generated by this call.
|
||||
OutputTokens int
|
||||
// CacheReadTokens is the number of cache-hit input tokens (provider-specific).
|
||||
CacheReadTokens int
|
||||
// CacheWriteTokens is the number of cache-write tokens.
|
||||
CacheWriteTokens int
|
||||
// Cost is the USD cost of this call computed from the model's per-token
|
||||
// pricing. Zero when pricing is unknown or OAuth credentials are in use.
|
||||
Cost float64
|
||||
// Model is the model identifier used for this call (e.g. "claude-sonnet-4-5-20250929").
|
||||
Model string
|
||||
// Provider is the provider identifier (e.g. "anthropic", "openai").
|
||||
Provider string
|
||||
// RequestID is an optional correlation id for the underlying provider
|
||||
// call. May be empty when the provider does not surface one.
|
||||
RequestID string
|
||||
// StepNumber is the zero-based step index within the current agent turn.
|
||||
StepNumber int
|
||||
// FinishReason mirrors the provider's finish reason for this call
|
||||
// (e.g. "stop", "tool_calls", "length"). May be empty.
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
func (e LLMUsageEvent) Type() EventType { return LLMUsage }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -125,6 +125,11 @@ const (
|
||||
// after steering messages are injected and before messages are sent
|
||||
// to the LLM. Handlers can replace the context window for this step.
|
||||
PrepareStep EventType = "prepare_step"
|
||||
|
||||
// LLMUsage fires after each LLM provider call with the token and cost
|
||||
// deltas for that single call. Extensions use it to attribute usage to
|
||||
// specific calls/models and to drive budget enforcement between calls.
|
||||
LLMUsage EventType = "llm_usage"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -139,7 +144,7 @@ func AllEventTypes() []EventType {
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
|
||||
PrepareStep,
|
||||
PrepareStep, LLMUsage,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 32 {
|
||||
t.Fatalf("expected 32 event types, got %d", len(all))
|
||||
if len(all) != 33 {
|
||||
t.Fatalf("expected 33 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package extensions
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRunner_EmitLLMUsage(t *testing.T) {
|
||||
var got LLMUsageEvent
|
||||
var called bool
|
||||
ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{
|
||||
LLMUsage: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(LLMUsageEvent)
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(LLMUsageEvent{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
Cost: 0.0012,
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
Provider: "anthropic",
|
||||
StepNumber: 2,
|
||||
FinishReason: "tool_calls",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected LLMUsage handler to be called")
|
||||
}
|
||||
if got.InputTokens != 100 || got.OutputTokens != 50 {
|
||||
t.Errorf("token fields not propagated: %+v", got)
|
||||
}
|
||||
if got.Cost != 0.0012 {
|
||||
t.Errorf("cost not propagated, got %v", got.Cost)
|
||||
}
|
||||
if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" {
|
||||
t.Errorf("model/provider not propagated: %+v", got)
|
||||
}
|
||||
if got.StepNumber != 2 || got.FinishReason != "tool_calls" {
|
||||
t.Errorf("step/finish reason not propagated: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) {
|
||||
// Verify NewTestAPI wires up onLLMUsage so the extension can call
|
||||
// api.OnLLMUsage during Init.
|
||||
ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)}
|
||||
api := NewTestAPI(ext)
|
||||
|
||||
var calls int
|
||||
api.OnLLMUsage(func(e LLMUsageEvent, c Context) {
|
||||
calls++
|
||||
})
|
||||
|
||||
if len(ext.Handlers[LLMUsage]) != 1 {
|
||||
t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage]))
|
||||
}
|
||||
|
||||
r := makeRunner(*ext)
|
||||
_, _ = r.Emit(LLMUsageEvent{InputTokens: 1})
|
||||
if calls != 1 {
|
||||
t.Errorf("expected handler called once, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentEndEvent_EnrichedFields(t *testing.T) {
|
||||
// Verify the enriched event carries through Emit without mangling.
|
||||
var got AgentEndEvent
|
||||
ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{
|
||||
AgentEnd: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(AgentEndEvent)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(AgentEndEvent{
|
||||
Response: "done",
|
||||
StopReason: "completed",
|
||||
ToolCallCount: 3,
|
||||
ToolNames: []string{"bash", "read", "bash"},
|
||||
LLMCallCount: 4,
|
||||
InputTokensDelta: 1500,
|
||||
OutputTokensDelta: 400,
|
||||
CacheReadTokensDelta: 200,
|
||||
CacheWriteTokensDelta: 100,
|
||||
CostDelta: 0.0123,
|
||||
DurationMs: 2500,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if got.ToolCallCount != 3 {
|
||||
t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount)
|
||||
}
|
||||
if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" {
|
||||
t.Errorf("ToolNames: %v", got.ToolNames)
|
||||
}
|
||||
if got.LLMCallCount != 4 {
|
||||
t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount)
|
||||
}
|
||||
if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 {
|
||||
t.Errorf("token deltas: %+v", got)
|
||||
}
|
||||
if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 {
|
||||
t.Errorf("cache deltas: %+v", got)
|
||||
}
|
||||
if got.CostDelta != 0.0123 {
|
||||
t.Errorf("CostDelta: got %v", got.CostDelta)
|
||||
}
|
||||
if got.DurationMs != 2500 {
|
||||
t.Errorf("DurationMs: got %d", got.DurationMs)
|
||||
}
|
||||
}
|
||||
@@ -669,6 +669,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -2,9 +2,12 @@ package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -99,6 +102,10 @@ type Runner struct {
|
||||
customEventSubs map[string][]func(string) // inter-extension event bus
|
||||
optionOverrides map[string]string // runtime option overrides
|
||||
configStore *viper.Viper // per-instance config store (nil = global)
|
||||
state map[string]string // session-scoped extension state (last-write-wins)
|
||||
stateMu sync.RWMutex // guards state independently of mu
|
||||
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
|
||||
stateSaver func() // optional persistence hook invoked after each state mutation
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -264,6 +271,18 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.SetState == nil {
|
||||
ctx.SetState = func(string, string) {}
|
||||
}
|
||||
if ctx.GetState == nil {
|
||||
ctx.GetState = func(string) (string, bool) { return "", false }
|
||||
}
|
||||
if ctx.DeleteState == nil {
|
||||
ctx.DeleteState = func(string) {}
|
||||
}
|
||||
if ctx.ListState == nil {
|
||||
ctx.ListState = func() []string { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
@@ -745,6 +764,168 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Extension state store (session-scoped, last-write-wins)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SetState records a key-value pair in the runner's session-scoped extension
|
||||
// state store. The store is in-memory; callers wire SetStateSaver to persist
|
||||
// changes to a sidecar file. Thread-safe.
|
||||
//
|
||||
// When a saver is installed, concurrent SetState/DeleteState invocations are
|
||||
// serialized through saverMu so that overlapping snapshot-and-rename writes
|
||||
// cannot interleave (which would otherwise race on the shared tmp file and
|
||||
// risk persisting an older snapshot after a newer one).
|
||||
func (r *Runner) SetState(key, value string) {
|
||||
r.stateMu.Lock()
|
||||
if r.state == nil {
|
||||
r.state = make(map[string]string)
|
||||
}
|
||||
r.state[key] = value
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// GetState returns the value previously stored via SetState, plus a bool
|
||||
// indicating whether the key was present. Thread-safe.
|
||||
func (r *Runner) GetState(key string) (string, bool) {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
v, ok := r.state[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// DeleteState removes a key from the state store. No-op if the key is
|
||||
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
|
||||
// SetState for the rationale.
|
||||
func (r *Runner) DeleteState(key string) {
|
||||
r.stateMu.Lock()
|
||||
_, existed := r.state[key]
|
||||
if existed {
|
||||
delete(r.state, key)
|
||||
}
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
if !existed {
|
||||
return
|
||||
}
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// runSaver invokes the optional persistence callback under saverMu so
|
||||
// concurrent SetState/DeleteState writers cannot race on the shared tmp
|
||||
// file used by SaveStateToFile's atomic rename. The deferred Unlock
|
||||
// guarantees saverMu is released even if the saver panics.
|
||||
func (r *Runner) runSaver(saver func()) {
|
||||
if saver == nil {
|
||||
return
|
||||
}
|
||||
r.saverMu.Lock()
|
||||
defer r.saverMu.Unlock()
|
||||
saver()
|
||||
}
|
||||
|
||||
// ListState returns all keys currently in the state store, in unspecified
|
||||
// order. Thread-safe.
|
||||
func (r *Runner) ListState() []string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(r.state))
|
||||
for k := range r.state {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// SetStateSaver installs an optional persistence hook invoked after each
|
||||
// mutation to the state store (SetState / DeleteState / LoadStateFromFile).
|
||||
// Pass nil to disable persistence. Thread-safe.
|
||||
func (r *Runner) SetStateSaver(saver func()) {
|
||||
r.stateMu.Lock()
|
||||
defer r.stateMu.Unlock()
|
||||
r.stateSaver = saver
|
||||
}
|
||||
|
||||
// SnapshotState returns a copy of the current state store as a
|
||||
// fresh map. Useful for persisting to disk without holding the lock.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SnapshotState() map[string]string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
copyMap := make(map[string]string, len(r.state))
|
||||
maps.Copy(copyMap, r.state)
|
||||
return copyMap
|
||||
}
|
||||
|
||||
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
|
||||
// state store with its contents. Missing or empty files are treated as
|
||||
// "no prior state": the in-memory store is replaced with an empty map so
|
||||
// callers can safely switch sessions without leaking keys from a prior
|
||||
// session into a new one. Malformed JSON returns the parse error without
|
||||
// touching the existing store. Thread-safe.
|
||||
func (r *Runner) LoadStateFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading extension state: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var loaded map[string]string
|
||||
if err := json.Unmarshal(data, &loaded); err != nil {
|
||||
return fmt.Errorf("parsing extension state: %w", err)
|
||||
}
|
||||
r.stateMu.Lock()
|
||||
r.state = loaded
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveStateToFile writes the current state store to path as JSON, creating
|
||||
// parent directories as needed. An empty store writes an empty object so
|
||||
// that consumers can distinguish "loaded but empty" from "never saved".
|
||||
// Writes are atomic via a tmp-file-and-rename sequence. Thread-safe.
|
||||
func (r *Runner) SaveStateToFile(path string) error {
|
||||
snap := r.SnapshotState()
|
||||
if snap == nil {
|
||||
snap = map[string]string{}
|
||||
}
|
||||
data, err := json.MarshalIndent(snap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling extension state: %w", err)
|
||||
}
|
||||
if dir := filepath.Dir(path); dir != "." && dir != "" {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return fmt.Errorf("writing extension state: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("renaming extension state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hot-reload
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -768,7 +949,9 @@ func (r *Runner) Reload(exts []LoadedExtension) {
|
||||
r.uiVisibility = nil
|
||||
r.disabledTools = nil
|
||||
r.customEventSubs = nil
|
||||
// optionOverrides are intentionally preserved.
|
||||
// optionOverrides and state are intentionally preserved across reloads:
|
||||
// they represent user/session intent (not extension code) and would be
|
||||
// surprising to lose on a hot-reload.
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunner_State_BasicSetGetDelete(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
|
||||
if _, ok := r.GetState("missing"); ok {
|
||||
t.Fatal("expected GetState to return ok=false for missing key")
|
||||
}
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("b", "2")
|
||||
r.SetState("a", "3") // last-write-wins
|
||||
|
||||
if v, ok := r.GetState("a"); !ok || v != "3" {
|
||||
t.Errorf("expected GetState(a)=(3,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r.GetState("b"); !ok || v != "2" {
|
||||
t.Errorf("expected GetState(b)=(2,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
|
||||
keys := r.ListState()
|
||||
if len(keys) != 2 {
|
||||
t.Errorf("expected 2 keys, got %d (%v)", len(keys), keys)
|
||||
}
|
||||
|
||||
r.DeleteState("a")
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected key a to be gone after DeleteState")
|
||||
}
|
||||
if len(r.ListState()) != 1 {
|
||||
t.Errorf("expected 1 key after delete, got %v", r.ListState())
|
||||
}
|
||||
|
||||
// Deleting missing key is a no-op.
|
||||
r.DeleteState("never-there")
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverFires(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
r.SetStateSaver(func() {
|
||||
mu.Lock()
|
||||
calls++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("a", "2")
|
||||
r.DeleteState("a")
|
||||
r.DeleteState("a") // missing → no save
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if calls != 3 {
|
||||
t.Errorf("expected saver to fire 3 times (2 sets + 1 delete), got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r1 := NewRunner(nil)
|
||||
r1.SetState("k1", "v1")
|
||||
r1.SetState("k2", `{"json":"value"}`)
|
||||
if err := r1.SaveStateToFile(path); err != nil {
|
||||
t.Fatalf("SaveStateToFile: %v", err)
|
||||
}
|
||||
|
||||
// Verify file contains JSON map.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["k1"] != "v1" || parsed["k2"] != `{"json":"value"}` {
|
||||
t.Errorf("unexpected file contents: %v", parsed)
|
||||
}
|
||||
|
||||
r2 := NewRunner(nil)
|
||||
if err := r2.LoadStateFromFile(path); err != nil {
|
||||
t.Fatalf("LoadStateFromFile: %v", err)
|
||||
}
|
||||
if v, ok := r2.GetState("k1"); !ok || v != "v1" {
|
||||
t.Errorf("expected k1=v1 after load, got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r2.GetState("k2"); !ok || v != `{"json":"value"}` {
|
||||
t.Errorf("expected k2 to round-trip, got %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
|
||||
// LoadStateFromFile is documented to "replace the in-memory state store
|
||||
// with its contents"; for a missing file that means clearing the store.
|
||||
// This is what makes session-switching safe: a new session that has not
|
||||
// yet written a sidecar must not inherit keys from a prior session.
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is missing")
|
||||
}
|
||||
if keys := r.ListState(); keys != nil {
|
||||
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.json")
|
||||
if err := os.WriteFile(path, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(path); err != nil {
|
||||
t.Errorf("expected nil error for empty file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMalformedFileError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad.json")
|
||||
if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
if err := r.LoadStateFromFile(path); err == nil {
|
||||
t.Error("expected error loading malformed JSON, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_PersistenceViaSaver(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r := NewRunner(nil)
|
||||
r.SetStateSaver(func() {
|
||||
_ = r.SaveStateToFile(path)
|
||||
})
|
||||
r.SetState("hello", "world")
|
||||
|
||||
// File should exist with the value already.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["hello"] != "world" {
|
||||
t.Errorf("expected file to contain hello=world, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ConcurrentSet(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 16
|
||||
const iterations = 100
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
r.SetState("k", "v")
|
||||
_, _ = r.GetState("k")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if v, ok := r.GetState("k"); !ok || v != "v" {
|
||||
t.Errorf("expected k=v after concurrent writes, got (%q,%v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ContextNoOpsWhenUnset(t *testing.T) {
|
||||
// Verify normalizeContext installs safe no-ops for SetState/GetState/etc.
|
||||
// when not provided by the caller.
|
||||
ext := makeHandlerExt("state.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// All four state functions should be non-nil and safe to call.
|
||||
c.SetState("a", "b")
|
||||
if v, ok := c.GetState("a"); ok || v != "" {
|
||||
t.Errorf("no-op GetState should return (\"\", false); got (%q,%v)", v, ok)
|
||||
}
|
||||
c.DeleteState("a")
|
||||
if keys := c.ListState(); keys != nil {
|
||||
t.Errorf("no-op ListState should return nil; got %v", keys)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
// SetContext with empty Context to exercise normalizeContext defaults.
|
||||
r.SetContext(Context{})
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverPanicReleasesSaverMu(t *testing.T) {
|
||||
// If the saver callback panics (e.g. disk full mid-write), runSaver
|
||||
// must still release saverMu so subsequent SetState/DeleteState calls
|
||||
// can make progress. Without `defer Unlock()` the lock would be
|
||||
// permanently held and the next write would deadlock.
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
r.SetStateSaver(func() {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
panic("simulated disk-write failure")
|
||||
}
|
||||
})
|
||||
|
||||
// First call panics. Recover, then verify a follow-up call still works
|
||||
// without blocking (proving saverMu was released).
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec == nil {
|
||||
t.Fatal("expected panic from first saver invocation")
|
||||
}
|
||||
}()
|
||||
r.SetState("a", "1")
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.SetState("b", "2") // would deadlock if saverMu were still held
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("SetState after saver panic blocked — saverMu was not released")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Errorf("expected saver to fire twice (panic + recovery write), got %d", calls)
|
||||
}
|
||||
}
|
||||
@@ -183,6 +183,7 @@ func Symbols() interp.Exports {
|
||||
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
|
||||
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
|
||||
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
|
||||
"LLMUsageEvent": reflect.ValueOf((*LLMUsageEvent)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,5 +189,11 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user