mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
49f8b485be
* feat(extensions): add OnLLMUsage, SetState, enriched AgentEndEvent (#53) Three additive primitives to the extension API: - OnLLMUsage event: per-LLM-call token + cost deltas attributed to the specific model/provider used for each round-trip. Derived from the SDK StepFinishEvent in the extension bridge. Enables accurate budget enforcement between calls instead of only at turn boundaries. - ctx.SetState / GetState / DeleteState / ListState: session-scoped, last-write-wins key-value store backed by a sidecar file (<session>.ext-state.json) outside the conversation tree. Reads are O(1), writes don't grow the JSONL, and the store is not duplicated on fork. State is preserved across hot-reloads. - Enriched AgentEndEvent: ToolCallCount, ToolNames, LLMCallCount, token deltas (input/output/cache-read/cache-write), CostDelta, and DurationMs populated by a per-turn aggregator. Existing handlers reading only Response/StopReason are unaffected. Includes unit tests for the state store, LLMUsage registration, enriched AgentEndEvent, turn aggregator, llmUsageMeta, and sidecar path derivation. Adds examples/extensions/usage-budget.go demoing all three primitives together. Documents the additions in README, the docs site (extensions overview, capabilities, examples), and the kit-extensions and kit-sdk skill guides. Fixes #53 * fix(extensions): address review feedback on state store and llmUsageMeta - Serialize SetState/DeleteState saver invocations through a new saverMu so overlapping atomic-rename writes can no longer race on the shared .tmp file and persist an older snapshot after a newer one. - LoadStateFromFile now clears the in-memory store when the sidecar is missing or empty, matching the documented "replace … with its contents" contract. This makes session-switching safe by preventing keys from a prior session leaking into a new one. Tests updated to cover both the missing-file and empty-file cases. - llmUsageMeta now detects Anthropic OAuth credentials and returns Cost=0, matching the comment and the existing usage_tracker behavior for OAuth users. Mirrors the OAuth detection already used in cmd/extension_context.go. - Document the single-in-flight-turn assumption baked into the per-turn aggregator with a clear migration path (per-turn ID) for if concurrent turns ever become a supported use case. * fix(extensions): release saverMu on panic in state store Extract a runSaver helper that locks saverMu and defers Unlock before invoking the persistence callback. Without the deferred Unlock, a panic inside the saver (e.g. disk full mid-write) would leave saverMu held forever and deadlock the next SetState/DeleteState. Both SetState and DeleteState now route through the helper. New TestRunner_State_Saver PanicReleasesSaverMu reproduces the deadlock window with a 2s deadline and proves the mutex is released after a panic.
120 lines
3.2 KiB
Go
120 lines
3.2 KiB
Go
package extensions
|
|
|
|
import "testing"
|
|
|
|
func TestRunner_EmitLLMUsage(t *testing.T) {
|
|
var got LLMUsageEvent
|
|
var called bool
|
|
ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{
|
|
LLMUsage: {
|
|
func(e Event, c Context) Result {
|
|
got = e.(LLMUsageEvent)
|
|
called = true
|
|
return nil
|
|
},
|
|
},
|
|
})
|
|
|
|
r := makeRunner(ext)
|
|
_, err := r.Emit(LLMUsageEvent{
|
|
InputTokens: 100,
|
|
OutputTokens: 50,
|
|
Cost: 0.0012,
|
|
Model: "claude-sonnet-4-5-20250929",
|
|
Provider: "anthropic",
|
|
StepNumber: 2,
|
|
FinishReason: "tool_calls",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("emit: %v", err)
|
|
}
|
|
if !called {
|
|
t.Fatal("expected LLMUsage handler to be called")
|
|
}
|
|
if got.InputTokens != 100 || got.OutputTokens != 50 {
|
|
t.Errorf("token fields not propagated: %+v", got)
|
|
}
|
|
if got.Cost != 0.0012 {
|
|
t.Errorf("cost not propagated, got %v", got.Cost)
|
|
}
|
|
if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" {
|
|
t.Errorf("model/provider not propagated: %+v", got)
|
|
}
|
|
if got.StepNumber != 2 || got.FinishReason != "tool_calls" {
|
|
t.Errorf("step/finish reason not propagated: %+v", got)
|
|
}
|
|
}
|
|
|
|
func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) {
|
|
// Verify NewTestAPI wires up onLLMUsage so the extension can call
|
|
// api.OnLLMUsage during Init.
|
|
ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)}
|
|
api := NewTestAPI(ext)
|
|
|
|
var calls int
|
|
api.OnLLMUsage(func(e LLMUsageEvent, c Context) {
|
|
calls++
|
|
})
|
|
|
|
if len(ext.Handlers[LLMUsage]) != 1 {
|
|
t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage]))
|
|
}
|
|
|
|
r := makeRunner(*ext)
|
|
_, _ = r.Emit(LLMUsageEvent{InputTokens: 1})
|
|
if calls != 1 {
|
|
t.Errorf("expected handler called once, got %d", calls)
|
|
}
|
|
}
|
|
|
|
func TestAgentEndEvent_EnrichedFields(t *testing.T) {
|
|
// Verify the enriched event carries through Emit without mangling.
|
|
var got AgentEndEvent
|
|
ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{
|
|
AgentEnd: {
|
|
func(e Event, c Context) Result {
|
|
got = e.(AgentEndEvent)
|
|
return nil
|
|
},
|
|
},
|
|
})
|
|
r := makeRunner(ext)
|
|
_, err := r.Emit(AgentEndEvent{
|
|
Response: "done",
|
|
StopReason: "completed",
|
|
ToolCallCount: 3,
|
|
ToolNames: []string{"bash", "read", "bash"},
|
|
LLMCallCount: 4,
|
|
InputTokensDelta: 1500,
|
|
OutputTokensDelta: 400,
|
|
CacheReadTokensDelta: 200,
|
|
CacheWriteTokensDelta: 100,
|
|
CostDelta: 0.0123,
|
|
DurationMs: 2500,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("emit: %v", err)
|
|
}
|
|
if got.ToolCallCount != 3 {
|
|
t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount)
|
|
}
|
|
if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" {
|
|
t.Errorf("ToolNames: %v", got.ToolNames)
|
|
}
|
|
if got.LLMCallCount != 4 {
|
|
t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount)
|
|
}
|
|
if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 {
|
|
t.Errorf("token deltas: %+v", got)
|
|
}
|
|
if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 {
|
|
t.Errorf("cache deltas: %+v", got)
|
|
}
|
|
if got.CostDelta != 0.0123 {
|
|
t.Errorf("CostDelta: got %v", got.CostDelta)
|
|
}
|
|
if got.DurationMs != 2500 {
|
|
t.Errorf("DurationMs: got %d", got.DurationMs)
|
|
}
|
|
}
|