mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
add extension hook system with priority ordering and tool interception (Plan 09)
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package kit
|
||||
|
||||
import "github.com/mark3labs/kit/internal/extensions"
|
||||
|
||||
// bridgeExtensions registers extension event handlers as SDK hooks. This makes
|
||||
// the existing extension system a consumer of the SDK hook API, proving the
|
||||
// hook surface is production-ready.
|
||||
//
|
||||
// Phase 1 (this plan): bridge BeforeAgentStart and Input as BeforeTurn hooks.
|
||||
// Tool-level events (ToolCall, ToolResult) are already handled by the extension
|
||||
// tool wrapper (internal/extensions/wrapper.go) which composes underneath the
|
||||
// SDK hook wrapper.
|
||||
//
|
||||
// Phase 2 (future): app.executeStep() migrates to SDK hooks exclusively.
|
||||
// Phase 3 (future): extension runner emits SDK events/hooks natively.
|
||||
func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Extension Input → BeforeTurn hook (high priority, runs first).
|
||||
// An Input handler with Action="transform" replaces the prompt text.
|
||||
if runner.HasHandlers(extensions.Input) {
|
||||
m.OnBeforeTurn(HookPriorityHigh, func(h BeforeTurnHook) *BeforeTurnResult {
|
||||
result, _ := runner.Emit(extensions.InputEvent{Text: h.Prompt})
|
||||
if r, ok := result.(extensions.InputResult); ok {
|
||||
if r.Action == "transform" {
|
||||
return &BeforeTurnResult{Prompt: &r.Text}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Extension BeforeAgentStart → BeforeTurn hook (normal priority).
|
||||
// Can inject a system prompt prefix and/or context text.
|
||||
if runner.HasHandlers(extensions.BeforeAgentStart) {
|
||||
m.OnBeforeTurn(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
|
||||
result, _ := runner.Emit(extensions.BeforeAgentStartEvent{Prompt: h.Prompt})
|
||||
if r, ok := result.(extensions.BeforeAgentStartResult); ok {
|
||||
return &BeforeTurnResult{
|
||||
SystemPrompt: r.SystemPrompt,
|
||||
InjectText: r.InjectText,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Priority
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// HookPriority controls execution order of hooks. Lower values run first.
|
||||
type HookPriority int
|
||||
|
||||
const (
|
||||
// HookPriorityHigh runs before normal hooks.
|
||||
HookPriorityHigh HookPriority = 0
|
||||
// HookPriorityNormal is the default priority.
|
||||
HookPriorityNormal HookPriority = 50
|
||||
// HookPriorityLow runs after normal hooks.
|
||||
HookPriorityLow HookPriority = 100
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hook input/result types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// BeforeToolCallHook is the input for hooks that fire before a tool executes.
|
||||
type BeforeToolCallHook struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
}
|
||||
|
||||
// BeforeToolCallResult controls whether the tool call proceeds.
|
||||
type BeforeToolCallResult struct {
|
||||
Block bool // true prevents the tool from running
|
||||
Reason string // human-readable reason for blocking
|
||||
}
|
||||
|
||||
// AfterToolResultHook is the input for hooks that fire after a tool executes.
|
||||
type AfterToolResultHook struct {
|
||||
ToolName string
|
||||
ToolArgs string
|
||||
Result string
|
||||
IsError bool
|
||||
}
|
||||
|
||||
// AfterToolResultResult can modify the tool's output before it reaches the LLM.
|
||||
type AfterToolResultResult struct {
|
||||
Result *string // non-nil overrides the result text
|
||||
IsError *bool // non-nil overrides the error flag
|
||||
}
|
||||
|
||||
// BeforeTurnHook is the input for hooks that fire before a prompt turn.
|
||||
type BeforeTurnHook struct {
|
||||
Prompt string
|
||||
}
|
||||
|
||||
// BeforeTurnResult can modify the prompt, inject system messages, or add context.
|
||||
type BeforeTurnResult struct {
|
||||
Prompt *string // override prompt text in the user message
|
||||
SystemPrompt *string // prepend a system message
|
||||
InjectText *string // prepend a user context message
|
||||
}
|
||||
|
||||
// AfterTurnHook is the input for hooks that fire after a prompt turn completes.
|
||||
type AfterTurnHook struct {
|
||||
Response string
|
||||
Error error
|
||||
}
|
||||
|
||||
// AfterTurnResult is a placeholder — after-turn hooks are observation-only.
|
||||
type AfterTurnResult struct{}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Generic hook registry with priority ordering
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type hookEntry[In any, Out any] struct {
|
||||
id int
|
||||
priority HookPriority
|
||||
handler func(In) *Out
|
||||
}
|
||||
|
||||
type hookRegistry[In any, Out any] struct {
|
||||
mu sync.RWMutex
|
||||
hooks []hookEntry[In, Out]
|
||||
next int
|
||||
}
|
||||
|
||||
func newHookRegistry[In any, Out any]() *hookRegistry[In, Out] {
|
||||
return &hookRegistry[In, Out]{}
|
||||
}
|
||||
|
||||
// register adds a hook with the given priority and returns an unregister
|
||||
// function. Within the same priority, hooks run in registration order.
|
||||
func (hr *hookRegistry[In, Out]) register(p HookPriority, h func(In) *Out) func() {
|
||||
hr.mu.Lock()
|
||||
id := hr.next
|
||||
hr.next++
|
||||
hr.hooks = append(hr.hooks, hookEntry[In, Out]{id: id, priority: p, handler: h})
|
||||
// Stable sort preserves insertion order within the same priority.
|
||||
sort.SliceStable(hr.hooks, func(i, j int) bool {
|
||||
return hr.hooks[i].priority < hr.hooks[j].priority
|
||||
})
|
||||
hr.mu.Unlock()
|
||||
|
||||
return func() {
|
||||
hr.mu.Lock()
|
||||
defer hr.mu.Unlock()
|
||||
for i, entry := range hr.hooks {
|
||||
if entry.id == id {
|
||||
hr.hooks = append(hr.hooks[:i], hr.hooks[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// run executes all hooks in priority order. The first non-nil result wins.
|
||||
func (hr *hookRegistry[In, Out]) run(input In) *Out {
|
||||
hr.mu.RLock()
|
||||
snapshot := make([]hookEntry[In, Out], len(hr.hooks))
|
||||
copy(snapshot, hr.hooks)
|
||||
hr.mu.RUnlock()
|
||||
|
||||
for _, entry := range snapshot {
|
||||
if result := entry.handler(input); result != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasHooks returns true if any hooks are registered.
|
||||
func (hr *hookRegistry[In, Out]) hasHooks() bool {
|
||||
hr.mu.RLock()
|
||||
defer hr.mu.RUnlock()
|
||||
return len(hr.hooks) > 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hook registration methods on Kit
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// OnBeforeToolCall registers a hook that fires before each tool execution.
|
||||
// Return a non-nil BeforeToolCallResult with Block=true to prevent the tool
|
||||
// from running. Hooks execute in priority order; the first non-nil result wins.
|
||||
// Returns an unregister function.
|
||||
func (m *Kit) OnBeforeToolCall(p HookPriority, h func(BeforeToolCallHook) *BeforeToolCallResult) func() {
|
||||
return m.beforeToolCall.register(p, h)
|
||||
}
|
||||
|
||||
// OnAfterToolResult registers a hook that fires after each tool execution.
|
||||
// Return a non-nil AfterToolResultResult to modify the tool's output before
|
||||
// it reaches the LLM. Hooks execute in priority order; the first non-nil
|
||||
// result wins. Returns an unregister function.
|
||||
func (m *Kit) OnAfterToolResult(p HookPriority, h func(AfterToolResultHook) *AfterToolResultResult) func() {
|
||||
return m.afterToolResult.register(p, h)
|
||||
}
|
||||
|
||||
// OnBeforeTurn registers a hook that fires before each prompt turn. Return
|
||||
// a non-nil BeforeTurnResult to modify the prompt, inject a system message,
|
||||
// or prepend context. Hooks execute in priority order; the first non-nil
|
||||
// result wins. Returns an unregister function.
|
||||
func (m *Kit) OnBeforeTurn(p HookPriority, h func(BeforeTurnHook) *BeforeTurnResult) func() {
|
||||
return m.beforeTurn.register(p, h)
|
||||
}
|
||||
|
||||
// OnAfterTurn registers a hook that fires after each prompt turn completes.
|
||||
// This is observation-only — the handler cannot modify the response. Hooks
|
||||
// execute in priority order. Returns an unregister function.
|
||||
func (m *Kit) OnAfterTurn(p HookPriority, h func(AfterTurnHook)) func() {
|
||||
return m.afterTurn.register(p, func(input AfterTurnHook) *AfterTurnResult {
|
||||
h(input)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool wrapping via hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// hookedTool wraps a fantasy.AgentTool to run BeforeToolCall and
|
||||
// AfterToolResult hooks around each execution. The registries are referenced
|
||||
// by pointer so hooks added after agent creation are still invoked.
|
||||
type hookedTool struct {
|
||||
inner fantasy.AgentTool
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
|
||||
}
|
||||
|
||||
func (h *hookedTool) Info() fantasy.ToolInfo { return h.inner.Info() }
|
||||
func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions { return h.inner.ProviderOptions() }
|
||||
func (h *hookedTool) SetProviderOptions(o fantasy.ProviderOptions) { h.inner.SetProviderOptions(o) }
|
||||
|
||||
func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolName := h.inner.Info().Name
|
||||
|
||||
// 1. BeforeToolCall — can block execution.
|
||||
if h.beforeToolCall.hasHooks() {
|
||||
if result := h.beforeToolCall.run(BeforeToolCallHook{
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
}); result != nil && result.Block {
|
||||
reason := result.Reason
|
||||
if reason == "" {
|
||||
reason = "blocked by hook"
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
fmt.Errorf("tool blocked by hook: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Execute actual tool.
|
||||
resp, err := h.inner.Run(ctx, call)
|
||||
|
||||
// 3. AfterToolResult — can modify output.
|
||||
if h.afterToolResult.hasHooks() {
|
||||
if result := h.afterToolResult.run(AfterToolResultHook{
|
||||
ToolName: toolName,
|
||||
ToolArgs: call.Input,
|
||||
Result: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
}); result != nil {
|
||||
if result.Result != nil {
|
||||
resp.Content = *result.Result
|
||||
}
|
||||
if result.IsError != nil {
|
||||
resp.IsError = *result.IsError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// hookToolWrapper creates a tool wrapper function that applies hook-based
|
||||
// tool interception. The wrapper references the hook registries directly,
|
||||
// so hooks registered after agent creation are still called at execution time.
|
||||
func hookToolWrapper(
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult],
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult],
|
||||
) func([]fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
wrapped := make([]fantasy.AgentTool, len(tools))
|
||||
for i, tool := range tools {
|
||||
wrapped[i] = &hookedTool{
|
||||
inner: tool,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
}
|
||||
}
|
||||
return wrapped
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hook registry tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHookRegistry_RegisterAndRun(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
hr.register(HookPriorityNormal, func(input string) *string {
|
||||
result := "handled: " + input
|
||||
return &result
|
||||
})
|
||||
|
||||
got := hr.run("hello")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if *got != "handled: hello" {
|
||||
t.Errorf("expected 'handled: hello', got %q", *got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_FirstNonNilWins(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
// First hook returns nil.
|
||||
hr.register(HookPriorityNormal, func(_ string) *string {
|
||||
return nil
|
||||
})
|
||||
// Second hook returns a result.
|
||||
hr.register(HookPriorityNormal, func(input string) *string {
|
||||
result := "second: " + input
|
||||
return &result
|
||||
})
|
||||
// Third hook would also return, but should never be reached.
|
||||
hr.register(HookPriorityNormal, func(input string) *string {
|
||||
result := "third: " + input
|
||||
return &result
|
||||
})
|
||||
|
||||
got := hr.run("test")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if *got != "second: test" {
|
||||
t.Errorf("expected 'second: test', got %q", *got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_PriorityOrdering(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
// Register in reverse priority order.
|
||||
hr.register(HookPriorityLow, func(_ string) *string {
|
||||
result := "low"
|
||||
return &result
|
||||
})
|
||||
hr.register(HookPriorityHigh, func(_ string) *string {
|
||||
result := "high"
|
||||
return &result
|
||||
})
|
||||
hr.register(HookPriorityNormal, func(_ string) *string {
|
||||
result := "normal"
|
||||
return &result
|
||||
})
|
||||
|
||||
got := hr.run("x")
|
||||
if got == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if *got != "high" {
|
||||
t.Errorf("expected 'high' (priority 0 runs first), got %q", *got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_SamePriorityPreservesOrder(t *testing.T) {
|
||||
hr := newHookRegistry[int, string]()
|
||||
|
||||
hr.register(HookPriorityNormal, func(n int) *string {
|
||||
result := "first"
|
||||
return &result
|
||||
})
|
||||
hr.register(HookPriorityNormal, func(n int) *string {
|
||||
result := "second"
|
||||
return &result
|
||||
})
|
||||
|
||||
got := hr.run(0)
|
||||
if got == nil || *got != "first" {
|
||||
t.Errorf("expected 'first' (insertion order), got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_Unregister(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
unregister := hr.register(HookPriorityNormal, func(input string) *string {
|
||||
result := "should be gone"
|
||||
return &result
|
||||
})
|
||||
|
||||
if !hr.hasHooks() {
|
||||
t.Fatal("expected hasHooks to be true after registration")
|
||||
}
|
||||
|
||||
unregister()
|
||||
|
||||
if hr.hasHooks() {
|
||||
t.Fatal("expected hasHooks to be false after unregister")
|
||||
}
|
||||
|
||||
got := hr.run("test")
|
||||
if got != nil {
|
||||
t.Errorf("expected nil after unregister, got %v", *got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_NoHooksReturnsNil(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
got := hr.run("test")
|
||||
if got != nil {
|
||||
t.Errorf("expected nil when no hooks registered, got %v", *got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_HasHooks(t *testing.T) {
|
||||
hr := newHookRegistry[string, string]()
|
||||
|
||||
if hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be false initially")
|
||||
}
|
||||
|
||||
unsub := hr.register(HookPriorityNormal, func(_ string) *string { return nil })
|
||||
if !hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be true after registration")
|
||||
}
|
||||
|
||||
unsub()
|
||||
if hr.hasHooks() {
|
||||
t.Error("expected hasHooks to be false after unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookRegistry_ConcurrentAccess(t *testing.T) {
|
||||
hr := newHookRegistry[int, int]()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const n = 100
|
||||
|
||||
// Concurrent registrations.
|
||||
for range n {
|
||||
wg.Go(func() {
|
||||
unsub := hr.register(HookPriorityNormal, func(x int) *int {
|
||||
result := x * 2
|
||||
return &result
|
||||
})
|
||||
// Immediately unregister half the time.
|
||||
unsub()
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrent runs while registrations are happening.
|
||||
for range n {
|
||||
wg.Go(func() {
|
||||
hr.run(42)
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hookedTool tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// mockAgentTool implements fantasy.AgentTool for testing.
|
||||
type mockAgentTool struct {
|
||||
name string
|
||||
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
|
||||
popts fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
func (m *mockAgentTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
|
||||
}
|
||||
func (m *mockAgentTool) ProviderOptions() fantasy.ProviderOptions { return m.popts }
|
||||
func (m *mockAgentTool) SetProviderOptions(o fantasy.ProviderOptions) { m.popts = o }
|
||||
func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if m.runFn != nil {
|
||||
return m.runFn(ctx, call)
|
||||
}
|
||||
return fantasy.NewTextResponse("default output"), nil
|
||||
}
|
||||
|
||||
func TestHookedTool_Passthrough(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "test_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("hello world"), nil
|
||||
},
|
||||
}
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "hello world" {
|
||||
t.Errorf("expected 'hello world', got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
toolRan := false
|
||||
mock := &mockAgentTool{
|
||||
name: "dangerous_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolRan = true
|
||||
return fantasy.NewTextResponse("should not run"), nil
|
||||
},
|
||||
}
|
||||
|
||||
before.register(HookPriorityHigh, func(h BeforeToolCallHook) *BeforeToolCallResult {
|
||||
if h.ToolName == "dangerous_tool" {
|
||||
return &BeforeToolCallResult{Block: true, Reason: "too dangerous"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from blocked tool")
|
||||
}
|
||||
if toolRan {
|
||||
t.Error("tool should not have run when blocked")
|
||||
}
|
||||
if resp.Content != "Error: too dangerous" {
|
||||
t.Errorf("expected block error message, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_BeforeToolCallBlockDefaultReason(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
mock := &mockAgentTool{name: "tool"}
|
||||
before.register(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult {
|
||||
return &BeforeToolCallResult{Block: true}
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, _ := ht.Run(context.Background(), fantasy.ToolCall{})
|
||||
if resp.Content != "Error: blocked by hook" {
|
||||
t.Errorf("expected default block reason, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_AfterToolResultModify(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("secret data"), nil
|
||||
},
|
||||
}
|
||||
|
||||
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
|
||||
redacted := "[REDACTED]"
|
||||
return &AfterToolResultResult{Result: &redacted}
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "[REDACTED]" {
|
||||
t.Errorf("expected '[REDACTED]', got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("ok"), nil
|
||||
},
|
||||
}
|
||||
|
||||
isErr := true
|
||||
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
|
||||
return &AfterToolResultResult{IsError: &isErr}
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected IsError to be overridden to true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "my_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("result"), nil
|
||||
},
|
||||
}
|
||||
|
||||
var capturedBefore BeforeToolCallHook
|
||||
var capturedAfter AfterToolResultHook
|
||||
|
||||
before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult {
|
||||
capturedBefore = h
|
||||
return nil // don't block
|
||||
})
|
||||
after.register(HookPriorityNormal, func(h AfterToolResultHook) *AfterToolResultResult {
|
||||
capturedAfter = h
|
||||
return nil // don't modify
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
_, _ = ht.Run(context.Background(), fantasy.ToolCall{Input: `{"key":"value"}`})
|
||||
|
||||
if capturedBefore.ToolName != "my_tool" {
|
||||
t.Errorf("BeforeToolCall: expected tool name 'my_tool', got %q", capturedBefore.ToolName)
|
||||
}
|
||||
if capturedBefore.ToolArgs != `{"key":"value"}` {
|
||||
t.Errorf("BeforeToolCall: expected args, got %q", capturedBefore.ToolArgs)
|
||||
}
|
||||
if capturedAfter.ToolName != "my_tool" {
|
||||
t.Errorf("AfterToolResult: expected tool name 'my_tool', got %q", capturedAfter.ToolName)
|
||||
}
|
||||
if capturedAfter.Result != "result" {
|
||||
t.Errorf("AfterToolResult: expected result 'result', got %q", capturedAfter.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookedTool_InfoDelegates(t *testing.T) {
|
||||
mock := &mockAgentTool{name: "delegate_test"}
|
||||
ht := &hookedTool{
|
||||
inner: mock,
|
||||
beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](),
|
||||
afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](),
|
||||
}
|
||||
|
||||
if ht.Info().Name != "delegate_test" {
|
||||
t.Errorf("expected Info() to delegate to inner tool")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// hookToolWrapper tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHookToolWrapper(t *testing.T) {
|
||||
before := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
after := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
|
||||
wrapper := hookToolWrapper(before, after)
|
||||
|
||||
tools := []fantasy.AgentTool{
|
||||
&mockAgentTool{name: "tool_a"},
|
||||
&mockAgentTool{name: "tool_b"},
|
||||
}
|
||||
|
||||
wrapped := wrapper(tools)
|
||||
if len(wrapped) != 2 {
|
||||
t.Fatalf("expected 2 wrapped tools, got %d", len(wrapped))
|
||||
}
|
||||
|
||||
// Verify tools are wrapped (different pointer than original).
|
||||
for i, wt := range wrapped {
|
||||
if _, ok := wt.(*hookedTool); !ok {
|
||||
t.Errorf("tool %d: expected *hookedTool, got %T", i, wt)
|
||||
}
|
||||
if wt.Info().Name != tools[i].Info().Name {
|
||||
t.Errorf("tool %d: expected name %q, got %q", i, tools[i].Info().Name, wt.Info().Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Hooks registered after wrapping should still work.
|
||||
var blocked bool
|
||||
before.register(HookPriorityNormal, func(h BeforeToolCallHook) *BeforeToolCallResult {
|
||||
blocked = true
|
||||
return &BeforeToolCallResult{Block: true, Reason: "late hook"}
|
||||
})
|
||||
|
||||
_, err := wrapped[0].Run(context.Background(), fantasy.ToolCall{})
|
||||
if err == nil {
|
||||
t.Error("expected error from late-registered blocking hook")
|
||||
}
|
||||
if !blocked {
|
||||
t.Error("late-registered hook should have been called")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hook type tests (BeforeTurn, AfterTurn)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBeforeTurnHook_PromptOverride(t *testing.T) {
|
||||
hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
|
||||
|
||||
override := "modified prompt"
|
||||
hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
|
||||
return &BeforeTurnResult{Prompt: &override}
|
||||
})
|
||||
|
||||
result := hr.run(BeforeTurnHook{Prompt: "original"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.Prompt == nil || *result.Prompt != "modified prompt" {
|
||||
t.Errorf("expected prompt override, got %v", result.Prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBeforeTurnHook_InjectSystemAndContext(t *testing.T) {
|
||||
hr := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
|
||||
|
||||
sysPr := "be concise"
|
||||
ctx := "project context here"
|
||||
hr.register(HookPriorityNormal, func(h BeforeTurnHook) *BeforeTurnResult {
|
||||
return &BeforeTurnResult{
|
||||
SystemPrompt: &sysPr,
|
||||
InjectText: &ctx,
|
||||
}
|
||||
})
|
||||
|
||||
result := hr.run(BeforeTurnHook{Prompt: "hello"})
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.SystemPrompt == nil || *result.SystemPrompt != "be concise" {
|
||||
t.Errorf("expected system prompt injection")
|
||||
}
|
||||
if result.InjectText == nil || *result.InjectText != "project context here" {
|
||||
t.Errorf("expected context injection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterTurnHook_ObservationOnly(t *testing.T) {
|
||||
hr := newHookRegistry[AfterTurnHook, AfterTurnResult]()
|
||||
|
||||
var captured AfterTurnHook
|
||||
hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult {
|
||||
captured = h
|
||||
return nil // observation only
|
||||
})
|
||||
|
||||
hr.run(AfterTurnHook{Response: "agent replied"})
|
||||
if captured.Response != "agent replied" {
|
||||
t.Errorf("expected captured response, got %q", captured.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAfterTurnHook_WithError(t *testing.T) {
|
||||
hr := newHookRegistry[AfterTurnHook, AfterTurnResult]()
|
||||
|
||||
var captured AfterTurnHook
|
||||
hr.register(HookPriorityNormal, func(h AfterTurnHook) *AfterTurnResult {
|
||||
captured = h
|
||||
return nil
|
||||
})
|
||||
|
||||
testErr := fmt.Errorf("generation failed")
|
||||
hr.run(AfterTurnHook{Error: testErr})
|
||||
if captured.Error != testErr {
|
||||
t.Errorf("expected captured error, got %v", captured.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Priority constants sanity check
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestHookPriorityOrdering(t *testing.T) {
|
||||
if HookPriorityHigh >= HookPriorityNormal {
|
||||
t.Error("HookPriorityHigh should be less than HookPriorityNormal")
|
||||
}
|
||||
if HookPriorityNormal >= HookPriorityLow {
|
||||
t.Error("HookPriorityNormal should be less than HookPriorityLow")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Kit method compilation tests (verify API surface exists)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestKit_HookMethodsExist(t *testing.T) {
|
||||
k := &Kit{
|
||||
events: newEventBus(),
|
||||
beforeToolCall: newHookRegistry[BeforeToolCallHook, BeforeToolCallResult](),
|
||||
afterToolResult: newHookRegistry[AfterToolResultHook, AfterToolResultResult](),
|
||||
beforeTurn: newHookRegistry[BeforeTurnHook, BeforeTurnResult](),
|
||||
afterTurn: newHookRegistry[AfterTurnHook, AfterTurnResult](),
|
||||
}
|
||||
|
||||
// Verify all hook registration methods return unsubscribe functions.
|
||||
u1 := k.OnBeforeToolCall(HookPriorityNormal, func(_ BeforeToolCallHook) *BeforeToolCallResult {
|
||||
return nil
|
||||
})
|
||||
u2 := k.OnAfterToolResult(HookPriorityNormal, func(_ AfterToolResultHook) *AfterToolResultResult {
|
||||
return nil
|
||||
})
|
||||
u3 := k.OnBeforeTurn(HookPriorityNormal, func(_ BeforeTurnHook) *BeforeTurnResult {
|
||||
return nil
|
||||
})
|
||||
u4 := k.OnAfterTurn(HookPriorityNormal, func(_ AfterTurnHook) {})
|
||||
|
||||
// All should be callable.
|
||||
u1()
|
||||
u2()
|
||||
u3()
|
||||
u4()
|
||||
}
|
||||
+90
-24
@@ -27,6 +27,12 @@ type Kit struct {
|
||||
autoCompact bool
|
||||
compactionOpts *CompactionOptions
|
||||
skills []*skills.Skill
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
|
||||
beforeTurn *hookRegistry[BeforeTurnHook, BeforeTurnResult]
|
||||
afterTurn *hookRegistry[AfterTurnHook, AfterTurnResult]
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
@@ -47,6 +53,7 @@ type Options struct {
|
||||
Streaming bool // Enable streaming (default from config)
|
||||
Quiet bool // Suppress debug output
|
||||
Tools []Tool // Custom tool set. If empty, AllTools() is used.
|
||||
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
|
||||
|
||||
// Session configuration
|
||||
SessionDir string // Base directory for session discovery (default: cwd)
|
||||
@@ -148,11 +155,21 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
return nil, fmt.Errorf("failed to load MCP config: %w", err)
|
||||
}
|
||||
|
||||
// Create agent using shared setup.
|
||||
// Pre-create hook registries so the tool wrapper can reference them.
|
||||
// Hooks registered after New() returns are still invoked because the
|
||||
// wrapper captures the registries by pointer.
|
||||
beforeToolCall := newHookRegistry[BeforeToolCallHook, BeforeToolCallResult]()
|
||||
afterToolResult := newHookRegistry[AfterToolResultHook, AfterToolResultResult]()
|
||||
beforeTurn := newHookRegistry[BeforeTurnHook, BeforeTurnResult]()
|
||||
afterTurn := newHookRegistry[AfterTurnHook, AfterTurnResult]()
|
||||
|
||||
// Create agent using shared setup with the hook tool wrapper.
|
||||
agentResult, err := SetupAgent(ctx, AgentSetupOptions{
|
||||
MCPConfig: mcpConfig,
|
||||
Quiet: opts.Quiet,
|
||||
CoreTools: opts.Tools,
|
||||
MCPConfig: mcpConfig,
|
||||
Quiet: opts.Quiet,
|
||||
CoreTools: opts.Tools,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -165,15 +182,26 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
return nil, fmt.Errorf("failed to initialize session: %w", err)
|
||||
}
|
||||
|
||||
return &Kit{
|
||||
agent: agentResult.Agent,
|
||||
treeSession: treeSession,
|
||||
modelString: viper.GetString("model"),
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
skills: loadedSkills,
|
||||
}, nil
|
||||
k := &Kit{
|
||||
agent: agentResult.Agent,
|
||||
treeSession: treeSession,
|
||||
modelString: viper.GetString("model"),
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
skills: loadedSkills,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
afterTurn: afterTurn,
|
||||
}
|
||||
|
||||
// Bridge extension events to SDK hooks.
|
||||
if agentResult.ExtRunner != nil {
|
||||
k.bridgeExtensions(agentResult.ExtRunner)
|
||||
}
|
||||
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// GetSkills returns the skills loaded during initialisation.
|
||||
@@ -277,15 +305,44 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
}
|
||||
|
||||
// runTurn is the shared lifecycle for every prompt mode:
|
||||
// 1. Persist pre-generation messages to the tree session.
|
||||
// 2. Build context from the tree (walks leaf-to-root for current branch).
|
||||
// 3. Emit turn/message start events.
|
||||
// 4. Run generation.
|
||||
// 5. Emit turn/message end events.
|
||||
// 6. Persist post-generation messages (tool calls, results, assistant).
|
||||
// 1. Run BeforeTurn hooks (can modify prompt, inject messages).
|
||||
// 2. Persist pre-generation messages to the tree session.
|
||||
// 3. Build context from the tree (walks leaf-to-root for current branch).
|
||||
// 4. Emit turn/message start events.
|
||||
// 5. Run generation.
|
||||
// 6. Emit turn/message end events.
|
||||
// 7. Persist post-generation messages (tool calls, results, assistant).
|
||||
// 8. Run AfterTurn hooks.
|
||||
//
|
||||
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
|
||||
func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fantasy.Message) (string, error) {
|
||||
// prompt is the raw user text passed to BeforeTurn hooks.
|
||||
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (string, error) {
|
||||
// Run BeforeTurn hooks — can modify the prompt, inject system/context messages.
|
||||
if m.beforeTurn.hasHooks() {
|
||||
if hookResult := m.beforeTurn.run(BeforeTurnHook{Prompt: prompt}); hookResult != nil {
|
||||
// Override prompt text in the last user message.
|
||||
if hookResult.Prompt != nil {
|
||||
for i := len(preMessages) - 1; i >= 0; i-- {
|
||||
if preMessages[i].Role == fantasy.MessageRoleUser {
|
||||
preMessages[i] = fantasy.NewUserMessage(*hookResult.Prompt)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Inject messages before the original preMessages.
|
||||
var injected []fantasy.Message
|
||||
if hookResult.SystemPrompt != nil {
|
||||
injected = append(injected, fantasy.NewSystemMessage(*hookResult.SystemPrompt))
|
||||
}
|
||||
if hookResult.InjectText != nil {
|
||||
injected = append(injected, fantasy.NewUserMessage(*hookResult.InjectText))
|
||||
}
|
||||
if len(injected) > 0 {
|
||||
preMessages = append(injected, preMessages...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persist pre-generation messages to tree session.
|
||||
for _, msg := range preMessages {
|
||||
_, _ = m.treeSession.AppendFantasyMessage(msg)
|
||||
@@ -306,6 +363,10 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
|
||||
result, err := m.generate(ctx, messages)
|
||||
if err != nil {
|
||||
m.events.emit(TurnEndEvent{Error: err})
|
||||
// Run AfterTurn hooks even on error.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Error: err})
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -321,6 +382,11 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
|
||||
}
|
||||
}
|
||||
|
||||
// Run AfterTurn hooks.
|
||||
if m.afterTurn.hasHooks() {
|
||||
m.afterTurn.run(AfterTurnHook{Response: responseText})
|
||||
}
|
||||
|
||||
return responseText, nil
|
||||
}
|
||||
|
||||
@@ -333,7 +399,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, preMessages []fan
|
||||
// automatically maintained in the tree session. Lifecycle events are emitted
|
||||
// to all registered subscribers. Returns an error if generation fails.
|
||||
func (m *Kit) Prompt(ctx context.Context, message string) (string, error) {
|
||||
return m.runTurn(ctx, message, []fantasy.Message{
|
||||
return m.runTurn(ctx, message, message, []fantasy.Message{
|
||||
fantasy.NewUserMessage(message),
|
||||
})
|
||||
}
|
||||
@@ -346,7 +412,7 @@ func (m *Kit) Prompt(ctx context.Context, message string) (string, error) {
|
||||
// a synthetic user message so the agent acknowledges and follows the directive.
|
||||
// Both messages are persisted to the session.
|
||||
func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
|
||||
return m.runTurn(ctx, "[steer] "+instruction, []fantasy.Message{
|
||||
return m.runTurn(ctx, "[steer] "+instruction, instruction, []fantasy.Message{
|
||||
fantasy.NewSystemMessage(instruction),
|
||||
fantasy.NewUserMessage("Please acknowledge and follow the above instruction."),
|
||||
})
|
||||
@@ -367,7 +433,7 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
text = "Continue."
|
||||
}
|
||||
|
||||
return m.runTurn(ctx, "[follow-up]", []fantasy.Message{
|
||||
return m.runTurn(ctx, "[follow-up]", text, []fantasy.Message{
|
||||
fantasy.NewUserMessage(text),
|
||||
})
|
||||
}
|
||||
@@ -390,7 +456,7 @@ func (m *Kit) PromptWithOptions(ctx context.Context, msg string, opts PromptOpti
|
||||
}
|
||||
preMessages = append(preMessages, fantasy.NewUserMessage(msg))
|
||||
|
||||
return m.runTurn(ctx, msg, preMessages)
|
||||
return m.runTurn(ctx, msg, msg, preMessages)
|
||||
}
|
||||
|
||||
// PromptWithCallbacks sends a message with callbacks for monitoring tool
|
||||
|
||||
+29
-2
@@ -31,6 +31,13 @@ type AgentSetupOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
// ExtraTools are additional tools added alongside core, MCP, and extension
|
||||
// tools. They do not replace the defaults — they extend them.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
// ToolWrapper is an optional function that wraps tools after extension
|
||||
// wrapping. Used by the SDK hook system. Both wrappers compose:
|
||||
// extension wrapper runs first (inner), then this wrapper (outer).
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -106,6 +113,26 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
}
|
||||
}
|
||||
|
||||
// Compose tool wrappers: extension wrapper (inner) + caller wrapper (outer).
|
||||
toolWrapper := extCreationOpts.toolWrapper
|
||||
if opts.ToolWrapper != nil {
|
||||
if toolWrapper != nil {
|
||||
inner := toolWrapper
|
||||
outer := opts.ToolWrapper
|
||||
toolWrapper = func(t []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return outer(inner(t))
|
||||
}
|
||||
} else {
|
||||
toolWrapper = opts.ToolWrapper
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra tools: extension tools + caller extra tools.
|
||||
extraTools := extCreationOpts.extraTools
|
||||
if len(opts.ExtraTools) > 0 {
|
||||
extraTools = append(extraTools, opts.ExtraTools...)
|
||||
}
|
||||
|
||||
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
@@ -117,8 +144,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: extCreationOpts.toolWrapper,
|
||||
ExtraTools: extCreationOpts.extraTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user