From 49f8b485be3733382a660e575c2f654479cb631a Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Tue, 9 Jun 2026 16:18:10 +0300 Subject: [PATCH] feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53) (#54) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53) Three additive primitives to the extension API: - OnLLMUsage event: per-LLM-call token + cost deltas attributed to the specific model/provider used for each round-trip. Derived from the SDK StepFinishEvent in the extension bridge. Enables accurate budget enforcement between calls instead of only at turn boundaries. - ctx.SetState / GetState / DeleteState / ListState: session-scoped, last-write-wins key-value store backed by a sidecar file (.ext-state.json) outside the conversation tree. Reads are O(1), writes don't grow the JSONL, and the store is not duplicated on fork. State is preserved across hot-reloads. - Enriched AgentEndEvent: ToolCallCount, ToolNames, LLMCallCount, token deltas (input/output/cache-read/cache-write), CostDelta, and DurationMs populated by a per-turn aggregator. Existing handlers reading only Response/StopReason are unaffected. Includes unit tests for the state store, LLMUsage registration, enriched AgentEndEvent, turn aggregator, llmUsageMeta, and sidecar path derivation. Adds examples/extensions/usage-budget.go demoing all three primitives together. Documents the additions in README, the docs site (extensions overview, capabilities, examples), and the kit-extensions and kit-sdk skill guides. Fixes #53 * fix(extensions): address review feedback on state store and llmUsageMeta - Serialize SetState/DeleteState saver invocations through a new saverMu so overlapping atomic-rename writes can no longer race on the shared .tmp file and persist an older snapshot after a newer one. - LoadStateFromFile now clears the in-memory store when the sidecar is missing or empty, matching the documented "replace … with its contents" contract. This makes session-switching safe by preventing keys from a prior session leaking into a new one. Tests updated to cover both the missing-file and empty-file cases. - llmUsageMeta now detects Anthropic OAuth credentials and returns Cost=0, matching the comment and the existing usage_tracker behavior for OAuth users. Mirrors the OAuth detection already used in cmd/extension_context.go. - Document the single-in-flight-turn assumption baked into the per-turn aggregator with a clear migration path (per-turn ID) for if concurrent turns ever become a supported use case. * fix(extensions): release saverMu on panic in state store Extract a runSaver helper that locks saverMu and defers Unlock before invoking the persistence callback. Without the deferred Unlock, a panic inside the saver (e.g. disk full mid-write) would leave saverMu held forever and deadlock the next SetState/DeleteState. Both SetState and DeleteState now route through the helper. New TestRunner_State_Saver PanicReleasesSaverMu reproduces the deadlock window with a 2s deadline and proves the mutex is released after a panic. --- README.md | 6 +- cmd/extension_context.go | 12 ++ cmd/root.go | 3 + examples/extensions/README.md | 1 + examples/extensions/usage-budget.go | 87 +++++++++ internal/extensions/api.go | 136 +++++++++++++- internal/extensions/events.go | 7 +- internal/extensions/events_test.go | 4 +- internal/extensions/llmusage_test.go | 119 ++++++++++++ internal/extensions/loader.go | 6 + internal/extensions/runner.go | 185 ++++++++++++++++++- internal/extensions/state_test.go | 262 +++++++++++++++++++++++++++ internal/extensions/symbols.go | 1 + internal/extensions/test_api.go | 6 + pkg/kit/extension_api.go | 80 ++++++++ pkg/kit/extensions_bridge.go | 233 +++++++++++++++++++++++- pkg/kit/extensions_bridge_test.go | 140 ++++++++++++++ skills/kit-extensions/SKILL.md | 59 +++++- skills/kit-sdk/SKILL.md | 13 ++ www/pages/extensions/capabilities.md | 80 +++++++- www/pages/extensions/examples.md | 1 + www/pages/extensions/overview.md | 3 +- 22 files changed, 1429 insertions(+), 15 deletions(-) create mode 100644 examples/extensions/usage-budget.go create mode 100644 internal/extensions/llmusage_test.go create mode 100644 internal/extensions/state_test.go create mode 100644 pkg/kit/extensions_bridge_test.go diff --git a/README.md b/README.md index bbefbb8f..9f8f354a 100644 --- a/README.md +++ b/README.md @@ -312,12 +312,15 @@ kit -e examples/extensions/minimal.go ### Extension Capabilities -**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd +**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnLLMUsage, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd + +`OnAgentEnd` carries per-turn aggregates (`ToolCallCount`, `ToolNames`, `LLMCallCount`, `InputTokensDelta`, `OutputTokensDelta`, `CostDelta`, `DurationMs`) so observers don't need to maintain parallel bookkeeping. `OnLLMUsage` fires after each LLM provider call with token + cost deltas attributed to that specific call/model — use it for accurate budget enforcement *between* calls instead of waiting for the turn to finish. **Custom Components**: - **Tools**: Add new tools the LLM can invoke - **Commands**: Register slash commands (e.g., `/mycommand`) - **Options**: Register configurable extension options +- **Session State**: Last-write-wins key-value store via `ctx.SetState` / `GetState` / `DeleteState` / `ListState`, persisted to a per-session sidecar file outside the conversation tree - **Widgets**: Persistent status displays above/below input - **Headers/Footers**: Persistent content above/below the conversation - **Status Bar**: Custom status bar entries @@ -373,6 +376,7 @@ See the `examples/extensions/` directory: - [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls - [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching - [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering +- [`usage-budget.go`](examples/extensions/usage-budget.go) - Per-call usage callback (`OnLLMUsage`), session state, and enriched `OnAgentEnd` per-turn report - [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits. diff --git a/cmd/extension_context.go b/cmd/extension_context.go index 8fd75268..e38002bf 100644 --- a/cmd/extension_context.go +++ b/cmd/extension_context.go @@ -190,6 +190,18 @@ func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Cont GetEntries: func(entryType string) []extensions.ExtensionEntry { return kitInstance.Extensions().GetEntries(entryType) }, + SetState: func(key string, value string) { + kitInstance.Extensions().SetState(key, value) + }, + GetState: func(key string) (string, bool) { + return kitInstance.Extensions().GetState(key) + }, + DeleteState: func(key string) { + kitInstance.Extensions().DeleteState(key) + }, + ListState: func() []string { + return kitInstance.Extensions().ListState() + }, SetEditorText: func(text string) { appInstance.SetEditorTextFromExtension(text) }, diff --git a/cmd/root.go b/cmd/root.go index 7e4bffb9..99d81c63 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -931,6 +931,9 @@ func runNormalMode(ctx context.Context) error { startupExtensionMessages = append(startupExtensionMessages, text) } kitInstance.Extensions().SetContext(extCtx) + if err := kitInstance.Extensions().InitStatePersistence(); err != nil { + log.Printf("WARN extension state init failed: %v", err) + } kitInstance.Extensions().EmitSessionStart() // Restore normal print functions for runtime use. diff --git a/examples/extensions/README.md b/examples/extensions/README.md index c1e60f05..40465572 100644 --- a/examples/extensions/README.md +++ b/examples/extensions/README.md @@ -58,6 +58,7 @@ kit install github.com/mark3labs/kit/examples/extensions --local | `project-rules.go` | Project-specific rules | Session data, file reading | | `protected-paths.go` | Block dangerous operations | `OnToolCall` with blocking | | `permission-gate.go` | Confirm destructive actions | `OnToolCall` with confirmation | +| `usage-budget.go` | Soft cost cap + per-turn report | `OnLLMUsage`, `SetState`/`GetState`, enriched `AgentEndEvent` | ### Tools & Commands diff --git a/examples/extensions/usage-budget.go b/examples/extensions/usage-budget.go new file mode 100644 index 00000000..56b2b6ab --- /dev/null +++ b/examples/extensions/usage-budget.go @@ -0,0 +1,87 @@ +//go:build ignore + +package main + +import ( + "fmt" + "strconv" + + "kit/ext" +) + +// Init demonstrates the three primitives added in issue #53: +// +// 1. api.OnLLMUsage(...) — per-LLM-call usage callback with token + cost +// deltas. Use this for budget enforcement that reacts between calls +// within a single agent turn, rather than only at turn boundaries. +// +// 2. ctx.SetState / ctx.GetState / ctx.DeleteState / ctx.ListState — +// last-write-wins, session-scoped key-value store backed by a sidecar +// file. Use this for snapshot state (current value of X) instead of +// ctx.AppendEntry, which is append-only and bloats branch reads. +// +// 3. ext.AgentEndEvent.ToolCallCount / .ToolNames / .LLMCallCount / +// .InputTokensDelta / .OutputTokensDelta / .CostDelta / .DurationMs — +// per-turn aggregates so observer extensions don't need to maintain +// parallel bookkeeping. +// +// Together these support a simple soft-budget cap: warn when the +// cumulative cost in this session exceeds a threshold, and print a +// per-turn report on AgentEnd. +// +// Usage: kit -e examples/extensions/usage-budget.go +func Init(api ext.API) { + const warnAtKey = "usage-budget:warn-at-usd" + + // 1. Print per-LLM-call usage with provider, model, and cost. + api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) { + ctx.Print(fmt.Sprintf( + "[usage] step=%d %s/%s tokens=↑%d ↓%d cache=↑%d/↓%d cost=$%.4f (%s)", + e.StepNumber, e.Provider, e.Model, + e.InputTokens, e.OutputTokens, + e.CacheWriteTokens, e.CacheReadTokens, + e.Cost, e.FinishReason, + )) + + // 2. Persist running total in last-write-wins state. + current := 0.0 + if raw, ok := ctx.GetState("usage-budget:total-cost"); ok { + current, _ = strconv.ParseFloat(raw, 64) + } + current += e.Cost + ctx.SetState("usage-budget:total-cost", strconv.FormatFloat(current, 'f', 6, 64)) + + // Soft warn-at threshold (configurable via state). + warnAt := 0.50 + if raw, ok := ctx.GetState(warnAtKey); ok { + if v, err := strconv.ParseFloat(raw, 64); err == nil { + warnAt = v + } + } + if current > warnAt { + ctx.PrintError(fmt.Sprintf( + "[usage] session cost $%.4f exceeds soft cap $%.2f", + current, warnAt, + )) + } + }) + + // 3. Print a per-turn summary using the enriched AgentEndEvent. + api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) { + ctx.Print(fmt.Sprintf( + "[turn] stop=%s tools=%d llm-calls=%d tokens=↑%d ↓%d cost=$%.4f duration=%dms", + e.StopReason, e.ToolCallCount, e.LLMCallCount, + e.InputTokensDelta, e.OutputTokensDelta, e.CostDelta, e.DurationMs, + )) + if len(e.ToolNames) > 0 { + ctx.Print(fmt.Sprintf("[turn] tool order: %v", e.ToolNames)) + } + }) + + // Bootstrap default soft cap once per session. + api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) { + if _, ok := ctx.GetState(warnAtKey); !ok { + ctx.SetState(warnAtKey, "0.50") + } + }) +} diff --git a/internal/extensions/api.go b/internal/extensions/api.go index 16d3e377..ac4b9132 100644 --- a/internal/extensions/api.go +++ b/internal/extensions/api.go @@ -341,6 +341,13 @@ type Context struct { // The data survives across session restarts and can be retrieved via // GetEntries. Use entryType to namespace your data (e.g. "myext:state"). // + // AppendEntry is append-only and lives in the conversation tree, which + // makes it the right tool for audit logs and event histories. For + // last-write-wins snapshot state — "what's the current value of X?" — + // prefer SetState / GetState instead. Those primitives store data in a + // sidecar file outside the conversation tree, are O(1) to read/write, + // and do not bloat branch reads or duplicate on fork. + // // Example: // // data, _ := json.Marshal(myState) @@ -360,6 +367,45 @@ type Context struct { // } GetEntries func(entryType string) []ExtensionEntry + // SetState stores a key-value pair in session-scoped, last-write-wins + // extension state. Unlike AppendEntry the value is kept in a sidecar + // file outside the conversation tree, so: + // - reads are O(1) (no branch walk) + // - writes don't bloat the session JSONL + // - state is not duplicated on fork (branches share the sidecar) + // - state is invisible to the LLM + // + // Use SetState for snapshot state ("current value of X"); use + // AppendEntry for audit logs and event histories. Namespace keys with + // your extension name to avoid collisions (e.g. "myext:budget-cap"). + // + // State persists for the lifetime of the session. For ephemeral or + // in-memory sessions the state lives only in memory. + // + // Example: + // + // ctx.SetState("myext:budget-cap", "10.00") + SetState func(key string, value string) + + // GetState returns the value previously stored via SetState. The bool + // is false when the key was never written. Returns ("", false) when + // state is unavailable. + // + // Example: + // + // if cap, ok := ctx.GetState("myext:budget-cap"); ok { + // fmt.Println("current cap:", cap) + // } + GetState func(key string) (string, bool) + + // DeleteState removes a key from session-scoped extension state. + // No-op when the key is missing. + DeleteState func(key string) + + // ListState returns all keys currently stored in session-scoped + // extension state, in unspecified order. + ListState func() []string + // SetEditorText sets the text content of the input editor. This can // be used to pre-fill the editor with suggested text (e.g. extracted // questions, handoff prompts). The cursor is moved to the end. @@ -1102,6 +1148,7 @@ type API struct { onError func(func(ErrorEvent, Context)) onRetry func(func(RetryEvent, Context)) onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult) + onLLMUsage func(func(LLMUsageEvent, Context)) } // OnToolCall registers a handler that fires before a tool executes. @@ -1359,6 +1406,19 @@ func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStep a.onPrepareStep(handler) } +// OnLLMUsage registers a handler that fires after each LLM provider call +// with the token and cost deltas for that single call. Use this for +// per-call usage attribution, real-time budget enforcement, and cost +// dashboards that need to react between calls within a single agent turn. +// +// Handlers receive an LLMUsageEvent describing the call's input/output +// tokens, cache tokens, computed cost, model, and provider. A single agent +// turn typically fires multiple LLMUsageEvents (one per tool-loop +// iteration). +func (a *API) OnLLMUsage(handler func(LLMUsageEvent, Context)) { + a.onLLMUsage(handler) +} + // RegisterToolRenderer registers a custom renderer for a specific tool's // display in the TUI. The renderer controls the header (parameter summary) // and/or body (result display) of the tool's output block. If multiple @@ -2091,10 +2151,47 @@ type AgentStartEvent struct { func (e AgentStartEvent) Type() EventType { return AgentStart } -// AgentEndEvent fires when the agent finishes responding. +// AgentEndEvent fires when the agent finishes responding. In addition to the +// final response and stop reason, the event carries per-turn aggregates so +// observer-style extensions don't have to maintain parallel bookkeeping in +// OnToolResult / OnStepFinish handlers. type AgentEndEvent struct { Response string StopReason string // "completed", "cancelled", "error" + + // ToolCallCount is the total number of tool invocations observed during + // this turn (sum across all steps). + ToolCallCount int + + // ToolNames lists the tool names invoked during this turn, in call order. + // Duplicates are preserved (e.g. two bash calls produce ["bash", "bash"]). + ToolNames []string + + // LLMCallCount is the number of LLM round-trips (tool-loop iterations) + // performed during this turn. Always >= 1 for a successful turn. + LLMCallCount int + + // InputTokensDelta is the sum of input tokens consumed during this turn + // across every LLM call (including cache-hit input tokens). + InputTokensDelta int + + // OutputTokensDelta is the sum of output tokens generated during this turn. + OutputTokensDelta int + + // CacheReadTokensDelta is the sum of cache-read tokens during this turn. + CacheReadTokensDelta int + + // CacheWriteTokensDelta is the sum of cache-write tokens during this turn. + CacheWriteTokensDelta int + + // CostDelta is the total cost in USD attributable to this turn. Computed + // from per-step usage and current model pricing. Zero when pricing is + // unknown or OAuth credentials are in use. + CostDelta float64 + + // DurationMs is the elapsed wall-clock time from AgentStart to AgentEnd, + // in milliseconds. + DurationMs int64 } func (e AgentEndEvent) Type() EventType { return AgentEnd } @@ -2403,6 +2500,43 @@ type PrepareStepResult struct { func (PrepareStepResult) isResult() {} +// LLMUsageEvent fires after each LLM provider call with the per-call token +// and cost deltas. Use this for accurate budget tracking, cost dashboards, +// and any logic that needs to react between LLM calls within a single agent +// turn (rather than only at turn boundaries). +// +// A single agent turn typically produces multiple LLMUsageEvents (one per +// tool-loop iteration). The Model and Provider fields reflect the model used +// for that specific call, which may differ from earlier calls if the +// extension switched models mid-turn via ctx.SetModel(). +type LLMUsageEvent struct { + // InputTokens is the number of input tokens for this call. + InputTokens int + // OutputTokens is the number of output tokens generated by this call. + OutputTokens int + // CacheReadTokens is the number of cache-hit input tokens (provider-specific). + CacheReadTokens int + // CacheWriteTokens is the number of cache-write tokens. + CacheWriteTokens int + // Cost is the USD cost of this call computed from the model's per-token + // pricing. Zero when pricing is unknown or OAuth credentials are in use. + Cost float64 + // Model is the model identifier used for this call (e.g. "claude-sonnet-4-5-20250929"). + Model string + // Provider is the provider identifier (e.g. "anthropic", "openai"). + Provider string + // RequestID is an optional correlation id for the underlying provider + // call. May be empty when the provider does not surface one. + RequestID string + // StepNumber is the zero-based step index within the current agent turn. + StepNumber int + // FinishReason mirrors the provider's finish reason for this call + // (e.g. "stop", "tool_calls", "length"). May be empty. + FinishReason string +} + +func (e LLMUsageEvent) Type() EventType { return LLMUsage } + // ThemeColor is an adaptive color pair with light and dark hex values. // Either field may be empty to inherit from the default theme. type ThemeColor struct { diff --git a/internal/extensions/events.go b/internal/extensions/events.go index c88a26a7..ac191051 100644 --- a/internal/extensions/events.go +++ b/internal/extensions/events.go @@ -125,6 +125,11 @@ const ( // after steering messages are injected and before messages are sent // to the LLM. Handlers can replace the context window for this step. PrepareStep EventType = "prepare_step" + + // LLMUsage fires after each LLM provider call with the token and cost + // deltas for that single call. Extensions use it to attribute usage to + // specific calls/models and to drive budget enforcement between calls. + LLMUsage EventType = "llm_usage" ) // AllEventTypes returns every supported event type. @@ -139,7 +144,7 @@ func AllEventTypes() []EventType { BeforeFork, BeforeSessionSwitch, BeforeCompact, SubagentStart, SubagentChunk, SubagentEnd, StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry, - PrepareStep, + PrepareStep, LLMUsage, } } diff --git a/internal/extensions/events_test.go b/internal/extensions/events_test.go index 8ef9e791..9faa491a 100644 --- a/internal/extensions/events_test.go +++ b/internal/extensions/events_test.go @@ -4,8 +4,8 @@ import "testing" func TestAllEventTypes_Count(t *testing.T) { all := AllEventTypes() - if len(all) != 32 { - t.Fatalf("expected 32 event types, got %d", len(all)) + if len(all) != 33 { + t.Fatalf("expected 33 event types, got %d", len(all)) } } diff --git a/internal/extensions/llmusage_test.go b/internal/extensions/llmusage_test.go new file mode 100644 index 00000000..e0f6f2ae --- /dev/null +++ b/internal/extensions/llmusage_test.go @@ -0,0 +1,119 @@ +package extensions + +import "testing" + +func TestRunner_EmitLLMUsage(t *testing.T) { + var got LLMUsageEvent + var called bool + ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{ + LLMUsage: { + func(e Event, c Context) Result { + got = e.(LLMUsageEvent) + called = true + return nil + }, + }, + }) + + r := makeRunner(ext) + _, err := r.Emit(LLMUsageEvent{ + InputTokens: 100, + OutputTokens: 50, + Cost: 0.0012, + Model: "claude-sonnet-4-5-20250929", + Provider: "anthropic", + StepNumber: 2, + FinishReason: "tool_calls", + }) + if err != nil { + t.Fatalf("emit: %v", err) + } + if !called { + t.Fatal("expected LLMUsage handler to be called") + } + if got.InputTokens != 100 || got.OutputTokens != 50 { + t.Errorf("token fields not propagated: %+v", got) + } + if got.Cost != 0.0012 { + t.Errorf("cost not propagated, got %v", got.Cost) + } + if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" { + t.Errorf("model/provider not propagated: %+v", got) + } + if got.StepNumber != 2 || got.FinishReason != "tool_calls" { + t.Errorf("step/finish reason not propagated: %+v", got) + } +} + +func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) { + // Verify NewTestAPI wires up onLLMUsage so the extension can call + // api.OnLLMUsage during Init. + ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)} + api := NewTestAPI(ext) + + var calls int + api.OnLLMUsage(func(e LLMUsageEvent, c Context) { + calls++ + }) + + if len(ext.Handlers[LLMUsage]) != 1 { + t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage])) + } + + r := makeRunner(*ext) + _, _ = r.Emit(LLMUsageEvent{InputTokens: 1}) + if calls != 1 { + t.Errorf("expected handler called once, got %d", calls) + } +} + +func TestAgentEndEvent_EnrichedFields(t *testing.T) { + // Verify the enriched event carries through Emit without mangling. + var got AgentEndEvent + ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{ + AgentEnd: { + func(e Event, c Context) Result { + got = e.(AgentEndEvent) + return nil + }, + }, + }) + r := makeRunner(ext) + _, err := r.Emit(AgentEndEvent{ + Response: "done", + StopReason: "completed", + ToolCallCount: 3, + ToolNames: []string{"bash", "read", "bash"}, + LLMCallCount: 4, + InputTokensDelta: 1500, + OutputTokensDelta: 400, + CacheReadTokensDelta: 200, + CacheWriteTokensDelta: 100, + CostDelta: 0.0123, + DurationMs: 2500, + }) + if err != nil { + t.Fatalf("emit: %v", err) + } + if got.ToolCallCount != 3 { + t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount) + } + if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" { + t.Errorf("ToolNames: %v", got.ToolNames) + } + if got.LLMCallCount != 4 { + t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount) + } + if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 { + t.Errorf("token deltas: %+v", got) + } + if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 { + t.Errorf("cache deltas: %+v", got) + } + if got.CostDelta != 0.0123 { + t.Errorf("CostDelta: got %v", got.CostDelta) + } + if got.DurationMs != 2500 { + t.Errorf("DurationMs: got %d", got.DurationMs) + } +} diff --git a/internal/extensions/loader.go b/internal/extensions/loader.go index ab68e86f..dcec3777 100644 --- a/internal/extensions/loader.go +++ b/internal/extensions/loader.go @@ -669,6 +669,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) { return *r }) }, + onLLMUsage: func(h func(LLMUsageEvent, Context)) { + reg(LLMUsage, func(e Event, c Context) Result { + h(e.(LLMUsageEvent), c) + return nil + }) + }, } // Call Init — the extension registers its handlers, tools, commands. diff --git a/internal/extensions/runner.go b/internal/extensions/runner.go index 919cfbf5..a48888a6 100644 --- a/internal/extensions/runner.go +++ b/internal/extensions/runner.go @@ -2,9 +2,12 @@ package extensions import ( "bytes" + "encoding/json" "fmt" "log" + "maps" "os" + "path/filepath" "runtime" "sort" "strconv" @@ -99,6 +102,10 @@ type Runner struct { customEventSubs map[string][]func(string) // inter-extension event bus optionOverrides map[string]string // runtime option overrides configStore *viper.Viper // per-instance config store (nil = global) + state map[string]string // session-scoped extension state (last-write-wins) + stateMu sync.RWMutex // guards state independently of mu + saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave + stateSaver func() // optional persistence hook invoked after each state mutation mu sync.RWMutex } @@ -264,6 +271,18 @@ func normalizeContext(ctx Context) Context { if ctx.GetEntries == nil { ctx.GetEntries = func(string) []ExtensionEntry { return nil } } + if ctx.SetState == nil { + ctx.SetState = func(string, string) {} + } + if ctx.GetState == nil { + ctx.GetState = func(string) (string, bool) { return "", false } + } + if ctx.DeleteState == nil { + ctx.DeleteState = func(string) {} + } + if ctx.ListState == nil { + ctx.ListState = func() []string { return nil } + } if ctx.GetOption == nil { ctx.GetOption = func(string) string { return "" } } @@ -745,6 +764,168 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig { return nil } +// --------------------------------------------------------------------------- +// Extension state store (session-scoped, last-write-wins) +// --------------------------------------------------------------------------- + +// SetState records a key-value pair in the runner's session-scoped extension +// state store. The store is in-memory; callers wire SetStateSaver to persist +// changes to a sidecar file. Thread-safe. +// +// When a saver is installed, concurrent SetState/DeleteState invocations are +// serialized through saverMu so that overlapping snapshot-and-rename writes +// cannot interleave (which would otherwise race on the shared tmp file and +// risk persisting an older snapshot after a newer one). +func (r *Runner) SetState(key, value string) { + r.stateMu.Lock() + if r.state == nil { + r.state = make(map[string]string) + } + r.state[key] = value + saver := r.stateSaver + r.stateMu.Unlock() + r.runSaver(saver) +} + +// GetState returns the value previously stored via SetState, plus a bool +// indicating whether the key was present. Thread-safe. +func (r *Runner) GetState(key string) (string, bool) { + r.stateMu.RLock() + defer r.stateMu.RUnlock() + v, ok := r.state[key] + return v, ok +} + +// DeleteState removes a key from the state store. No-op if the key is +// missing. Thread-safe. Saver invocations are serialized via saverMu — see +// SetState for the rationale. +func (r *Runner) DeleteState(key string) { + r.stateMu.Lock() + _, existed := r.state[key] + if existed { + delete(r.state, key) + } + saver := r.stateSaver + r.stateMu.Unlock() + if !existed { + return + } + r.runSaver(saver) +} + +// runSaver invokes the optional persistence callback under saverMu so +// concurrent SetState/DeleteState writers cannot race on the shared tmp +// file used by SaveStateToFile's atomic rename. The deferred Unlock +// guarantees saverMu is released even if the saver panics. +func (r *Runner) runSaver(saver func()) { + if saver == nil { + return + } + r.saverMu.Lock() + defer r.saverMu.Unlock() + saver() +} + +// ListState returns all keys currently in the state store, in unspecified +// order. Thread-safe. +func (r *Runner) ListState() []string { + r.stateMu.RLock() + defer r.stateMu.RUnlock() + if len(r.state) == 0 { + return nil + } + keys := make([]string, 0, len(r.state)) + for k := range r.state { + keys = append(keys, k) + } + return keys +} + +// SetStateSaver installs an optional persistence hook invoked after each +// mutation to the state store (SetState / DeleteState / LoadStateFromFile). +// Pass nil to disable persistence. Thread-safe. +func (r *Runner) SetStateSaver(saver func()) { + r.stateMu.Lock() + defer r.stateMu.Unlock() + r.stateSaver = saver +} + +// SnapshotState returns a copy of the current state store as a +// fresh map. Useful for persisting to disk without holding the lock. +// Thread-safe. +func (r *Runner) SnapshotState() map[string]string { + r.stateMu.RLock() + defer r.stateMu.RUnlock() + if len(r.state) == 0 { + return nil + } + copyMap := make(map[string]string, len(r.state)) + maps.Copy(copyMap, r.state) + return copyMap +} + +// LoadStateFromFile reads a JSON map from path and replaces the in-memory +// state store with its contents. Missing or empty files are treated as +// "no prior state": the in-memory store is replaced with an empty map so +// callers can safely switch sessions without leaking keys from a prior +// session into a new one. Malformed JSON returns the parse error without +// touching the existing store. Thread-safe. +func (r *Runner) LoadStateFromFile(path string) error { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + r.stateMu.Lock() + r.state = map[string]string{} + r.stateMu.Unlock() + return nil + } + return fmt.Errorf("reading extension state: %w", err) + } + if len(data) == 0 { + r.stateMu.Lock() + r.state = map[string]string{} + r.stateMu.Unlock() + return nil + } + var loaded map[string]string + if err := json.Unmarshal(data, &loaded); err != nil { + return fmt.Errorf("parsing extension state: %w", err) + } + r.stateMu.Lock() + r.state = loaded + r.stateMu.Unlock() + return nil +} + +// SaveStateToFile writes the current state store to path as JSON, creating +// parent directories as needed. An empty store writes an empty object so +// that consumers can distinguish "loaded but empty" from "never saved". +// Writes are atomic via a tmp-file-and-rename sequence. Thread-safe. +func (r *Runner) SaveStateToFile(path string) error { + snap := r.SnapshotState() + if snap == nil { + snap = map[string]string{} + } + data, err := json.MarshalIndent(snap, "", " ") + if err != nil { + return fmt.Errorf("marshalling extension state: %w", err) + } + if dir := filepath.Dir(path); dir != "." && dir != "" { + if err := os.MkdirAll(dir, 0o755); err != nil { + return fmt.Errorf("creating state directory: %w", err) + } + } + tmp := path + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return fmt.Errorf("writing extension state: %w", err) + } + if err := os.Rename(tmp, path); err != nil { + _ = os.Remove(tmp) + return fmt.Errorf("renaming extension state: %w", err) + } + return nil +} + // --------------------------------------------------------------------------- // Hot-reload // --------------------------------------------------------------------------- @@ -768,7 +949,9 @@ func (r *Runner) Reload(exts []LoadedExtension) { r.uiVisibility = nil r.disabledTools = nil r.customEventSubs = nil - // optionOverrides are intentionally preserved. + // optionOverrides and state are intentionally preserved across reloads: + // they represent user/session intent (not extension code) and would be + // surprising to lose on a hot-reload. } // --------------------------------------------------------------------------- diff --git a/internal/extensions/state_test.go b/internal/extensions/state_test.go new file mode 100644 index 00000000..3dd0c045 --- /dev/null +++ b/internal/extensions/state_test.go @@ -0,0 +1,262 @@ +package extensions + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func TestRunner_State_BasicSetGetDelete(t *testing.T) { + r := NewRunner(nil) + + if _, ok := r.GetState("missing"); ok { + t.Fatal("expected GetState to return ok=false for missing key") + } + + r.SetState("a", "1") + r.SetState("b", "2") + r.SetState("a", "3") // last-write-wins + + if v, ok := r.GetState("a"); !ok || v != "3" { + t.Errorf("expected GetState(a)=(3,true), got (%q,%v)", v, ok) + } + if v, ok := r.GetState("b"); !ok || v != "2" { + t.Errorf("expected GetState(b)=(2,true), got (%q,%v)", v, ok) + } + + keys := r.ListState() + if len(keys) != 2 { + t.Errorf("expected 2 keys, got %d (%v)", len(keys), keys) + } + + r.DeleteState("a") + if _, ok := r.GetState("a"); ok { + t.Error("expected key a to be gone after DeleteState") + } + if len(r.ListState()) != 1 { + t.Errorf("expected 1 key after delete, got %v", r.ListState()) + } + + // Deleting missing key is a no-op. + r.DeleteState("never-there") +} + +func TestRunner_State_SaverFires(t *testing.T) { + r := NewRunner(nil) + var calls int + var mu sync.Mutex + r.SetStateSaver(func() { + mu.Lock() + calls++ + mu.Unlock() + }) + + r.SetState("a", "1") + r.SetState("a", "2") + r.DeleteState("a") + r.DeleteState("a") // missing → no save + + mu.Lock() + defer mu.Unlock() + if calls != 3 { + t.Errorf("expected saver to fire 3 times (2 sets + 1 delete), got %d", calls) + } +} + +func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "ext-state.json") + + r1 := NewRunner(nil) + r1.SetState("k1", "v1") + r1.SetState("k2", `{"json":"value"}`) + if err := r1.SaveStateToFile(path); err != nil { + t.Fatalf("SaveStateToFile: %v", err) + } + + // Verify file contains JSON map. + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading saved file: %v", err) + } + var parsed map[string]string + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshalling: %v", err) + } + if parsed["k1"] != "v1" || parsed["k2"] != `{"json":"value"}` { + t.Errorf("unexpected file contents: %v", parsed) + } + + r2 := NewRunner(nil) + if err := r2.LoadStateFromFile(path); err != nil { + t.Fatalf("LoadStateFromFile: %v", err) + } + if v, ok := r2.GetState("k1"); !ok || v != "v1" { + t.Errorf("expected k1=v1 after load, got (%q,%v)", v, ok) + } + if v, ok := r2.GetState("k2"); !ok || v != `{"json":"value"}` { + t.Errorf("expected k2 to round-trip, got %q", v) + } +} + +func TestRunner_State_LoadMissingFileClearsState(t *testing.T) { + // LoadStateFromFile is documented to "replace the in-memory state store + // with its contents"; for a missing file that means clearing the store. + // This is what makes session-switching safe: a new session that has not + // yet written a sidecar must not inherit keys from a prior session. + r := NewRunner(nil) + r.SetState("a", "1") + if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil { + t.Errorf("expected nil error for missing file, got %v", err) + } + if _, ok := r.GetState("a"); ok { + t.Error("expected pre-existing state to be cleared when target file is missing") + } + if keys := r.ListState(); keys != nil { + t.Errorf("expected ListState() to be nil after clearing, got %v", keys) + } +} + +func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "empty.json") + if err := os.WriteFile(path, nil, 0o644); err != nil { + t.Fatal(err) + } + r := NewRunner(nil) + r.SetState("a", "1") + if err := r.LoadStateFromFile(path); err != nil { + t.Errorf("expected nil error for empty file, got %v", err) + } + if _, ok := r.GetState("a"); ok { + t.Error("expected pre-existing state to be cleared when target file is empty") + } +} + +func TestRunner_State_LoadMalformedFileError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil { + t.Fatal(err) + } + r := NewRunner(nil) + if err := r.LoadStateFromFile(path); err == nil { + t.Error("expected error loading malformed JSON, got nil") + } +} + +func TestRunner_State_PersistenceViaSaver(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "ext-state.json") + + r := NewRunner(nil) + r.SetStateSaver(func() { + _ = r.SaveStateToFile(path) + }) + r.SetState("hello", "world") + + // File should exist with the value already. + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading saved file: %v", err) + } + var parsed map[string]string + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshalling: %v", err) + } + if parsed["hello"] != "world" { + t.Errorf("expected file to contain hello=world, got %v", parsed) + } +} + +func TestRunner_State_ConcurrentSet(t *testing.T) { + r := NewRunner(nil) + var wg sync.WaitGroup + const goroutines = 16 + const iterations = 100 + wg.Add(goroutines) + for range goroutines { + go func() { + defer wg.Done() + for range iterations { + r.SetState("k", "v") + _, _ = r.GetState("k") + } + }() + } + wg.Wait() + if v, ok := r.GetState("k"); !ok || v != "v" { + t.Errorf("expected k=v after concurrent writes, got (%q,%v)", v, ok) + } +} + +func TestRunner_State_ContextNoOpsWhenUnset(t *testing.T) { + // Verify normalizeContext installs safe no-ops for SetState/GetState/etc. + // when not provided by the caller. + ext := makeHandlerExt("state.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { + // All four state functions should be non-nil and safe to call. + c.SetState("a", "b") + if v, ok := c.GetState("a"); ok || v != "" { + t.Errorf("no-op GetState should return (\"\", false); got (%q,%v)", v, ok) + } + c.DeleteState("a") + if keys := c.ListState(); keys != nil { + t.Errorf("no-op ListState should return nil; got %v", keys) + } + return nil + }, + }, + }) + r := makeRunner(ext) + // SetContext with empty Context to exercise normalizeContext defaults. + r.SetContext(Context{}) + _, err := r.Emit(SessionStartEvent{}) + if err != nil { + t.Fatalf("emit: %v", err) + } +} + +func TestRunner_State_SaverPanicReleasesSaverMu(t *testing.T) { + // If the saver callback panics (e.g. disk full mid-write), runSaver + // must still release saverMu so subsequent SetState/DeleteState calls + // can make progress. Without `defer Unlock()` the lock would be + // permanently held and the next write would deadlock. + r := NewRunner(nil) + var calls int + r.SetStateSaver(func() { + calls++ + if calls == 1 { + panic("simulated disk-write failure") + } + }) + + // First call panics. Recover, then verify a follow-up call still works + // without blocking (proving saverMu was released). + func() { + defer func() { + if rec := recover(); rec == nil { + t.Fatal("expected panic from first saver invocation") + } + }() + r.SetState("a", "1") + }() + + done := make(chan struct{}) + go func() { + r.SetState("b", "2") // would deadlock if saverMu were still held + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("SetState after saver panic blocked — saverMu was not released") + } + if calls != 2 { + t.Errorf("expected saver to fire twice (panic + recovery write), got %d", calls) + } +} diff --git a/internal/extensions/symbols.go b/internal/extensions/symbols.go index 84f5a866..03a812fa 100644 --- a/internal/extensions/symbols.go +++ b/internal/extensions/symbols.go @@ -183,6 +183,7 @@ func Symbols() interp.Exports { "RetryEvent": reflect.ValueOf((*RetryEvent)(nil)), "PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)), "PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)), + "LLMUsageEvent": reflect.ValueOf((*LLMUsageEvent)(nil)), }, } } diff --git a/internal/extensions/test_api.go b/internal/extensions/test_api.go index 222c4cbc..3357618a 100644 --- a/internal/extensions/test_api.go +++ b/internal/extensions/test_api.go @@ -189,5 +189,11 @@ func NewTestAPI(ext *LoadedExtension) API { return nil }) }, + onLLMUsage: func(h func(LLMUsageEvent, Context)) { + reg(LLMUsage, func(e Event, c Context) Result { + h(e.(LLMUsageEvent), c) + return nil + }) + }, } } diff --git a/pkg/kit/extension_api.go b/pkg/kit/extension_api.go index eea9a9d8..679fc3be 100644 --- a/pkg/kit/extension_api.go +++ b/pkg/kit/extension_api.go @@ -2,6 +2,8 @@ package kit import ( "fmt" + "log" + "strings" "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/message" @@ -96,6 +98,23 @@ type ExtensionAPI interface { AppendEntry(extType, data string) (string, error) GetEntries(extType string) []ExtensionEntry + // Session-scoped extension state (last-write-wins key-value store). + // Backed by an in-memory map and (optionally) a sidecar file per session; + // state lives outside the conversation tree and is not visible to the LLM. + SetState(key, value string) + GetState(key string) (string, bool) + DeleteState(key string) + ListState() []string + + // InitStatePersistence loads any existing state from the per-session + // sidecar file and installs a saver hook so that subsequent SetState / + // DeleteState mutations are flushed to disk. Safe to call multiple times; + // repeat calls simply reload and reinstall the saver. + // + // For ephemeral or in-memory sessions (no session file path), the call + // is a no-op and state remains in memory for the lifetime of the runner. + InitStatePersistence() error + // Status bar SetStatus(entry ExtensionStatusBarEntry) RemoveStatus(key string) @@ -332,6 +351,67 @@ func (e *extensionAPI) AppendEntry(extType, data string) (string, error) { return e.kit.session.AppendExtensionData(extType, data) } +func (e *extensionAPI) SetState(key, value string) { + if e.kit.extRunner != nil { + e.kit.extRunner.SetState(key, value) + } +} + +func (e *extensionAPI) GetState(key string) (string, bool) { + if e.kit.extRunner == nil { + return "", false + } + return e.kit.extRunner.GetState(key) +} + +func (e *extensionAPI) DeleteState(key string) { + if e.kit.extRunner != nil { + e.kit.extRunner.DeleteState(key) + } +} + +func (e *extensionAPI) ListState() []string { + if e.kit.extRunner == nil { + return nil + } + return e.kit.extRunner.ListState() +} + +func (e *extensionAPI) InitStatePersistence() error { + if e.kit.extRunner == nil { + return nil + } + path := extStateSidecarPath(e.kit.GetSessionPath()) + if path == "" { + // Ephemeral or in-memory session; no on-disk state. + e.kit.extRunner.SetStateSaver(nil) + return nil + } + if err := e.kit.extRunner.LoadStateFromFile(path); err != nil { + return err + } + runner := e.kit.extRunner + runner.SetStateSaver(func() { + if err := runner.SaveStateToFile(path); err != nil { + log.Printf("WARN extension state save failed: path=%s err=%v", path, err) + } + }) + return nil +} + +// extStateSidecarPath returns the path to the per-session extension state +// sidecar file derived from the session's JSONL path. Returns empty for +// ephemeral / in-memory sessions where no JSONL is being written. +func extStateSidecarPath(sessionPath string) string { + if sessionPath == "" { + return "" + } + if trimmed, ok := strings.CutSuffix(sessionPath, ".jsonl"); ok { + return trimmed + ".ext-state.json" + } + return sessionPath + ".ext-state.json" +} + func (e *extensionAPI) GetEntries(extType string) []ExtensionEntry { if e.kit.session == nil { return nil diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go index 03ac8983..36e069bb 100644 --- a/pkg/kit/extensions_bridge.go +++ b/pkg/kit/extensions_bridge.go @@ -3,8 +3,11 @@ package kit import ( "strings" "sync" + "time" + "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/extensions" + "github.com/mark3labs/kit/internal/models" ) // bridgeExtensions registers extension event handlers as SDK hooks and @@ -19,6 +22,30 @@ import ( // wrapper (internal/extensions/wrapper.go) which composes underneath the SDK // hook wrapper. func (m *Kit) bridgeExtensions(runner *extensions.Runner) { + // Per-turn aggregator: collects tool/LLM/usage signals between AgentStart + // and AgentEnd so the enriched AgentEndEvent can be populated without + // requiring extensions to maintain parallel bookkeeping. + // + // NOTE: this aggregator assumes a single in-flight turn per *Kit instance, + // which is the current contract — runTurn does not serialize callers and + // the SDK's TurnStartEvent/TurnEndEvent do not carry a turn ID, so two + // concurrent Prompt() calls on the same *Kit would clobber the counters. + // All current callers (TUI app layer, CLI runner, SDK examples) serialize + // turns above this layer. If concurrent turns become a supported use case, + // extend TurnStartEvent/TurnEndEvent with a turn ID and key this map per + // turn instead. + turnAgg := &turnAggregator{kit: m} + m.Subscribe(func(e Event) { + switch ev := e.(type) { + case TurnStartEvent: + turnAgg.start() + case ToolResultEvent: + turnAgg.recordTool(ev.ToolName) + case StepFinishEvent: + turnAgg.recordStep(ev.Usage) + } + }) + // --- Interception hooks --- // Extension Input → BeforeTurn hook (high priority, runs first). @@ -109,9 +136,19 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { } else if stopReason == "" { stopReason = "completed" } + agg := turnAgg.consume() _, _ = runner.Emit(extensions.AgentEndEvent{ - Response: response, - StopReason: stopReason, + Response: response, + StopReason: stopReason, + ToolCallCount: agg.toolCallCount, + ToolNames: agg.toolNames, + LLMCallCount: agg.llmCallCount, + InputTokensDelta: agg.inputTokens, + OutputTokensDelta: agg.outputTokens, + CacheReadTokensDelta: agg.cacheReadTokens, + CacheWriteTokensDelta: agg.cacheWriteTokens, + CostDelta: agg.cost, + DurationMs: agg.durationMs(), }) } }) @@ -302,6 +339,32 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { } }) + // LLMUsage: derive per-call usage from StepFinish. Each step corresponds + // to one LLM provider call, so the step's usage is the per-call delta. + // Cost is computed from the current model's pricing (zero when unknown + // or OAuth credentials are in use). RequestID is left empty until the + // SDK surfaces a correlation id from the underlying provider. + if runner.HasHandlers(extensions.LLMUsage) { + m.Subscribe(func(e Event) { + ev, ok := e.(StepFinishEvent) + if !ok { + return + } + provider, modelID, cost := llmUsageMeta(m, ev.Usage) + _, _ = runner.Emit(extensions.LLMUsageEvent{ + InputTokens: int(ev.Usage.InputTokens), + OutputTokens: int(ev.Usage.OutputTokens), + CacheReadTokens: int(ev.Usage.CacheReadTokens), + CacheWriteTokens: int(ev.Usage.CacheCreationTokens), + Cost: cost, + Model: modelID, + Provider: provider, + StepNumber: ev.StepNumber, + FinishReason: ev.FinishReason, + }) + }) + } + bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event { return extensions.ReasoningStartEvent{ID: ev.ID} }) @@ -363,6 +426,172 @@ func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions. }) } +// turnAggregator collects per-turn signals (tool calls, LLM round-trips, token +// usage, wall-clock duration) so that the enriched AgentEndEvent can be +// populated without requiring extensions to maintain parallel bookkeeping. +// +// The aggregator resets on each TurnStartEvent and is consumed (snapshotted + +// reset) on TurnEndEvent. All access is serialized via a mutex because the +// underlying event bus may fan handlers across goroutines in the future. +type turnAggregator struct { + mu sync.Mutex + started time.Time + ended time.Time + toolCallCount int + toolNames []string + llmCallCount int + inputTokens int + outputTokens int + cacheReadTokens int + cacheWriteTokens int + cost float64 + kit *Kit +} + +type turnSnapshot struct { + started time.Time + ended time.Time + toolCallCount int + toolNames []string + llmCallCount int + inputTokens int + outputTokens int + cacheReadTokens int + cacheWriteTokens int + cost float64 +} + +func (s turnSnapshot) durationMs() int64 { + if s.started.IsZero() { + return 0 + } + end := s.ended + if end.IsZero() { + end = time.Now() + } + return end.Sub(s.started).Milliseconds() +} + +// start resets all counters and records the turn's start time. Called from +// the TurnStartEvent subscriber. +func (a *turnAggregator) start() { + a.mu.Lock() + defer a.mu.Unlock() + a.started = time.Now() + a.ended = time.Time{} + a.toolCallCount = 0 + a.toolNames = nil + a.llmCallCount = 0 + a.inputTokens = 0 + a.outputTokens = 0 + a.cacheReadTokens = 0 + a.cacheWriteTokens = 0 + a.cost = 0 +} + +func (a *turnAggregator) recordTool(name string) { + a.mu.Lock() + defer a.mu.Unlock() + a.toolCallCount++ + if name != "" { + a.toolNames = append(a.toolNames, name) + } +} + +func (a *turnAggregator) recordStep(usage LLMUsage) { + a.mu.Lock() + defer a.mu.Unlock() + a.llmCallCount++ + a.inputTokens += int(usage.InputTokens) + a.outputTokens += int(usage.OutputTokens) + a.cacheReadTokens += int(usage.CacheReadTokens) + a.cacheWriteTokens += int(usage.CacheCreationTokens) + if a.kit != nil { + _, _, c := llmUsageMeta(a.kit, usage) + a.cost += c + } +} + +// consume returns a snapshot of the current turn and marks it ended. +// Subsequent start() calls clear the snapshot. +func (a *turnAggregator) consume() turnSnapshot { + a.mu.Lock() + defer a.mu.Unlock() + a.ended = time.Now() + names := a.toolNames + if len(names) > 0 { + copied := make([]string, len(names)) + copy(copied, names) + names = copied + } + return turnSnapshot{ + started: a.started, + ended: a.ended, + toolCallCount: a.toolCallCount, + toolNames: names, + llmCallCount: a.llmCallCount, + inputTokens: a.inputTokens, + outputTokens: a.outputTokens, + cacheReadTokens: a.cacheReadTokens, + cacheWriteTokens: a.cacheWriteTokens, + cost: a.cost, + } +} + +// llmUsageMeta returns the current provider, model id, and computed cost for +// the given usage values using the Kit instance's active model. Cost is zero +// in any of the following cases: +// - the *Kit pointer is nil or has no active model; +// - the model is not in the registry (custom fine-tunes, unknown providers); +// - the model has no pricing fields set; +// - the active credential is an Anthropic OAuth token (matches the +// existing usage_tracker behavior of suppressing cost for OAuth users). +func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float64) { + if m == nil { + return "", "", 0 + } + modelString := m.GetModelString() + if modelString == "" { + return "", "", 0 + } + p, id, err := models.ParseModelString(modelString) + if err != nil { + return "", "", 0 + } + provider, modelID = p, id + info := models.GetGlobalRegistry().LookupModel(provider, modelID) + if info == nil { + return provider, modelID, 0 + } + if isAnthropicOAuth(m, provider) { + return provider, modelID, 0 + } + cost = float64(usage.InputTokens) * info.Cost.Input / 1_000_000 + cost += float64(usage.OutputTokens) * info.Cost.Output / 1_000_000 + if info.Cost.CacheRead != nil { + cost += float64(usage.CacheReadTokens) * (*info.Cost.CacheRead) / 1_000_000 + } + if info.Cost.CacheWrite != nil { + cost += float64(usage.CacheCreationTokens) * (*info.Cost.CacheWrite) / 1_000_000 + } + return provider, modelID, cost +} + +// isAnthropicOAuth reports whether the current Anthropic credential resolves +// to a stored OAuth token (in which case the user is not billed per-token). +// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker +// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage(). +func isAnthropicOAuth(m *Kit, provider string) bool { + if m == nil || provider != "anthropic" { + return false + } + _, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key")) + if err != nil { + return false + } + return strings.HasPrefix(source, "stored OAuth") +} + // llmToContextMessages converts a slice of LLM messages to extension // ContextMessage values, extracting plain text from each message. func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage { diff --git a/pkg/kit/extensions_bridge_test.go b/pkg/kit/extensions_bridge_test.go new file mode 100644 index 00000000..cf0dd976 --- /dev/null +++ b/pkg/kit/extensions_bridge_test.go @@ -0,0 +1,140 @@ +package kit + +import ( + "testing" + "time" +) + +// TestTurnAggregator_BasicLifecycle exercises the per-turn aggregator: +// start → record several tools and steps → consume → snapshot should reflect +// the accumulated counts and zero out for the next turn. +func TestTurnAggregator_BasicLifecycle(t *testing.T) { + agg := &turnAggregator{} + + agg.start() + agg.recordTool("bash") + agg.recordTool("read") + agg.recordTool("bash") + agg.recordStep(LLMUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadTokens: 10, + CacheCreationTokens: 5, + }) + agg.recordStep(LLMUsage{ + InputTokens: 200, + OutputTokens: 75, + }) + + snap := agg.consume() + if snap.toolCallCount != 3 { + t.Errorf("toolCallCount: got %d want 3", snap.toolCallCount) + } + wantNames := []string{"bash", "read", "bash"} + if len(snap.toolNames) != len(wantNames) { + t.Fatalf("toolNames length: got %d want %d", len(snap.toolNames), len(wantNames)) + } + for i, n := range wantNames { + if snap.toolNames[i] != n { + t.Errorf("toolNames[%d]: got %q want %q", i, snap.toolNames[i], n) + } + } + if snap.llmCallCount != 2 { + t.Errorf("llmCallCount: got %d want 2", snap.llmCallCount) + } + if snap.inputTokens != 300 { + t.Errorf("inputTokens: got %d want 300", snap.inputTokens) + } + if snap.outputTokens != 125 { + t.Errorf("outputTokens: got %d want 125", snap.outputTokens) + } + if snap.cacheReadTokens != 10 { + t.Errorf("cacheReadTokens: got %d want 10", snap.cacheReadTokens) + } + if snap.cacheWriteTokens != 5 { + t.Errorf("cacheWriteTokens: got %d want 5", snap.cacheWriteTokens) + } + if snap.durationMs() < 0 { + t.Errorf("durationMs should not be negative, got %d", snap.durationMs()) + } +} + +func TestTurnAggregator_StartResetsCounters(t *testing.T) { + agg := &turnAggregator{} + agg.start() + agg.recordTool("bash") + agg.recordStep(LLMUsage{InputTokens: 50}) + + // Begin a new turn — previous counters should be cleared. + agg.start() + snap := agg.consume() + + if snap.toolCallCount != 0 || snap.llmCallCount != 0 || snap.inputTokens != 0 { + t.Errorf("expected counters zeroed after start(), got %+v", snap) + } + if snap.toolNames != nil { + t.Errorf("expected toolNames=nil after start(), got %v", snap.toolNames) + } +} + +// TestTurnAggregator_DurationMs verifies the snapshot computes a positive +// duration when consume() runs after start(). +func TestTurnAggregator_DurationMs(t *testing.T) { + agg := &turnAggregator{} + agg.start() + time.Sleep(5 * time.Millisecond) + snap := agg.consume() + if snap.durationMs() < 1 { + t.Errorf("expected positive duration, got %d", snap.durationMs()) + } +} + +// TestTurnAggregator_ZeroStartSafe ensures a snapshot taken without a prior +// start() doesn't crash and reports zero duration. +func TestTurnAggregator_ZeroStartSafe(t *testing.T) { + agg := &turnAggregator{} + snap := agg.consume() + if snap.durationMs() != 0 { + t.Errorf("expected zero duration for unstarted aggregator, got %d", snap.durationMs()) + } +} + +// TestLLMUsageMeta_NilKit verifies the helper degrades gracefully when given +// a nil Kit instance (zero values, no panic). +func TestLLMUsageMeta_NilKit(t *testing.T) { + provider, modelID, cost := llmUsageMeta(nil, LLMUsage{InputTokens: 100}) + if provider != "" || modelID != "" || cost != 0 { + t.Errorf("expected zero values for nil kit, got (%q,%q,%v)", provider, modelID, cost) + } +} + +// TestIsAnthropicOAuth_NonAnthropic verifies the helper short-circuits for any +// provider other than "anthropic" without touching the credential store. +func TestIsAnthropicOAuth_NonAnthropic(t *testing.T) { + for _, provider := range []string{"openai", "google", "openrouter", ""} { + if isAnthropicOAuth(nil, provider) { + t.Errorf("isAnthropicOAuth(nil, %q) = true, want false", provider) + } + } +} + +func TestExtStateSidecarPath(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + {"empty", "", ""}, + {"jsonl", "/tmp/sessions/abc.jsonl", "/tmp/sessions/abc.ext-state.json"}, + {"jsonl with subdir", "/a/b/c.jsonl", "/a/b/c.ext-state.json"}, + {"no extension", "/tmp/session-blob", "/tmp/session-blob.ext-state.json"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := extStateSidecarPath(tc.in) + if got != tc.want { + t.Errorf("extStateSidecarPath(%q): got %q want %q", tc.in, got, tc.want) + } + }) + } +} diff --git a/skills/kit-extensions/SKILL.md b/skills/kit-extensions/SKILL.md index d5bb3ff8..47ac9c5c 100644 --- a/skills/kit-extensions/SKILL.md +++ b/skills/kit-extensions/SKILL.md @@ -88,7 +88,8 @@ api.OnAgentStart(func(e ext.AgentStartEvent, ctx ext.Context) { // e.Prompt string }) -// Agent finished responding. +// Agent finished responding. Carries per-turn aggregates so observer-style +// extensions don't need to maintain parallel bookkeeping. api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) { // e.Response string // e.StopReason string — "error" (on failure), "completed" (when LLM returns @@ -96,6 +97,33 @@ api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) { // (e.g. "stop", "length" (max output tokens hit), "tool-calls", "content-filter"). // To detect errors, check e.StopReason == "error". // Do NOT compare against "completed" for success — instead check != "error". + // + // Per-turn aggregates (computed by Kit's runtime): + // e.ToolCallCount int — total tool invocations this turn + // e.ToolNames []string — tool names in call order (duplicates preserved) + // e.LLMCallCount int — LLM round-trips / tool-loop iterations + // e.InputTokensDelta int — sum of input tokens across LLM calls this turn + // e.OutputTokensDelta int + // e.CacheReadTokensDelta int + // e.CacheWriteTokensDelta int + // e.CostDelta float64 — USD cost (zero when pricing unknown / OAuth) + // e.DurationMs int64 — wall-clock duration AgentStart→AgentEnd +}) + +// Per-LLM-call usage — fires after each provider round-trip with token + cost +// deltas attributed to that specific call. A single turn typically produces +// multiple LLMUsageEvents (one per tool-loop iteration). Use this for accurate +// budget enforcement that needs to react between calls instead of waiting +// for the turn to finish. +api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) { + // e.InputTokens, e.OutputTokens int + // e.CacheReadTokens, e.CacheWriteTokens int + // e.Cost float64 — USD; zero when pricing unknown / OAuth + // e.Model, e.Provider string — model used for THIS call + // (may differ across calls if SetModel was called) + // e.StepNumber int — zero-based step index in this turn + // e.FinishReason string — "stop" / "tool_calls" / "length" / ... + // e.RequestID string — optional provider correlation id (may be empty) }) ``` @@ -528,11 +556,38 @@ stats := ctx.GetContextStats() // .EstimatedTokens, .ContextLimit, .UsagePer msgs := ctx.GetMessages() // []ext.SessionMessage on current branch path := ctx.GetSessionPath() // file path of session JSONL -// Persist custom data in the session tree: +// Append-only log in the session tree (fork-aware, walked on every branch read): id, err := ctx.AppendEntry("my-type", "data string") entries := ctx.GetEntries("my-type") // []ext.ExtensionEntry{ID, EntryType, Data, Timestamp} ``` +### Session State (last-write-wins) + +Key-value store scoped to the session, persisted to a sidecar file +(`.ext-state.json`) outside the conversation tree. Reads are O(1) +(no branch walk), writes don't grow the JSONL, and the store is not +duplicated on fork. State is invisible to the LLM and survives session +resume. For ephemeral / in-memory sessions, state lives only in memory. + +```go +ctx.SetState("myext:budget-cap", "10.00") // last write wins +val, ok := ctx.GetState("myext:budget-cap") // (string, bool) +ctx.DeleteState("myext:budget-cap") // no-op if missing +keys := ctx.ListState() // []string, unspecified order +``` + +**When to use which:** + +| Need | Use | +|------|-----| +| Snapshot state ("current value of X") | `SetState` / `GetState` | +| Audit log / event history | `AppendEntry` / `GetEntries` | +| One-shot per-turn signal | enriched `AgentEndEvent` fields | +| Per-LLM-call observation | `OnLLMUsage` event | + +Namespace keys with your extension name (e.g. `"myext:budget-cap"`) to avoid +collisions across extensions. + ### Model Management ```go diff --git a/skills/kit-sdk/SKILL.md b/skills/kit-sdk/SKILL.md index a46f1baf..e578fbcd 100644 --- a/skills/kit-sdk/SKILL.md +++ b/skills/kit-sdk/SKILL.md @@ -1104,6 +1104,19 @@ if extAPI.HasExtensions() { tools := extAPI.GetToolInfos() extAPI.SetActiveTools([]string{"bash", "read"}) + // Session-scoped extension state (last-write-wins key-value store). + // Backed by an in-memory map and a per-session sidecar file + // (.ext-state.json) outside the conversation tree. + extAPI.SetState("myext:budget-cap", "10.00") + val, ok := extAPI.GetState("myext:budget-cap") + extAPI.DeleteState("myext:budget-cap") + keys := extAPI.ListState() + + // Load any existing state from the sidecar and install a saver hook so + // subsequent SetState/DeleteState mutations are flushed atomically. + // No-op for ephemeral / in-memory sessions. Safe to call multiple times. + _ = extAPI.InitStatePersistence() + // Events extAPI.EmitSessionStart() extAPI.EmitModelChange("new/model", "old/model", "extension") diff --git a/www/pages/extensions/capabilities.md b/www/pages/extensions/capabilities.md index dd061006..2eeded4f 100644 --- a/www/pages/extensions/capabilities.md +++ b/www/pages/extensions/capabilities.md @@ -7,7 +7,7 @@ description: All extension capabilities — lifecycle events, tools, commands, w ## Lifecycle events -Extensions can hook into 26 lifecycle events: +Extensions can hook into 27 lifecycle events: | Event | Description | |-------|-------------| @@ -15,7 +15,8 @@ Extensions can hook into 26 lifecycle events: | `OnSessionShutdown` | Session ending | | `OnBeforeAgentStart` | Before the agent loop begins | | `OnAgentStart` | Agent loop started | -| `OnAgentEnd` | Agent loop completed | +| `OnAgentEnd` | Agent loop completed (carries per-turn aggregates: tool counts, token deltas, cost, duration) | +| `OnLLMUsage` | Per-LLM-call token + cost delta (fires once per provider round-trip) | | `OnToolCall` | Tool call requested by the model | | `OnToolCallInputStart` | LLM began generating tool call arguments (tool name known, args streaming) | | `OnToolCallInputDelta` | Streamed JSON fragment of tool call arguments | @@ -45,11 +46,52 @@ api.OnToolCall(func(event ext.ToolCallEvent, ctx ext.Context) { ctx.PrintInfo("Calling tool: " + event.Name) }) -api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) { - ctx.PrintInfo("Agent finished") +api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) { + // Per-turn aggregates populated by Kit's runtime — no parallel + // bookkeeping required in the handler. + ctx.PrintInfo(fmt.Sprintf( + "Turn finished: %d tool calls (%v), %d LLM round-trips, $%.4f, %dms", + e.ToolCallCount, e.ToolNames, e.LLMCallCount, e.CostDelta, e.DurationMs, + )) +}) + +// Per-LLM-call usage — fires multiple times per turn (once per round-trip). +// Use for accurate budget enforcement between calls. +api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) { + ctx.PrintInfo(fmt.Sprintf( + "%s/%s step=%d tokens=↑%d ↓%d cost=$%.4f (%s)", + e.Provider, e.Model, e.StepNumber, + e.InputTokens, e.OutputTokens, e.Cost, e.FinishReason, + )) }) ``` +**`AgentEndEvent` fields** (in addition to `Response` and `StopReason`): + +| Field | Type | Description | +|-------|------|-------------| +| `ToolCallCount` | `int` | Total tool invocations during the turn | +| `ToolNames` | `[]string` | Tool names in call order (duplicates preserved) | +| `LLMCallCount` | `int` | LLM round-trips / tool-loop iterations | +| `InputTokensDelta` | `int` | Sum of input tokens across all LLM calls this turn | +| `OutputTokensDelta` | `int` | Sum of output tokens across all LLM calls this turn | +| `CacheReadTokensDelta` | `int` | Sum of cache-read tokens this turn | +| `CacheWriteTokensDelta` | `int` | Sum of cache-write tokens this turn | +| `CostDelta` | `float64` | Cost in USD (zero when pricing is unknown or OAuth credentials) | +| `DurationMs` | `int64` | Wall-clock time from `AgentStart` to `AgentEnd` | + +**`LLMUsageEvent` fields**: + +| Field | Type | Description | +|-------|------|-------------| +| `InputTokens` / `OutputTokens` | `int` | Per-call token deltas | +| `CacheReadTokens` / `CacheWriteTokens` | `int` | Per-call cache token deltas | +| `Cost` | `float64` | Per-call USD cost (zero when pricing unknown) | +| `Model` / `Provider` | `string` | Model used for this specific call — may differ from earlier calls if `ctx.SetModel` was called mid-turn | +| `StepNumber` | `int` | Zero-based step index within the turn | +| `FinishReason` | `string` | Provider finish reason for this call (`"stop"`, `"tool_calls"`, `"length"`, ...) | +| `RequestID` | `string` | Optional provider correlation id (may be empty) | + ## Tools Register custom tools that the LLM can invoke: @@ -338,6 +380,36 @@ api.OnCustomEvent("my-extension:data-ready", func(data any, ctx ext.Context) { }) ``` +## Session state + +Last-write-wins key-value store, scoped to the current session and persisted to a sidecar file (`.ext-state.json`) outside the conversation tree: + +```go +ctx.SetState("myext:budget-cap", "10.00") + +if cap, ok := ctx.GetState("myext:budget-cap"); ok { + // ... +} + +ctx.DeleteState("myext:budget-cap") +keys := ctx.ListState() // []string, unspecified order +``` + +Reads are O(1) (no branch walk), writes don't grow the session JSONL, and the store is not duplicated when the conversation forks. State is invisible to the LLM and survives session resume. + +### When to use which persistence primitive + +| Need | Use | Why | +|------|-----|-----| +| Snapshot state ("current value of X") | `SetState` / `GetState` | O(1) reads, sidecar file, last-write-wins | +| Audit log / event history | `AppendEntry` / `GetEntries` | Append-only, lives in conversation tree, fork-aware | +| One-shot per-turn signal | Enriched `AgentEndEvent` fields | No persistence needed; runtime tracks it for you | +| Per-LLM-call observation | `OnLLMUsage` event | Already attributed to model/provider/step | + +Using `AppendEntry` for snapshot state has a cost: it's O(branch_length) to read, fsyncs into the JSONL on every write, and the entry list duplicates on every fork. Prefer `SetState` for "what's the current value of X?"-style data. + +For ephemeral / in-memory sessions (no JSONL path) the state lives only in memory for the lifetime of the runner. + ## Bridged SDK APIs Extensions can access powerful internal SDK capabilities that enable advanced features like conversation tree navigation, dynamic skill loading, template parsing, and model resolution. diff --git a/www/pages/extensions/examples.md b/www/pages/extensions/examples.md index 8bf0a83e..5cb2cef4 100644 --- a/www/pages/extensions/examples.md +++ b/www/pages/extensions/examples.md @@ -50,6 +50,7 @@ Kit ships with a rich set of example extensions in the `examples/extensions/` di | [`context-inject.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/context-inject.go) | Inject context into conversations | | [`summarize.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/summarize.go) | Conversation summarization | | [`lsp-diagnostics.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/lsp-diagnostics.go) | LSP diagnostic integration | +| [`usage-budget.go`](https://github.com/mark3labs/kit/blob/master/examples/extensions/usage-budget.go) | Per-call usage callback (`OnLLMUsage`), session state (`SetState`/`GetState`), and enriched `OnAgentEnd` per-turn report | ## Bridged SDK APIs diff --git a/www/pages/extensions/overview.md b/www/pages/extensions/overview.md index f376cb19..931a7116 100644 --- a/www/pages/extensions/overview.md +++ b/www/pages/extensions/overview.md @@ -65,7 +65,8 @@ Passed to event handlers, the `Context` object provides runtime access to Kit's - **Model** — `ctx.SetModel(...)`, `ctx.GetAvailableModels()` - **Tools** — `ctx.GetAllTools()`, `ctx.SetActiveTools(...)` - **Context stats** — `ctx.GetContextStats()` -- **Session data** — `ctx.AppendEntry(...)`, `ctx.GetEntries(...)` +- **Session data** — `ctx.AppendEntry(...)`, `ctx.GetEntries(...)` (append-only, in conversation tree) +- **Session state** — `ctx.SetState(...)`, `ctx.GetState(...)`, `ctx.DeleteState(...)`, `ctx.ListState()` (last-write-wins, sidecar file) - **Subagents** — `ctx.SpawnSubagent(...)` - **LLM completion** — `ctx.Complete(...)` - **Custom events** — `ctx.EmitCustomEvent(...)`