add extension hook system with priority ordering and tool interception (Plan 09)

This commit is contained in:
Ed Zynda
2026-02-27 13:03:26 +03:00
parent 6c069907dd
commit 470ec43636
5 changed files with 972 additions and 26 deletions
+45
View File
@@ -0,0 +1,45 @@
package kit
import "github.com/mark3labs/kit/internal/extensions"
// bridgeExtensions registers extension event handlers as SDK hooks. This makes
// the existing extension system a consumer of the SDK hook API, proving the
// hook surface is production-ready.
//
// Phase 1 (this plan): bridge BeforeAgentStart and Input as BeforeTurn hooks.
// Tool-level events (ToolCall, ToolResult) are already handled by the extension
// tool wrapper (internal/extensions/wrapper.go) which composes underneath the
// SDK hook wrapper.
//
// Phase 2 (future): app.executeStep() migrates to SDK hooks exclusively.
// Phase 3 (future): extension runner emits SDK events/hooks natively.
func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
// Extension Input → BeforeTurn hook (high priority, runs first).
// An Input handler with Action="transform" replaces the prompt text.
if runner.HasHandlers(extensions.Input) {
m.OnBeforeTurn(HookPriorityHigh, func(h BeforeTurnHook) *BeforeTurnResult {
result, _ := runner.Emit(extensions.InputEvent{Text: h.Prompt})
if r, ok := result.(extensions.InputResult); ok {
if r.Action == "transform" {
return &BeforeTurnResult{Prompt: &r.Text}
}
}
return nil
})
}
// Extension BeforeAgentStart → BeforeTurn hook (normal priority).
// Can inject a system prompt prefix and/or context text.
if runner.HasHandlers(extensions.BeforeAgentStart) {
m.OnBeforeTurn(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
result, _ := runner.Emit(extensions.BeforeAgentStartEvent{Prompt: h.Prompt})
if r, ok := result.(extensions.BeforeAgentStartResult); ok {
return &BeforeTurnResult{
SystemPrompt: r.SystemPrompt,
InjectText: r.InjectText,
}
}
return nil
})
}
}
+260
View File
@@ -0,0 +1,260 @@
package kit
import (
"context"
"fmt"
"sort"
"sync"
"charm.land/fantasy"
)
// ---------------------------------------------------------------------------
// Priority
// ---------------------------------------------------------------------------
// HookPriority controls execution order of hooks. Lower values run first.
type HookPriority int
const (
// HookPriorityHigh runs before normal hooks.
HookPriorityHigh HookPriority = 0
// HookPriorityNormal is the default priority.
HookPriorityNormal HookPriority = 50
// HookPriorityLow runs after normal hooks.
HookPriorityLow HookPriority = 100
)
// ---------------------------------------------------------------------------
// Hook input/result types
// ---------------------------------------------------------------------------
// BeforeToolCallHook is the input for hooks that fire before a tool executes.
type BeforeToolCallHook struct {
ToolName string
ToolArgs string
}
// BeforeToolCallResult controls whether the tool call proceeds.
type BeforeToolCallResult struct {
Block bool // true prevents the tool from running
Reason string // human-readable reason for blocking
}
// AfterToolResultHook is the input for hooks that fire after a tool executes.
type AfterToolResultHook struct {
ToolName string
ToolArgs string
Result string
IsError bool
}
// AfterToolResultResult can modify the tool's output before it reaches the LLM.
type AfterToolResultResult struct {
Result *string // non-nil overrides the result text
IsError *bool // non-nil overrides the error flag
}
// BeforeTurnHook is the input for hooks that fire before a prompt turn.
type BeforeTurnHook struct {
Prompt string
}
// BeforeTurnResult can modify the prompt, inject system messages, or add context.
type BeforeTurnResult struct {
Prompt *string // override prompt text in the user message
SystemPrompt *string // prepend a system message
InjectText *string // prepend a user context message
}
// AfterTurnHook is the input for hooks that fire after a prompt turn completes.
type AfterTurnHook struct {
Response string
Error error
}
// AfterTurnResult is a placeholder — after-turn hooks are observation-only.
type AfterTurnResult struct{}
// ---------------------------------------------------------------------------
// Generic hook registry with priority ordering
// ---------------------------------------------------------------------------
type hookEntry[In any, Out any] struct {
id int
priority HookPriority
handler func(In) *Out
}
type hookRegistry[In any, Out any] struct {
mu sync.RWMutex
hooks []hookEntry[In, Out]
next int
}
func newHookRegistry[In any, Out any]() *hookRegistry[In, Out] {
return &hookRegistry[In, Out]{}
}
// register adds a hook with the given priority and returns an unregister
// function. Within the same priority, hooks run in registration order.
func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func() {
hr.mu.Lock()
id := hr.next
hr.next++
hr.hooks = append(hr.hooks, hookEntry[In, Out]{id: id, priority: p, handler: h})
// Stable sort preserves insertion order within the same priority.
sort.SliceStable(hr.hooks, func(i, j int) bool {
return hr.hooks[i].priority < hr.hooks[j].priority
})
hr.mu.Unlock()
return func() {
hr.mu.Lock()
defer hr.mu.Unlock()
for i, entry := range hr.hooks {
if entry.id == id {
hr.hooks = append(hr.hooks[:i], hr.hooks[i+1:]...)
return
}
}
}
}
// run executes all hooks in priority order. The first non-nil result wins.
func (hr *hookRegistry[In, Out]) run(input In) *Out {
hr.mu.RLock()
snapshot := make([]hookEntry[In, Out], len(hr.hooks))
copy(snapshot, hr.hooks)
hr.mu.RUnlock()
for _, entry := range snapshot {
if result := entry.handler(input); result != nil {
return result
}
}
return nil
}
// hasHooks returns true if any hooks are registered.
func (hr *hookRegistry[In, Out]) hasHooks() bool {
hr.mu.RLock()
defer hr.mu.RUnlock()
return len(hr.hooks) > 0
}
// ---------------------------------------------------------------------------
// Hook registration methods on Kit
// ---------------------------------------------------------------------------
// OnBeforeToolCall registers a hook that fires before each tool execution.
// Return a non-nil BeforeToolCallResult with Block=true to prevent the tool
// from running. Hooks execute in priority order; the first non-nil result wins.
// Returns an unregister function.
func (m *Kit) OnBeforeToolCall(p HookPriority, h func(BeforeToolCallHook) *BeforeToolCallResult) func() {
return m.beforeToolCall.register(p, h)
}
// OnAfterToolResult registers a hook that fires after each tool execution.
// Return a non-nil AfterToolResultResult to modify the tool's output before
// it reaches the LLM. Hooks execute in priority order; the first non-nil
// result wins. Returns an unregister function.
func (m *Kit) OnAfterToolResult(p HookPriority, h func(AfterToolResultHook) *AfterToolResultResult) func() {
return m.afterToolResult.register(p, h)
}
// OnBeforeTurn registers a hook that fires before each prompt turn. Return
// a non-nil BeforeTurnResult to modify the prompt, inject a system message,
// or prepend context. Hooks execute in priority order; the first non-nil
// result wins. Returns an unregister function.
func (m *Kit) OnBeforeTurn(p HookPriority, h func(BeforeTurnHook) *BeforeTurnResult) func() {
return m.beforeTurn.register(p, h)
}
// OnAfterTurn registers a hook that fires after each prompt turn completes.
// This is observation-only — the handler cannot modify the response. Hooks
// execute in priority order. Returns an unregister function.
func (m *Kit) OnAfterTurn(p HookPriority, h func(AfterTurnHook)) func() {
return m.afterTurn.register(p, func(input AfterTurnHook) *AfterTurnResult {
h(input)
return nil
})
}
// ---------------------------------------------------------------------------
// Tool wrapping via hooks
// ---------------------------------------------------------------------------
// hookedTool wraps a fantasy.AgentTool to run BeforeToolCall and
// AfterToolResult hooks around each execution. The registries are referenced
// by pointer so hooks added after agent creation are still invoked.
type hookedTool struct {
inner fantasy.AgentTool
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
}
func (h *hookedTool) Info() fantasy.ToolInfo { return h.inner.Info() }
func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions { return h.inner.ProviderOptions() }
func (h *hookedTool) SetProviderOptions(o fantasy.ProviderOptions) { h.inner.SetProviderOptions(o) }
func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolName := h.inner.Info().Name
// 1. BeforeToolCall — can block execution.
if h.beforeToolCall.hasHooks() {
if result := h.beforeToolCall.run(BeforeToolCallHook{
ToolName: toolName,
ToolArgs: call.Input,
}); result != nil && result.Block {
reason := result.Reason
if reason == "" {
reason = "blocked by hook"
}
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
fmt.Errorf("tool blocked by hook: %s", reason)
}
}
// 2. Execute actual tool.
resp, err := h.inner.Run(ctx, call)
// 3. AfterToolResult — can modify output.
if h.afterToolResult.hasHooks() {
if result := h.afterToolResult.run(AfterToolResultHook{
ToolName: toolName,
ToolArgs: call.Input,
Result: resp.Content,
IsError: err != nil || resp.IsError,
}); result != nil {
if result.Result != nil {
resp.Content = *result.Result
}
if result.IsError != nil {
resp.IsError = *result.IsError
}
}
}
return resp, err
}
// hookToolWrapper creates a tool wrapper function that applies hook-based
// tool interception. The wrapper references the hook registries directly,
// so hooks registered after agent creation are still called at execution time.
func hookToolWrapper(
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult],
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult],
) func([]fantasy.AgentTool) []fantasy.AgentTool {
return func(tools []fantasy.AgentTool) []fantasy.AgentTool {
wrapped := make([]fantasy.AgentTool, len(tools))
for i, tool := range tools {
wrapped[i] = &hookedTool{
inner: tool,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
}
}
return wrapped
}
}
+548
View File
@@ -0,0 +1,548 @@
package kit
import (
"context"
"fmt"
"sync"
"testing"
"charm.land/fantasy"
)
// ---------------------------------------------------------------------------
// Hook registry tests
// ---------------------------------------------------------------------------
func TestHookRegistry_RegisterAndRun(t *testing.T) {
hr := newHookRegistry[string, string]()
hr.register(HookPriorityNormal, func(input string) *string {
result := "handled: " + input
return &result
})
got := hr.run("hello")
if got == nil {
t.Fatal("expected non-nil result")
}
if *got != "handled: hello" {
t.Errorf("expected 'handled: hello', got %q", *got)
}
}
func TestHookRegistry_FirstNonNilWins(t *testing.T) {
hr := newHookRegistry[string, string]()
// First hook returns nil.
hr.register(HookPriorityNormal, func(_ string) *string {
return nil
})
// Second hook returns a result.
hr.register(HookPriorityNormal, func(input string) *string {
result := "second: " + input
return &result
})
// Third hook would also return, but should never be reached.
hr.register(HookPriorityNormal, func(input string) *string {
result := "third: " + input
return &result
})
got := hr.run("test")
if got == nil {
t.Fatal("expected non-nil result")
}
if *got != "second: test" {
t.Errorf("expected 'second: test', got %q", *got)
}
}
func TestHookRegistry_PriorityOrdering(t *testing.T) {
hr := newHookRegistry[string, string]()
// Register in reverse priority order.
hr.register(HookPriorityLow, func(_ string) *string {
result := "low"
return &result
})
hr.register(HookPriorityHigh, func(_ string) *string {
result := "high"
return &result
})
hr.register(HookPriorityNormal, func(_ string) *string {
result := "normal"
return &result
})
got := hr.run("x")
if got == nil {
t.Fatal("expected non-nil result")
}
if *got != "high" {
t.Errorf("expected 'high' (priority 0 runs first), got %q", *got)
}
}
func TestHookRegistry_SamePriorityPreservesOrder(t *testing.T) {
hr := newHookRegistry[int, string]()
hr.register(HookPriorityNormal, func(n int) *string {
result := "first"
return &result
})
hr.register(HookPriorityNormal, func(n int) *string {
result := "second"
return &result
})
got := hr.run(0)
if got == nil || *got != "first" {
t.Errorf("expected 'first' (insertion order), got %v", got)
}
}
func TestHookRegistry_Unregister(t *testing.T) {
hr := newHookRegistry[string, string]()
unregister := hr.register(HookPriorityNormal, func(input string) *string {
result := "should be gone"
return &result
})
if !hr.hasHooks() {
t.Fatal("expected hasHooks to be true after registration")
}
unregister()
if hr.hasHooks() {
t.Fatal("expected hasHooks to be false after unregister")
}
got := hr.run("test")
if got != nil {
t.Errorf("expected nil after unregister, got %v", *got)
}
}
func TestHookRegistry_NoHooksReturnsNil(t *testing.T) {
hr := newHookRegistry[string, string]()
got := hr.run("test")
if got != nil {
t.Errorf("expected nil when no hooks registered, got %v", *got)
}
}
func TestHookRegistry_HasHooks(t *testing.T) {
hr := newHookRegistry[string, string]()
if hr.hasHooks() {
t.Error("expected hasHooks to be false initially")
}
unsub := hr.register(HookPriorityNormal, func(_ string) *string { return nil })
if !hr.hasHooks() {
t.Error("expected hasHooks to be true after registration")
}
unsub()
if hr.hasHooks() {
t.Error("expected hasHooks to be false after unregister")
}
}
func TestHookRegistry_ConcurrentAccess(t *testing.T) {
hr := newHookRegistry[int, int]()
var wg sync.WaitGroup
const n = 100
// Concurrent registrations.
for range n {
wg.Go(func() {
unsub := hr.register(HookPriorityNormal, func(x int) *int {
result := x * 2
return &result
})
// Immediately unregister half the time.
unsub()
})
}
// Concurrent runs while registrations are happening.
for range n {
wg.Go(func() {
hr.run(42)
})
}
wg.Wait()
}
// ---------------------------------------------------------------------------
// hookedTool tests
// ---------------------------------------------------------------------------
// mockAgentTool implements fantasy.AgentTool for testing.
type mockAgentTool struct {
name string
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
popts fantasy.ProviderOptions
}
func (m *mockAgentTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
}
func (m *mockAgentTool) ProviderOptions() fantasy.ProviderOptions { return m.popts }
func (m *mockAgentTool) SetProviderOptions(o fantasy.ProviderOptions) { m.popts = o }
func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
if m.runFn != nil {
return m.runFn(ctx, call)
}
return fantasy.NewTextResponse("default output"), nil
}
func TestHookedTool_Passthrough(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
mock := &mockAgentTool{
name: "test_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("hello world"), nil
},
}
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "hello world" {
t.Errorf("expected 'hello world', got %q", resp.Content)
}
}
func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
toolRan := false
mock := &mockAgentTool{
name: "dangerous_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolRan = true
return fantasy.NewTextResponse("should not run"), nil
},
}
before.register(HookPriorityHigh, func(h BeforeToolCallHook) *BeforeToolCallResult {
if h.ToolName == "dangerous_tool" {
return &BeforeToolCallResult{Block: true, Reason: "too dangerous"}
}
return nil
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
if err == nil {
t.Fatal("expected error from blocked tool")
}
if toolRan {
t.Error("tool should not have run when blocked")
}
if resp.Content != "Error: too dangerous" {
t.Errorf("expected block error message, got %q", resp.Content)
}
}
func TestHookedTool_BeforeToolCallBlockDefaultReason(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
mock := &mockAgentTool{name: "tool"}
before.register(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult {
return &BeforeToolCallResult{Block: true}
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, _ := ht.Run(context.Background(), fantasy.ToolCall{})
if resp.Content != "Error: blocked by hook" {
t.Errorf("expected default block reason, got %q", resp.Content)
}
}
func TestHookedTool_AfterToolResultModify(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
mock := &mockAgentTool{
name: "tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("secret data"), nil
},
}
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
redacted := "[REDACTED]"
return &AfterToolResultResult{Result: &redacted}
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "[REDACTED]" {
t.Errorf("expected '[REDACTED]', got %q", resp.Content)
}
}
func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
mock := &mockAgentTool{
name: "tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("ok"), nil
},
}
isErr := true
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
return &AfterToolResultResult{IsError: &isErr}
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !resp.IsError {
t.Error("expected IsError to be overridden to true")
}
}
func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
mock := &mockAgentTool{
name: "my_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("result"), nil
},
}
var capturedBefore BeforeToolCallHook
var capturedAfter AfterToolResultHook
before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult {
capturedBefore = h
return nil // don't block
})
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
capturedAfter = h
return nil // don't modify
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
_, _ = ht.Run(context.Background(), fantasy.ToolCall{Input: `{"key":"value"}`})
if capturedBefore.ToolName != "my_tool" {
t.Errorf("BeforeToolCall: expected tool name 'my_tool', got %q", capturedBefore.ToolName)
}
if capturedBefore.ToolArgs != `{"key":"value"}` {
t.Errorf("BeforeToolCall: expected args, got %q", capturedBefore.ToolArgs)
}
if capturedAfter.ToolName != "my_tool" {
t.Errorf("AfterToolResult: expected tool name 'my_tool', got %q", capturedAfter.ToolName)
}
if capturedAfter.Result != "result" {
t.Errorf("AfterToolResult: expected result 'result', got %q", capturedAfter.Result)
}
}
func TestHookedTool_InfoDelegates(t *testing.T) {
mock := &mockAgentTool{name: "delegate_test"}
ht := &hookedTool{
inner: mock,
beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](),
afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](),
}
if ht.Info().Name != "delegate_test" {
t.Errorf("expected Info() to delegate to inner tool")
}
}
// ---------------------------------------------------------------------------
// hookToolWrapper tests
// ---------------------------------------------------------------------------
func TestHookToolWrapper(t *testing.T) {
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
wrapper := hookToolWrapper(before, after)
tools := []fantasy.AgentTool{
&mockAgentTool{name: "tool_a"},
&mockAgentTool{name: "tool_b"},
}
wrapped := wrapper(tools)
if len(wrapped) != 2 {
t.Fatalf("expected 2 wrapped tools, got %d", len(wrapped))
}
// Verify tools are wrapped (different pointer than original).
for i, wt := range wrapped {
if _, ok := wt.(*hookedTool); !ok {
t.Errorf("tool %d: expected *hookedTool, got %T", i, wt)
}
if wt.Info().Name != tools[i].Info().Name {
t.Errorf("tool %d: expected name %q, got %q", i, tools[i].Info().Name, wt.Info().Name)
}
}
// Hooks registered after wrapping should still work.
var blocked bool
before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult {
blocked = true
return &BeforeToolCallResult{Block: true, Reason: "late hook"}
})
_, err := wrapped[0].Run(context.Background(), fantasy.ToolCall{})
if err == nil {
t.Error("expected error from late-registered blocking hook")
}
if !blocked {
t.Error("late-registered hook should have been called")
}
}
// ---------------------------------------------------------------------------
// Hook type tests (BeforeTurn, AfterTurn)
// ---------------------------------------------------------------------------
func TestBeforeTurnHook_PromptOverride(t *testing.T) {
hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
override := "modified prompt"
hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
return &BeforeTurnResult{Prompt: &override}
})
result := hr.run(BeforeTurnHook{Prompt: "original"})
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Prompt == nil || *result.Prompt != "modified prompt" {
t.Errorf("expected prompt override, got %v", result.Prompt)
}
}
func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) {
hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
sysPr := "be concise"
ctx := "project context here"
hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
return &BeforeTurnResult{
SystemPrompt: &sysPr,
InjectText: &ctx,
}
})
result := hr.run(BeforeTurnHook{Prompt: "hello"})
if result == nil {
t.Fatal("expected non-nil result")
}
if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" {
t.Errorf("expected system prompt injection")
}
if result.InjectText == nil || *result.InjectText != "project context here" {
t.Errorf("expected context injection")
}
}
func TestAfterTurnHook_ObservationOnly(t *testing.T) {
hr := newHookRegistry[AfterTurnHook, AfterTurnResult]()
var captured AfterTurnHook
hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult {
captured = h
return nil // observation only
})
hr.run(AfterTurnHook{Response: "agent replied"})
if captured.Response != "agent replied" {
t.Errorf("expected captured response, got %q", captured.Response)
}
}
func TestAfterTurnHook_WithError(t *testing.T) {
hr := newHookRegistry[AfterTurnHook, AfterTurnResult]()
var captured AfterTurnHook
hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult {
captured = h
return nil
})
testErr := fmt.Errorf("generation failed")
hr.run(AfterTurnHook{Error: testErr})
if captured.Error != testErr {
t.Errorf("expected captured error, got %v", captured.Error)
}
}
// ---------------------------------------------------------------------------
// Priority constants sanity check
// ---------------------------------------------------------------------------
func TestHookPriorityOrdering(t *testing.T) {
if HookPriorityHigh >= HookPriorityNormal {
t.Error("HookPriorityHigh should be less than HookPriorityNormal")
}
if HookPriorityNormal >= HookPriorityLow {
t.Error("HookPriorityNormal should be less than HookPriorityLow")
}
}
// ---------------------------------------------------------------------------
// Kit method compilation tests (verify API surface exists)
// ---------------------------------------------------------------------------
func TestKit_HookMethodsExist(t *testing.T) {
k := &Kit{
events: newEventBus(),
beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](),
afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](),
beforeTurn: newHookRegistry[BeforeTurnHook, BeforeTurnResult](),
afterTurn: newHookRegistry[AfterTurnHook, AfterTurnResult](),
}
// Verify all hook registration methods return unsubscribe functions.
u1 := k.OnBeforeToolCall(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult {
return nil
})
u2 := k.OnAfterToolResult(HookPriorityNormal, func(_ AfterToolResultHook) *AfterToolResultResult {
return nil
})
u3 := k.OnBeforeTurn(HookPriorityNormal, func(_ BeforeTurnHook) *BeforeTurnResult {
return nil
})
u4 := k.OnAfterTurn(HookPriorityNormal, func(_ AfterTurnHook) {})
// All should be callable.
u1()
u2()
u3()
u4()
}
+90 -24
View File
@@ -27,6 +27,12 @@ type Kit struct {
autoCompact bool
compactionOpts *CompactionOptions
skills []*skills.Skill
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
beforeTurn *hookRegistry[BeforeTurnHook, BeforeTurnResult]
afterTurn *hookRegistry[AfterTurnHook, AfterTurnResult]
}
// Subscribe registers an EventListener that will be called for every lifecycle
@@ -47,6 +53,7 @@ type Options struct {
Streaming bool // Enable streaming (default from config)
Quiet bool // Suppress debug output
Tools []Tool // Custom tool set. If empty, AllTools() is used.
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
// Session configuration
SessionDir string // Base directory for session discovery (default: cwd)
@@ -148,11 +155,21 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, fmt.Errorf("failed to load MCP config: %w", err)
}
// Create agent using shared setup.
// Pre-create hook registries so the tool wrapper can reference them.
// Hooks registered after New() returns are still invoked because the
// wrapper captures the registries by pointer.
beforeToolCall := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
afterToolResult := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
beforeTurn := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
afterTurn := newHookRegistry[AfterTurnHook, AfterTurnResult]()
// Create agent using shared setup with the hook tool wrapper.
agentResult, err := SetupAgent(ctx, AgentSetupOptions{
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
})
if err != nil {
return nil, err
@@ -165,15 +182,26 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, fmt.Errorf("failed to initialize session: %w", err)
}
return &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
modelString: viper.GetString("model"),
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
skills: loadedSkills,
}, nil
k := &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
modelString: viper.GetString("model"),
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
skills: loadedSkills,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
afterTurn: afterTurn,
}
// Bridge extension events to SDK hooks.
if agentResult.ExtRunner != nil {
k.bridgeExtensions(agentResult.ExtRunner)
}
return k, nil
}
// GetSkills returns the skills loaded during initialisation.
@@ -277,15 +305,44 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
}
// runTurn is the shared lifecycle for every prompt mode:
// 1. Persist pre-generation messages to the tree session.
// 2. Build context from the tree (walks leaf-to-root for current branch).
// 3. Emit turn/message start events.
// 4. Run generation.
// 5. Emit turn/message end events.
// 6. Persist post-generation messages (tool calls, results, assistant).
// 1. Run BeforeTurn hooks (can modify prompt, inject messages).
// 2. Persist pre-generation messages to the tree session.
// 3. Build context from the tree (walks leaf-to-root for current branch).
// 4. Emit turn/message start events.
// 5. Run generation.
// 6. Emit turn/message end events.
// 7. Persist post-generation messages (tool calls, results, assistant).
// 8. Run AfterTurn hooks.
//
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fantasy.Message) (string, error) {
// prompt is the raw user text passed to BeforeTurn hooks.
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (string, error) {
// Run BeforeTurn hooks — can modify the prompt, inject system/context messages.
if m.beforeTurn.hasHooks() {
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
// Override prompt text in the last user message.
if hookResult.Prompt != nil {
for i := len(preMessages) - 1; i >= 0; i-- {
if preMessages[i].Role == fantasy.MessageRoleUser {
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt)
break
}
}
}
// Inject messages before the original preMessages.
var injected []fantasy.Message
if hookResult.SystemPrompt != nil {
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
}
if hookResult.InjectText != nil {
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
}
if len(injected) > 0 {
preMessages = append(injected, preMessages...)
}
}
}
// Persist pre-generation messages to tree session.
for _, msg := range preMessages {
_, _ = m.treeSession.AppendFantasyMessage(msg)
@@ -306,6 +363,10 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
result, err := m.generate(ctx, messages)
if err != nil {
m.events.emit(TurnEndEvent{Error: err})
// Run AfterTurn hooks even on error.
if m.afterTurn.hasHooks() {
m.afterTurn.run(AfterTurnHook{Error: err})
}
return "", err
}
@@ -321,6 +382,11 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
}
}
// Run AfterTurn hooks.
if m.afterTurn.hasHooks() {
m.afterTurn.run(AfterTurnHook{Response: responseText})
}
return responseText, nil
}
@@ -333,7 +399,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
// automatically maintained in the tree session. Lifecycle events are emitted
// to all registered subscribers. Returns an error if generation fails.
func (m *Kit) Prompt(ctx context.Context, message string) (string, error) {
return m.runTurn(ctx, message, []fantasy.Message{
return m.runTurn(ctx, message, message, []fantasy.Message{
fantasy.NewUserMessage(message),
})
}
@@ -346,7 +412,7 @@ func (m *Kit) Prompt(ctx context.Context, message string) (string, error) {
// a synthetic user message so the agent acknowledges and follows the directive.
// Both messages are persisted to the session.
func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
return m.runTurn(ctx, "[steer] "+instruction, []fantasy.Message{
return m.runTurn(ctx, "[steer] "+instruction, instruction, []fantasy.Message{
fantasy.NewSystemMessage(instruction),
fantasy.NewUserMessage("Please acknowledge and follow the above instruction."),
})
@@ -367,7 +433,7 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
text = "Continue."
}
return m.runTurn(ctx, "[follow-up]", []fantasy.Message{
return m.runTurn(ctx, "[follow-up]", text, []fantasy.Message{
fantasy.NewUserMessage(text),
})
}
@@ -390,7 +456,7 @@ func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOpti
}
preMessages = append(preMessages, fantasy.NewUserMessage(msg))
return m.runTurn(ctx, msg, preMessages)
return m.runTurn(ctx, msg, msg, preMessages)
}
// PromptWithCallbacks sends a message with callbacks for monitoring tool
+29 -2
View File
@@ -31,6 +31,13 @@ type AgentSetupOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
CoreTools []fantasy.AgentTool
// ExtraTools are additional tools added alongside core, MCP, and extension
// tools. They do not replace the defaults — they extend them.
ExtraTools []fantasy.AgentTool
// ToolWrapper is an optional function that wraps tools after extension
// wrapping. Used by the SDK hook system. Both wrappers compose:
// extension wrapper runs first (inner), then this wrapper (outer).
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
}
// AgentSetupResult bundles the created agent and any debug logger so the caller
@@ -106,6 +113,26 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
}
}
// Compose tool wrappers: extension wrapper (inner) + caller wrapper (outer).
toolWrapper := extCreationOpts.toolWrapper
if opts.ToolWrapper != nil {
if toolWrapper != nil {
inner := toolWrapper
outer := opts.ToolWrapper
toolWrapper = func(t []fantasy.AgentTool) []fantasy.AgentTool {
return outer(inner(t))
}
} else {
toolWrapper = opts.ToolWrapper
}
}
// Merge extra tools: extension tools + caller extra tools.
extraTools := extCreationOpts.extraTools
if len(opts.ExtraTools) > 0 {
extraTools = append(extraTools, opts.ExtraTools...)
}
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
@@ -117,8 +144,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
CoreTools: opts.CoreTools,
ToolWrapper: extCreationOpts.toolWrapper,
ExtraTools: extCreationOpts.extraTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)