From 470ec43636ff8bc9cff80bc9e63f2dbde7a1b1bb Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Fri, 27 Feb 2026 13:03:26 +0300 Subject: [PATCH] add extension hook system with priority ordering and tool interception (Plan 09) --- pkg/kit/extensions_bridge.go | 45 +++ pkg/kit/hooks.go | 260 +++++++++++++++++ pkg/kit/hooks_test.go | 548 +++++++++++++++++++++++++++++++++++ pkg/kit/kit.go | 114 ++++++-- pkg/kit/setup.go | 31 +- 5 files changed, 972 insertions(+), 26 deletions(-) create mode 100644 pkg/kit/extensions_bridge.go create mode 100644 pkg/kit/hooks.go create mode 100644 pkg/kit/hooks_test.go diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go new file mode 100644 index 00000000..8b34ff6a --- /dev/null +++ b/pkg/kit/extensions_bridge.go @@ -0,0 +1,45 @@ +package kit + +import "github.com/mark3labs/kit/internal/extensions" + +// bridgeExtensions registers extension event handlers as SDK hooks. This makes +// the existing extension system a consumer of the SDK hook API, proving the +// hook surface is production-ready. +// +// Phase 1 (this plan): bridge BeforeAgentStart and Input as BeforeTurn hooks. +// Tool-level events (ToolCall, ToolResult) are already handled by the extension +// tool wrapper (internal/extensions/wrapper.go) which composes underneath the +// SDK hook wrapper. +// +// Phase 2 (future): app.executeStep() migrates to SDK hooks exclusively. +// Phase 3 (future): extension runner emits SDK events/hooks natively. +func (m *Kit) bridgeExtensions(runner *extensions.Runner) { + // Extension Input → BeforeTurn hook (high priority, runs first). + // An Input handler with Action="transform" replaces the prompt text. + if runner.HasHandlers(extensions.Input) { + m.OnBeforeTurn(HookPriorityHigh, func(h BeforeTurnHook) *BeforeTurnResult { + result, _ := runner.Emit(extensions.InputEvent{Text: h.Prompt}) + if r, ok := result.(extensions.InputResult); ok { + if r.Action == "transform" { + return &BeforeTurnResult{Prompt: &r.Text} + } + } + return nil + }) + } + + // Extension BeforeAgentStart → BeforeTurn hook (normal priority). + // Can inject a system prompt prefix and/or context text. + if runner.HasHandlers(extensions.BeforeAgentStart) { + m.OnBeforeTurn(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult { + result, _ := runner.Emit(extensions.BeforeAgentStartEvent{Prompt: h.Prompt}) + if r, ok := result.(extensions.BeforeAgentStartResult); ok { + return &BeforeTurnResult{ + SystemPrompt: r.SystemPrompt, + InjectText: r.InjectText, + } + } + return nil + }) + } +} diff --git a/pkg/kit/hooks.go b/pkg/kit/hooks.go new file mode 100644 index 00000000..bc31c5ce --- /dev/null +++ b/pkg/kit/hooks.go @@ -0,0 +1,260 @@ +package kit + +import ( + "context" + "fmt" + "sort" + "sync" + + "charm.land/fantasy" +) + +// --------------------------------------------------------------------------- +// Priority +// --------------------------------------------------------------------------- + +// HookPriority controls execution order of hooks. Lower values run first. +type HookPriority int + +const ( + // HookPriorityHigh runs before normal hooks. + HookPriorityHigh HookPriority = 0 + // HookPriorityNormal is the default priority. + HookPriorityNormal HookPriority = 50 + // HookPriorityLow runs after normal hooks. + HookPriorityLow HookPriority = 100 +) + +// --------------------------------------------------------------------------- +// Hook input/result types +// --------------------------------------------------------------------------- + +// BeforeToolCallHook is the input for hooks that fire before a tool executes. +type BeforeToolCallHook struct { + ToolName string + ToolArgs string +} + +// BeforeToolCallResult controls whether the tool call proceeds. +type BeforeToolCallResult struct { + Block bool // true prevents the tool from running + Reason string // human-readable reason for blocking +} + +// AfterToolResultHook is the input for hooks that fire after a tool executes. +type AfterToolResultHook struct { + ToolName string + ToolArgs string + Result string + IsError bool +} + +// AfterToolResultResult can modify the tool's output before it reaches the LLM. +type AfterToolResultResult struct { + Result *string // non-nil overrides the result text + IsError *bool // non-nil overrides the error flag +} + +// BeforeTurnHook is the input for hooks that fire before a prompt turn. +type BeforeTurnHook struct { + Prompt string +} + +// BeforeTurnResult can modify the prompt, inject system messages, or add context. +type BeforeTurnResult struct { + Prompt *string // override prompt text in the user message + SystemPrompt *string // prepend a system message + InjectText *string // prepend a user context message +} + +// AfterTurnHook is the input for hooks that fire after a prompt turn completes. +type AfterTurnHook struct { + Response string + Error error +} + +// AfterTurnResult is a placeholder — after-turn hooks are observation-only. +type AfterTurnResult struct{} + +// --------------------------------------------------------------------------- +// Generic hook registry with priority ordering +// --------------------------------------------------------------------------- + +type hookEntry[In any, Out any] struct { + id int + priority HookPriority + handler func(In) *Out +} + +type hookRegistry[In any, Out any] struct { + mu sync.RWMutex + hooks []hookEntry[In, Out] + next int +} + +func newHookRegistry[In any, Out any]() *hookRegistry[In, Out] { + return &hookRegistry[In, Out]{} +} + +// register adds a hook with the given priority and returns an unregister +// function. Within the same priority, hooks run in registration order. +func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func() { + hr.mu.Lock() + id := hr.next + hr.next++ + hr.hooks = append(hr.hooks, hookEntry[In, Out]{id: id, priority: p, handler: h}) + // Stable sort preserves insertion order within the same priority. + sort.SliceStable(hr.hooks, func(i, j int) bool { + return hr.hooks[i].priority < hr.hooks[j].priority + }) + hr.mu.Unlock() + + return func() { + hr.mu.Lock() + defer hr.mu.Unlock() + for i, entry := range hr.hooks { + if entry.id == id { + hr.hooks = append(hr.hooks[:i], hr.hooks[i+1:]...) + return + } + } + } +} + +// run executes all hooks in priority order. The first non-nil result wins. +func (hr *hookRegistry[In, Out]) run(input In) *Out { + hr.mu.RLock() + snapshot := make([]hookEntry[In, Out], len(hr.hooks)) + copy(snapshot, hr.hooks) + hr.mu.RUnlock() + + for _, entry := range snapshot { + if result := entry.handler(input); result != nil { + return result + } + } + return nil +} + +// hasHooks returns true if any hooks are registered. +func (hr *hookRegistry[In, Out]) hasHooks() bool { + hr.mu.RLock() + defer hr.mu.RUnlock() + return len(hr.hooks) > 0 +} + +// --------------------------------------------------------------------------- +// Hook registration methods on Kit +// --------------------------------------------------------------------------- + +// OnBeforeToolCall registers a hook that fires before each tool execution. +// Return a non-nil BeforeToolCallResult with Block=true to prevent the tool +// from running. Hooks execute in priority order; the first non-nil result wins. +// Returns an unregister function. +func (m *Kit) OnBeforeToolCall(p HookPriority, h func(BeforeToolCallHook) *BeforeToolCallResult) func() { + return m.beforeToolCall.register(p, h) +} + +// OnAfterToolResult registers a hook that fires after each tool execution. +// Return a non-nil AfterToolResultResult to modify the tool's output before +// it reaches the LLM. Hooks execute in priority order; the first non-nil +// result wins. Returns an unregister function. +func (m *Kit) OnAfterToolResult(p HookPriority, h func(AfterToolResultHook) *AfterToolResultResult) func() { + return m.afterToolResult.register(p, h) +} + +// OnBeforeTurn registers a hook that fires before each prompt turn. Return +// a non-nil BeforeTurnResult to modify the prompt, inject a system message, +// or prepend context. Hooks execute in priority order; the first non-nil +// result wins. Returns an unregister function. +func (m *Kit) OnBeforeTurn(p HookPriority, h func(BeforeTurnHook) *BeforeTurnResult) func() { + return m.beforeTurn.register(p, h) +} + +// OnAfterTurn registers a hook that fires after each prompt turn completes. +// This is observation-only — the handler cannot modify the response. Hooks +// execute in priority order. Returns an unregister function. +func (m *Kit) OnAfterTurn(p HookPriority, h func(AfterTurnHook)) func() { + return m.afterTurn.register(p, func(input AfterTurnHook) *AfterTurnResult { + h(input) + return nil + }) +} + +// --------------------------------------------------------------------------- +// Tool wrapping via hooks +// --------------------------------------------------------------------------- + +// hookedTool wraps a fantasy.AgentTool to run BeforeToolCall and +// AfterToolResult hooks around each execution. The registries are referenced +// by pointer so hooks added after agent creation are still invoked. +type hookedTool struct { + inner fantasy.AgentTool + beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult] + afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult] +} + +func (h *hookedTool) Info() fantasy.ToolInfo { return h.inner.Info() } +func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions { return h.inner.ProviderOptions() } +func (h *hookedTool) SetProviderOptions(o fantasy.ProviderOptions) { h.inner.SetProviderOptions(o) } + +func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + toolName := h.inner.Info().Name + + // 1. BeforeToolCall — can block execution. + if h.beforeToolCall.hasHooks() { + if result := h.beforeToolCall.run(BeforeToolCallHook{ + ToolName: toolName, + ToolArgs: call.Input, + }); result != nil && result.Block { + reason := result.Reason + if reason == "" { + reason = "blocked by hook" + } + return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)), + fmt.Errorf("tool blocked by hook: %s", reason) + } + } + + // 2. Execute actual tool. + resp, err := h.inner.Run(ctx, call) + + // 3. AfterToolResult — can modify output. + if h.afterToolResult.hasHooks() { + if result := h.afterToolResult.run(AfterToolResultHook{ + ToolName: toolName, + ToolArgs: call.Input, + Result: resp.Content, + IsError: err != nil || resp.IsError, + }); result != nil { + if result.Result != nil { + resp.Content = *result.Result + } + if result.IsError != nil { + resp.IsError = *result.IsError + } + } + } + + return resp, err +} + +// hookToolWrapper creates a tool wrapper function that applies hook-based +// tool interception. The wrapper references the hook registries directly, +// so hooks registered after agent creation are still called at execution time. +func hookToolWrapper( + beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult], + afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult], +) func([]fantasy.AgentTool) []fantasy.AgentTool { + return func(tools []fantasy.AgentTool) []fantasy.AgentTool { + wrapped := make([]fantasy.AgentTool, len(tools)) + for i, tool := range tools { + wrapped[i] = &hookedTool{ + inner: tool, + beforeToolCall: beforeToolCall, + afterToolResult: afterToolResult, + } + } + return wrapped + } +} diff --git a/pkg/kit/hooks_test.go b/pkg/kit/hooks_test.go new file mode 100644 index 00000000..db36251f --- /dev/null +++ b/pkg/kit/hooks_test.go @@ -0,0 +1,548 @@ +package kit + +import ( + "context" + "fmt" + "sync" + "testing" + + "charm.land/fantasy" +) + +// --------------------------------------------------------------------------- +// Hook registry tests +// --------------------------------------------------------------------------- + +func TestHookRegistry_RegisterAndRun(t *testing.T) { + hr := newHookRegistry[string, string]() + + hr.register(HookPriorityNormal, func(input string) *string { + result := "handled: " + input + return &result + }) + + got := hr.run("hello") + if got == nil { + t.Fatal("expected non-nil result") + } + if *got != "handled: hello" { + t.Errorf("expected 'handled: hello', got %q", *got) + } +} + +func TestHookRegistry_FirstNonNilWins(t *testing.T) { + hr := newHookRegistry[string, string]() + + // First hook returns nil. + hr.register(HookPriorityNormal, func(_ string) *string { + return nil + }) + // Second hook returns a result. + hr.register(HookPriorityNormal, func(input string) *string { + result := "second: " + input + return &result + }) + // Third hook would also return, but should never be reached. + hr.register(HookPriorityNormal, func(input string) *string { + result := "third: " + input + return &result + }) + + got := hr.run("test") + if got == nil { + t.Fatal("expected non-nil result") + } + if *got != "second: test" { + t.Errorf("expected 'second: test', got %q", *got) + } +} + +func TestHookRegistry_PriorityOrdering(t *testing.T) { + hr := newHookRegistry[string, string]() + + // Register in reverse priority order. + hr.register(HookPriorityLow, func(_ string) *string { + result := "low" + return &result + }) + hr.register(HookPriorityHigh, func(_ string) *string { + result := "high" + return &result + }) + hr.register(HookPriorityNormal, func(_ string) *string { + result := "normal" + return &result + }) + + got := hr.run("x") + if got == nil { + t.Fatal("expected non-nil result") + } + if *got != "high" { + t.Errorf("expected 'high' (priority 0 runs first), got %q", *got) + } +} + +func TestHookRegistry_SamePriorityPreservesOrder(t *testing.T) { + hr := newHookRegistry[int, string]() + + hr.register(HookPriorityNormal, func(n int) *string { + result := "first" + return &result + }) + hr.register(HookPriorityNormal, func(n int) *string { + result := "second" + return &result + }) + + got := hr.run(0) + if got == nil || *got != "first" { + t.Errorf("expected 'first' (insertion order), got %v", got) + } +} + +func TestHookRegistry_Unregister(t *testing.T) { + hr := newHookRegistry[string, string]() + + unregister := hr.register(HookPriorityNormal, func(input string) *string { + result := "should be gone" + return &result + }) + + if !hr.hasHooks() { + t.Fatal("expected hasHooks to be true after registration") + } + + unregister() + + if hr.hasHooks() { + t.Fatal("expected hasHooks to be false after unregister") + } + + got := hr.run("test") + if got != nil { + t.Errorf("expected nil after unregister, got %v", *got) + } +} + +func TestHookRegistry_NoHooksReturnsNil(t *testing.T) { + hr := newHookRegistry[string, string]() + + got := hr.run("test") + if got != nil { + t.Errorf("expected nil when no hooks registered, got %v", *got) + } +} + +func TestHookRegistry_HasHooks(t *testing.T) { + hr := newHookRegistry[string, string]() + + if hr.hasHooks() { + t.Error("expected hasHooks to be false initially") + } + + unsub := hr.register(HookPriorityNormal, func(_ string) *string { return nil }) + if !hr.hasHooks() { + t.Error("expected hasHooks to be true after registration") + } + + unsub() + if hr.hasHooks() { + t.Error("expected hasHooks to be false after unregister") + } +} + +func TestHookRegistry_ConcurrentAccess(t *testing.T) { + hr := newHookRegistry[int, int]() + + var wg sync.WaitGroup + const n = 100 + + // Concurrent registrations. + for range n { + wg.Go(func() { + unsub := hr.register(HookPriorityNormal, func(x int) *int { + result := x * 2 + return &result + }) + // Immediately unregister half the time. + unsub() + }) + } + + // Concurrent runs while registrations are happening. + for range n { + wg.Go(func() { + hr.run(42) + }) + } + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// hookedTool tests +// --------------------------------------------------------------------------- + +// mockAgentTool implements fantasy.AgentTool for testing. +type mockAgentTool struct { + name string + runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) + popts fantasy.ProviderOptions +} + +func (m *mockAgentTool) Info() fantasy.ToolInfo { + return fantasy.ToolInfo{Name: m.name, Description: "mock tool"} +} +func (m *mockAgentTool) ProviderOptions() fantasy.ProviderOptions { return m.popts } +func (m *mockAgentTool) SetProviderOptions(o fantasy.ProviderOptions) { m.popts = o } +func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + if m.runFn != nil { + return m.runFn(ctx, call) + } + return fantasy.NewTextResponse("default output"), nil +} + +func TestHookedTool_Passthrough(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + mock := &mockAgentTool{ + name: "test_tool", + runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("hello world"), nil + }, + } + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + + resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "hello world" { + t.Errorf("expected 'hello world', got %q", resp.Content) + } +} + +func TestHookedTool_BeforeToolCallBlock(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + toolRan := false + mock := &mockAgentTool{ + name: "dangerous_tool", + runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + toolRan = true + return fantasy.NewTextResponse("should not run"), nil + }, + } + + before.register(HookPriorityHigh, func(h BeforeToolCallHook) *BeforeToolCallResult { + if h.ToolName == "dangerous_tool" { + return &BeforeToolCallResult{Block: true, Reason: "too dangerous"} + } + return nil + }) + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + + resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"}) + if err == nil { + t.Fatal("expected error from blocked tool") + } + if toolRan { + t.Error("tool should not have run when blocked") + } + if resp.Content != "Error: too dangerous" { + t.Errorf("expected block error message, got %q", resp.Content) + } +} + +func TestHookedTool_BeforeToolCallBlockDefaultReason(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + mock := &mockAgentTool{name: "tool"} + before.register(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult { + return &BeforeToolCallResult{Block: true} + }) + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + resp, _ := ht.Run(context.Background(), fantasy.ToolCall{}) + if resp.Content != "Error: blocked by hook" { + t.Errorf("expected default block reason, got %q", resp.Content) + } +} + +func TestHookedTool_AfterToolResultModify(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + mock := &mockAgentTool{ + name: "tool", + runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("secret data"), nil + }, + } + + after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult { + redacted := "[REDACTED]" + return &AfterToolResultResult{Result: &redacted} + }) + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "[REDACTED]" { + t.Errorf("expected '[REDACTED]', got %q", resp.Content) + } +} + +func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + mock := &mockAgentTool{ + name: "tool", + runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("ok"), nil + }, + } + + isErr := true + after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult { + return &AfterToolResultResult{IsError: &isErr} + }) + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + resp, err := ht.Run(context.Background(), fantasy.ToolCall{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !resp.IsError { + t.Error("expected IsError to be overridden to true") + } +} + +func TestHookedTool_HookReceivesToolInfo(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + mock := &mockAgentTool{ + name: "my_tool", + runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("result"), nil + }, + } + + var capturedBefore BeforeToolCallHook + var capturedAfter AfterToolResultHook + + before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult { + capturedBefore = h + return nil // don't block + }) + after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult { + capturedAfter = h + return nil // don't modify + }) + + ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after} + _, _ = ht.Run(context.Background(), fantasy.ToolCall{Input: `{"key":"value"}`}) + + if capturedBefore.ToolName != "my_tool" { + t.Errorf("BeforeToolCall: expected tool name 'my_tool', got %q", capturedBefore.ToolName) + } + if capturedBefore.ToolArgs != `{"key":"value"}` { + t.Errorf("BeforeToolCall: expected args, got %q", capturedBefore.ToolArgs) + } + if capturedAfter.ToolName != "my_tool" { + t.Errorf("AfterToolResult: expected tool name 'my_tool', got %q", capturedAfter.ToolName) + } + if capturedAfter.Result != "result" { + t.Errorf("AfterToolResult: expected result 'result', got %q", capturedAfter.Result) + } +} + +func TestHookedTool_InfoDelegates(t *testing.T) { + mock := &mockAgentTool{name: "delegate_test"} + ht := &hookedTool{ + inner: mock, + beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](), + afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](), + } + + if ht.Info().Name != "delegate_test" { + t.Errorf("expected Info() to delegate to inner tool") + } +} + +// --------------------------------------------------------------------------- +// hookToolWrapper tests +// --------------------------------------------------------------------------- + +func TestHookToolWrapper(t *testing.T) { + before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + + wrapper := hookToolWrapper(before, after) + + tools := []fantasy.AgentTool{ + &mockAgentTool{name: "tool_a"}, + &mockAgentTool{name: "tool_b"}, + } + + wrapped := wrapper(tools) + if len(wrapped) != 2 { + t.Fatalf("expected 2 wrapped tools, got %d", len(wrapped)) + } + + // Verify tools are wrapped (different pointer than original). + for i, wt := range wrapped { + if _, ok := wt.(*hookedTool); !ok { + t.Errorf("tool %d: expected *hookedTool, got %T", i, wt) + } + if wt.Info().Name != tools[i].Info().Name { + t.Errorf("tool %d: expected name %q, got %q", i, tools[i].Info().Name, wt.Info().Name) + } + } + + // Hooks registered after wrapping should still work. + var blocked bool + before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult { + blocked = true + return &BeforeToolCallResult{Block: true, Reason: "late hook"} + }) + + _, err := wrapped[0].Run(context.Background(), fantasy.ToolCall{}) + if err == nil { + t.Error("expected error from late-registered blocking hook") + } + if !blocked { + t.Error("late-registered hook should have been called") + } +} + +// --------------------------------------------------------------------------- +// Hook type tests (BeforeTurn, AfterTurn) +// --------------------------------------------------------------------------- + +func TestBeforeTurnHook_PromptOverride(t *testing.T) { + hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]() + + override := "modified prompt" + hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult { + return &BeforeTurnResult{Prompt: &override} + }) + + result := hr.run(BeforeTurnHook{Prompt: "original"}) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.Prompt == nil || *result.Prompt != "modified prompt" { + t.Errorf("expected prompt override, got %v", result.Prompt) + } +} + +func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) { + hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]() + + sysPr := "be concise" + ctx := "project context here" + hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult { + return &BeforeTurnResult{ + SystemPrompt: &sysPr, + InjectText: &ctx, + } + }) + + result := hr.run(BeforeTurnHook{Prompt: "hello"}) + if result == nil { + t.Fatal("expected non-nil result") + } + if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" { + t.Errorf("expected system prompt injection") + } + if result.InjectText == nil || *result.InjectText != "project context here" { + t.Errorf("expected context injection") + } +} + +func TestAfterTurnHook_ObservationOnly(t *testing.T) { + hr := newHookRegistry[AfterTurnHook, AfterTurnResult]() + + var captured AfterTurnHook + hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult { + captured = h + return nil // observation only + }) + + hr.run(AfterTurnHook{Response: "agent replied"}) + if captured.Response != "agent replied" { + t.Errorf("expected captured response, got %q", captured.Response) + } +} + +func TestAfterTurnHook_WithError(t *testing.T) { + hr := newHookRegistry[AfterTurnHook, AfterTurnResult]() + + var captured AfterTurnHook + hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult { + captured = h + return nil + }) + + testErr := fmt.Errorf("generation failed") + hr.run(AfterTurnHook{Error: testErr}) + if captured.Error != testErr { + t.Errorf("expected captured error, got %v", captured.Error) + } +} + +// --------------------------------------------------------------------------- +// Priority constants sanity check +// --------------------------------------------------------------------------- + +func TestHookPriorityOrdering(t *testing.T) { + if HookPriorityHigh >= HookPriorityNormal { + t.Error("HookPriorityHigh should be less than HookPriorityNormal") + } + if HookPriorityNormal >= HookPriorityLow { + t.Error("HookPriorityNormal should be less than HookPriorityLow") + } +} + +// --------------------------------------------------------------------------- +// Kit method compilation tests (verify API surface exists) +// --------------------------------------------------------------------------- + +func TestKit_HookMethodsExist(t *testing.T) { + k := &Kit{ + events: newEventBus(), + beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](), + afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](), + beforeTurn: newHookRegistry[BeforeTurnHook, BeforeTurnResult](), + afterTurn: newHookRegistry[AfterTurnHook, AfterTurnResult](), + } + + // Verify all hook registration methods return unsubscribe functions. + u1 := k.OnBeforeToolCall(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult { + return nil + }) + u2 := k.OnAfterToolResult(HookPriorityNormal, func(_ AfterToolResultHook) *AfterToolResultResult { + return nil + }) + u3 := k.OnBeforeTurn(HookPriorityNormal, func(_ BeforeTurnHook) *BeforeTurnResult { + return nil + }) + u4 := k.OnAfterTurn(HookPriorityNormal, func(_ AfterTurnHook) {}) + + // All should be callable. + u1() + u2() + u3() + u4() +} diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 6322053d..1fe1f6a1 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -27,6 +27,12 @@ type Kit struct { autoCompact bool compactionOpts *CompactionOptions skills []*skills.Skill + + // Hook registries — interception layer (see hooks.go). + beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult] + afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult] + beforeTurn *hookRegistry[BeforeTurnHook, BeforeTurnResult] + afterTurn *hookRegistry[AfterTurnHook, AfterTurnResult] } // Subscribe registers an EventListener that will be called for every lifecycle @@ -47,6 +53,7 @@ type Options struct { Streaming bool // Enable streaming (default from config) Quiet bool // Suppress debug output Tools []Tool // Custom tool set. If empty, AllTools() is used. + ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools. // Session configuration SessionDir string // Base directory for session discovery (default: cwd) @@ -148,11 +155,21 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { return nil, fmt.Errorf("failed to load MCP config: %w", err) } - // Create agent using shared setup. + // Pre-create hook registries so the tool wrapper can reference them. + // Hooks registered after New() returns are still invoked because the + // wrapper captures the registries by pointer. + beforeToolCall := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]() + afterToolResult := newHookRegistry[AfterToolResultHook, AfterToolResultResult]() + beforeTurn := newHookRegistry[BeforeTurnHook, BeforeTurnResult]() + afterTurn := newHookRegistry[AfterTurnHook, AfterTurnResult]() + + // Create agent using shared setup with the hook tool wrapper. agentResult, err := SetupAgent(ctx, AgentSetupOptions{ - MCPConfig: mcpConfig, - Quiet: opts.Quiet, - CoreTools: opts.Tools, + MCPConfig: mcpConfig, + Quiet: opts.Quiet, + CoreTools: opts.Tools, + ExtraTools: opts.ExtraTools, + ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult), }) if err != nil { return nil, err @@ -165,15 +182,26 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { return nil, fmt.Errorf("failed to initialize session: %w", err) } - return &Kit{ - agent: agentResult.Agent, - treeSession: treeSession, - modelString: viper.GetString("model"), - events: newEventBus(), - autoCompact: opts.AutoCompact, - compactionOpts: opts.CompactionOptions, - skills: loadedSkills, - }, nil + k := &Kit{ + agent: agentResult.Agent, + treeSession: treeSession, + modelString: viper.GetString("model"), + events: newEventBus(), + autoCompact: opts.AutoCompact, + compactionOpts: opts.CompactionOptions, + skills: loadedSkills, + beforeToolCall: beforeToolCall, + afterToolResult: afterToolResult, + beforeTurn: beforeTurn, + afterTurn: afterTurn, + } + + // Bridge extension events to SDK hooks. + if agentResult.ExtRunner != nil { + k.bridgeExtensions(agentResult.ExtRunner) + } + + return k, nil } // GetSkills returns the skills loaded during initialisation. @@ -277,15 +305,44 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent. } // runTurn is the shared lifecycle for every prompt mode: -// 1. Persist pre-generation messages to the tree session. -// 2. Build context from the tree (walks leaf-to-root for current branch). -// 3. Emit turn/message start events. -// 4. Run generation. -// 5. Emit turn/message end events. -// 6. Persist post-generation messages (tool calls, results, assistant). +// 1. Run BeforeTurn hooks (can modify prompt, inject messages). +// 2. Persist pre-generation messages to the tree session. +// 3. Build context from the tree (walks leaf-to-root for current branch). +// 4. Emit turn/message start events. +// 5. Run generation. +// 6. Emit turn/message end events. +// 7. Persist post-generation messages (tool calls, results, assistant). +// 8. Run AfterTurn hooks. // // promptLabel is the human-readable label emitted in TurnStartEvent.Prompt. -func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fantasy.Message) (string, error) { +// prompt is the raw user text passed to BeforeTurn hooks. +func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (string, error) { + // Run BeforeTurn hooks — can modify the prompt, inject system/context messages. + if m.beforeTurn.hasHooks() { + if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil { + // Override prompt text in the last user message. + if hookResult.Prompt != nil { + for i := len(preMessages) - 1; i >= 0; i-- { + if preMessages[i].Role == fantasy.MessageRoleUser { + preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt) + break + } + } + } + // Inject messages before the original preMessages. + var injected []fantasy.Message + if hookResult.SystemPrompt != nil { + injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt)) + } + if hookResult.InjectText != nil { + injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText)) + } + if len(injected) > 0 { + preMessages = append(injected, preMessages...) + } + } + } + // Persist pre-generation messages to tree session. for _, msg := range preMessages { _, _ = m.treeSession.AppendFantasyMessage(msg) @@ -306,6 +363,10 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan result, err := m.generate(ctx, messages) if err != nil { m.events.emit(TurnEndEvent{Error: err}) + // Run AfterTurn hooks even on error. + if m.afterTurn.hasHooks() { + m.afterTurn.run(AfterTurnHook{Error: err}) + } return "", err } @@ -321,6 +382,11 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan } } + // Run AfterTurn hooks. + if m.afterTurn.hasHooks() { + m.afterTurn.run(AfterTurnHook{Response: responseText}) + } + return responseText, nil } @@ -333,7 +399,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan // automatically maintained in the tree session. Lifecycle events are emitted // to all registered subscribers. Returns an error if generation fails. func (m *Kit) Prompt(ctx context.Context, message string) (string, error) { - return m.runTurn(ctx, message, []fantasy.Message{ + return m.runTurn(ctx, message, message, []fantasy.Message{ fantasy.NewUserMessage(message), }) } @@ -346,7 +412,7 @@ func (m *Kit) Prompt(ctx context.Context, message string) (string, error) { // a synthetic user message so the agent acknowledges and follows the directive. // Both messages are persisted to the session. func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) { - return m.runTurn(ctx, "[steer] "+instruction, []fantasy.Message{ + return m.runTurn(ctx, "[steer] "+instruction, instruction, []fantasy.Message{ fantasy.NewSystemMessage(instruction), fantasy.NewUserMessage("Please acknowledge and follow the above instruction."), }) @@ -367,7 +433,7 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) { text = "Continue." } - return m.runTurn(ctx, "[follow-up]", []fantasy.Message{ + return m.runTurn(ctx, "[follow-up]", text, []fantasy.Message{ fantasy.NewUserMessage(text), }) } @@ -390,7 +456,7 @@ func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOpti } preMessages = append(preMessages, fantasy.NewUserMessage(msg)) - return m.runTurn(ctx, msg, preMessages) + return m.runTurn(ctx, msg, msg, preMessages) } // PromptWithCallbacks sends a message with callbacks for monitoring tool diff --git a/pkg/kit/setup.go b/pkg/kit/setup.go index 49890f04..17e9ad8a 100644 --- a/pkg/kit/setup.go +++ b/pkg/kit/setup.go @@ -31,6 +31,13 @@ type AgentSetupOptions struct { // CoreTools overrides the default core tool set. If empty, core.AllTools() // is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir). CoreTools []fantasy.AgentTool + // ExtraTools are additional tools added alongside core, MCP, and extension + // tools. They do not replace the defaults — they extend them. + ExtraTools []fantasy.AgentTool + // ToolWrapper is an optional function that wraps tools after extension + // wrapping. Used by the SDK hook system. Both wrappers compose: + // extension wrapper runs first (inner), then this wrapper (outer). + ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool } // AgentSetupResult bundles the created agent and any debug logger so the caller @@ -106,6 +113,26 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, } } + // Compose tool wrappers: extension wrapper (inner) + caller wrapper (outer). + toolWrapper := extCreationOpts.toolWrapper + if opts.ToolWrapper != nil { + if toolWrapper != nil { + inner := toolWrapper + outer := opts.ToolWrapper + toolWrapper = func(t []fantasy.AgentTool) []fantasy.AgentTool { + return outer(inner(t)) + } + } else { + toolWrapper = opts.ToolWrapper + } + } + + // Merge extra tools: extension tools + caller extra tools. + extraTools := extCreationOpts.extraTools + if len(opts.ExtraTools) > 0 { + extraTools = append(extraTools, opts.ExtraTools...) + } + a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ ModelConfig: modelConfig, MCPConfig: opts.MCPConfig, @@ -117,8 +144,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, SpinnerFunc: opts.SpinnerFunc, DebugLogger: debugLogger, CoreTools: opts.CoreTools, - ToolWrapper: extCreationOpts.toolWrapper, - ExtraTools: extCreationOpts.extraTools, + ToolWrapper: toolWrapper, + ExtraTools: extraTools, }) if err != nil { return nil, fmt.Errorf("failed to create agent: %w", err)