mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-13 19:20:06 +00:00
add Yaegi-based in-process extension system with event handlers, tool/command registration, and styled output
Implement a Pi-style extension system where plain .go files are loaded at runtime via Yaegi. Extensions register typed event handlers against 13 lifecycle events (tool_call, tool_result, input, before_agent_start, etc.) using concrete-type-only API methods to avoid Yaegi interface panics. Key capabilities: - Tool interception: block calls, modify results (wrapper pattern) - Input handling: transform or fully handle user input (skip agent) - System prompt injection via BeforeAgentStartResult - Custom tool and slash command registration - Styled output: ctx.Print, PrintInfo, PrintError, PrintBlock - Legacy hooks.yml compatibility via adapter - Auto-discovery from ~/.config/kit/extensions/ and .kit/extensions/ - CLI: kit extensions list|validate|init, --no-extensions, -e flags - 58 unit tests covering runner, loader (Yaegi), wrapper, events
This commit is contained in:
@@ -0,0 +1,175 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var extensionsCmd = &cobra.Command{
|
||||
Use: "extensions",
|
||||
Short: "Manage KIT extensions",
|
||||
Long: "Commands for listing, validating, and scaffolding KIT extensions",
|
||||
}
|
||||
|
||||
var extensionsListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List discovered extensions and their handlers",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading extensions: %w", err)
|
||||
}
|
||||
|
||||
if len(loaded) == 0 {
|
||||
fmt.Println("No extensions found.")
|
||||
fmt.Println()
|
||||
fmt.Println("Extension search paths:")
|
||||
fmt.Println(" ~/.config/kit/extensions/*.go (global)")
|
||||
fmt.Println(" .kit/extensions/*.go (project)")
|
||||
fmt.Println()
|
||||
fmt.Println("Run 'kit extensions init' to create an example extension.")
|
||||
return nil
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "EXTENSION\tEVENT\tHANDLERS\tTOOLS\tCOMMANDS")
|
||||
|
||||
for _, ext := range loaded {
|
||||
totalHandlers := 0
|
||||
for _, handlers := range ext.Handlers {
|
||||
totalHandlers += len(handlers)
|
||||
}
|
||||
first := true
|
||||
for event, handlers := range ext.Handlers {
|
||||
if first {
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%d\t%d\t%d\n",
|
||||
ext.Path, event, len(handlers), len(ext.Tools), len(ext.Commands))
|
||||
first = false
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(w, "\t%s\t%d\t\t\n",
|
||||
event, len(handlers))
|
||||
}
|
||||
}
|
||||
if first {
|
||||
// Extension loaded but registered no handlers
|
||||
_, _ = fmt.Fprintf(w, "%s\t(none)\t0\t%d\t%d\n",
|
||||
ext.Path, len(ext.Tools), len(ext.Commands))
|
||||
}
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
},
|
||||
}
|
||||
|
||||
var extensionsValidateCmd = &cobra.Command{
|
||||
Use: "validate",
|
||||
Short: "Validate all extension files can be loaded",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension"))
|
||||
if err != nil {
|
||||
return fmt.Errorf("validation failed: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Loaded %d extension(s) successfully\n", len(loaded))
|
||||
for _, ext := range loaded {
|
||||
total := 0
|
||||
for _, h := range ext.Handlers {
|
||||
total += len(h)
|
||||
}
|
||||
fmt.Printf(" %s (%d handlers, %d tools, %d commands)\n",
|
||||
ext.Path, total, len(ext.Tools), len(ext.Commands))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
var extensionsInitCmd = &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Generate an example extension file",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
dir := ".kit/extensions"
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("creating extensions directory: %w", err)
|
||||
}
|
||||
|
||||
example := `package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Init is called when the extension is loaded. Register handlers here.
|
||||
func Init(api ext.API) {
|
||||
// Log every tool call to a file.
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
f, err := os.OpenFile("/tmp/kit-tool-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
fmt.Fprintf(f, "[%s] tool=%s\n", time.Now().Format(time.RFC3339), tc.ToolName)
|
||||
}
|
||||
return nil // don't block
|
||||
})
|
||||
|
||||
// Block dangerous bash commands.
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName == "bash" && strings.Contains(tc.Input, "rm -rf /") {
|
||||
return &ext.ToolCallResult{Block: true, Reason: "Blocked: dangerous command"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Handle custom ! commands. Use ctx.Print/PrintInfo/PrintError/PrintBlock
|
||||
// instead of fmt.Println — BubbleTea captures stdout in interactive mode.
|
||||
//
|
||||
// ctx.Print("text") — plain text
|
||||
// ctx.PrintInfo("text") — styled system message block
|
||||
// ctx.PrintError("text") — styled error block
|
||||
// ctx.PrintBlock(opts) — custom block with border color and subtitle
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
switch ie.Text {
|
||||
case "!time":
|
||||
ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339))
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
|
||||
case "!status":
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "my-extension",
|
||||
})
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
|
||||
path := dir + "/example.go"
|
||||
if err := os.WriteFile(path, []byte(example), 0644); err != nil {
|
||||
return fmt.Errorf("writing example: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Created %s with example extension\n", path)
|
||||
fmt.Println()
|
||||
fmt.Println("The extension will be auto-loaded on the next kit run.")
|
||||
fmt.Println("Use --no-extensions to disable all extensions.")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(extensionsCmd)
|
||||
extensionsCmd.AddCommand(extensionsListCmd)
|
||||
extensionsCmd.AddCommand(extensionsValidateCmd)
|
||||
extensionsCmd.AddCommand(extensionsInitCmd)
|
||||
}
|
||||
+27
-2
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -57,7 +58,9 @@ var (
|
||||
numGPU int32
|
||||
mainGPU int32
|
||||
|
||||
// Hooks control
|
||||
// Extensions control
|
||||
noExtensionsFlag bool
|
||||
extensionPaths []string
|
||||
|
||||
// TLS configuration
|
||||
tlsSkipVerify bool
|
||||
@@ -301,6 +304,10 @@ func init() {
|
||||
BoolVarP(&resumeFlag, "resume", "r", false, "interactive session picker")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noSessionFlag, "no-session", false, "ephemeral mode — no session persistence")
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noExtensionsFlag, "no-extensions", false, "disable all extensions and hooks")
|
||||
rootCmd.PersistentFlags().
|
||||
StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)")
|
||||
|
||||
flags := rootCmd.PersistentFlags()
|
||||
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
|
||||
@@ -338,6 +345,8 @@ func init() {
|
||||
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
|
||||
_ = viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
|
||||
_ = viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
|
||||
_ = viper.BindPFlag("no-extensions", rootCmd.PersistentFlags().Lookup("no-extensions"))
|
||||
_ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension"))
|
||||
|
||||
// Defaults are already set in flag definitions, no need to duplicate in viper
|
||||
|
||||
@@ -542,7 +551,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// Create the app.App instance now that session messages are loaded.
|
||||
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames)
|
||||
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner)
|
||||
appOpts.SessionManager = sessionManager
|
||||
appOpts.TreeSession = treeSession
|
||||
|
||||
@@ -564,6 +573,22 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Emit SessionStart event to extensions.
|
||||
if agentResult.ExtRunner != nil {
|
||||
agentResult.ExtRunner.SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: promptFlag == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
})
|
||||
if agentResult.ExtRunner.HasHandlers(extensions.SessionStart) {
|
||||
_, _ = agentResult.ExtRunner.Emit(extensions.SessionStartEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// Check if running in non-interactive mode
|
||||
if promptFlag != "" {
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, promptFlag, quietFlag, noExitFlag, modelName, parsedProvider, mcpAgent.GetLoadingMessage(), serverNames, toolNames, usageTracker)
|
||||
|
||||
+1
-1
@@ -534,7 +534,7 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
|
||||
DisplayDebugConfig(cli, mcpAgent, mcpConfig, parsedProvider)
|
||||
|
||||
// Build app options.
|
||||
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames)
|
||||
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner)
|
||||
if cli != nil {
|
||||
if tracker := cli.GetUsageTracker(); tracker != nil {
|
||||
appOpts.UsageTracker = tracker
|
||||
|
||||
+72
-1
@@ -5,9 +5,13 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/hooks"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
@@ -70,6 +74,9 @@ type AgentSetupOptions struct {
|
||||
type AgentSetupResult struct {
|
||||
Agent *agent.Agent
|
||||
BufferedLogger *tools.BufferedDebugLogger
|
||||
// ExtRunner is the extension runner (nil when --no-extensions or no
|
||||
// extensions were discovered).
|
||||
ExtRunner *extensions.Runner
|
||||
}
|
||||
|
||||
// SetupAgent creates an agent from the current viper state + the provided
|
||||
@@ -93,6 +100,20 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
}
|
||||
}
|
||||
|
||||
// Load extensions unless --no-extensions is set. Extensions must be loaded
|
||||
// BEFORE agent creation so their tool wrapper and custom tools are included
|
||||
// in the Fantasy agent's tool list.
|
||||
var extRunner *extensions.Runner
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !viper.GetBool("no-extensions") {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = setupExtensions()
|
||||
if extErr != nil {
|
||||
// Extension loading failures are non-fatal.
|
||||
fmt.Printf("Warning: Failed to load extensions: %v\n", extErr)
|
||||
}
|
||||
}
|
||||
|
||||
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
@@ -103,6 +124,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
Quiet: quietFlag,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
ToolWrapper: extCreationOpts.toolWrapper,
|
||||
ExtraTools: extCreationOpts.extraTools,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
@@ -110,10 +133,57 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
|
||||
return &AgentSetupResult{
|
||||
Agent: a,
|
||||
ExtRunner: extRunner,
|
||||
BufferedLogger: bufferedLogger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// extensionCreationOpts holds the tool wrapper and extra tools that need to be
|
||||
// passed into agent creation, extracted from loaded extensions.
|
||||
type extensionCreationOpts struct {
|
||||
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
extraTools []fantasy.AgentTool
|
||||
}
|
||||
|
||||
// setupExtensions discovers and loads Yaegi extensions plus legacy hooks.yml,
|
||||
// builds the runner, and returns the tool wrapper/extra tools needed by the
|
||||
// agent factory.
|
||||
func setupExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
loaded, err := extensions.LoadExtensions(extraPaths)
|
||||
if err != nil {
|
||||
return nil, extensionCreationOpts{}, err
|
||||
}
|
||||
|
||||
// Also load legacy hooks.yml as a compat extension.
|
||||
hooksCfg, _ := hooks.LoadHooksConfig()
|
||||
if hooksCfg != nil && len(hooksCfg.Hooks) > 0 {
|
||||
compat := extensions.HooksAsExtension(hooksCfg)
|
||||
if compat != nil {
|
||||
loaded = append([]extensions.LoadedExtension{*compat}, loaded...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(loaded) == 0 {
|
||||
return nil, extensionCreationOpts{}, nil
|
||||
}
|
||||
|
||||
runner := extensions.NewRunner(loaded)
|
||||
|
||||
// Build the tool wrapper that intercepts tool calls through the runner.
|
||||
wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
}
|
||||
|
||||
// Collect custom tools registered by extensions.
|
||||
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools())
|
||||
|
||||
return runner, extensionCreationOpts{
|
||||
toolWrapper: wrapper,
|
||||
extraTools: extTools,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CollectAgentMetadata extracts model display info and tool/server name lists
|
||||
// from the agent. This is used by both root.go and script.go to populate
|
||||
// app.Options and UI setup.
|
||||
@@ -138,7 +208,7 @@ func CollectAgentMetadata(mcpAgent *agent.Agent, mcpConfig *config.Config) (prov
|
||||
|
||||
// BuildAppOptions constructs the app.Options struct from the current state.
|
||||
// Both root.go and script.go converge here after agent creation.
|
||||
func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string) app.Options {
|
||||
func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string, extRunner *extensions.Runner) app.Options {
|
||||
return app.Options{
|
||||
Agent: mcpAgent,
|
||||
MCPConfig: mcpConfig,
|
||||
@@ -149,6 +219,7 @@ func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName
|
||||
Quiet: quietFlag,
|
||||
Debug: viper.GetBool("debug"),
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
Extensions: extRunner,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Init registers handlers that log all tool calls and session lifecycle
|
||||
// events to /tmp/kit-tool-log.txt.
|
||||
func Init(api ext.API) {
|
||||
logFile := "/tmp/kit-tool-log.txt"
|
||||
|
||||
// Log every tool call before execution.
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
fmt.Fprintf(f, "[%s] CALL tool=%s model=%s\n",
|
||||
time.Now().Format(time.RFC3339), tc.ToolName, ctx.Model)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Log tool results after execution.
|
||||
api.OnToolResult(func(tr ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
status := "ok"
|
||||
if tr.IsError {
|
||||
status = "error"
|
||||
}
|
||||
fmt.Fprintf(f, "[%s] RESULT tool=%s status=%s bytes=%d\n",
|
||||
time.Now().Format(time.RFC3339), tr.ToolName, status, len(tr.Content))
|
||||
}
|
||||
return nil // don't modify the result
|
||||
})
|
||||
|
||||
// Log session start/shutdown.
|
||||
api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) {
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
fmt.Fprintf(f, "[%s] SESSION_START cwd=%s\n",
|
||||
time.Now().Format(time.RFC3339), ctx.CWD)
|
||||
}
|
||||
})
|
||||
|
||||
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
fmt.Fprintf(f, "[%s] SESSION_SHUTDOWN\n",
|
||||
time.Now().Format(time.RFC3339))
|
||||
}
|
||||
})
|
||||
|
||||
// "!time" — prints the current time as a styled info block.
|
||||
// "!status" — prints a custom block with green border and subtitle.
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
switch ie.Text {
|
||||
case "!time":
|
||||
ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339))
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
|
||||
case "!status":
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "tool-logger extension",
|
||||
})
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -48,6 +48,7 @@ require (
|
||||
github.com/charmbracelet/colorprofile v0.4.2 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/log v0.4.2 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260223200540-d6a276319c45 // indirect
|
||||
@@ -61,6 +62,7 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
@@ -95,6 +97,7 @@ require (
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/traefik/yaegi v0.16.1 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
@@ -108,6 +111,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.40.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
golang.org/x/oauth2 v0.35.0 // indirect
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
|
||||
@@ -86,6 +86,8 @@ github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
|
||||
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
|
||||
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 h1:Af/L28Xh+pddhouT/6lJ7IAIYfu5tWJOB0iqt+mXsYM=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73/go.mod h1:E6/0abq9uG2SnM8IbLB9Y5SW09uIgfaFETk8aRzgXUQ=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
@@ -132,6 +134,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -251,6 +255,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
|
||||
@@ -24,6 +24,15 @@ type AgentConfig struct {
|
||||
MaxSteps int
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the Fantasy agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ExtraTools are additional tools to include alongside core and MCP tools.
|
||||
// Used by extensions to register custom tools.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -109,6 +118,16 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Append any extra tools provided by extensions.
|
||||
if len(agentConfig.ExtraTools) > 0 {
|
||||
allTools = append(allTools, agentConfig.ExtraTools...)
|
||||
}
|
||||
|
||||
// Apply tool wrapper (extension interception layer) if configured.
|
||||
if agentConfig.ToolWrapper != nil {
|
||||
allTools = agentConfig.ToolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Build fantasy agent options
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
@@ -34,6 +36,10 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -47,6 +53,8 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
@@ -254,6 +255,11 @@ func (a *App) Close() {
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
|
||||
// --- Extension: SessionShutdown ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.SessionShutdown) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.SessionShutdownEvent{})
|
||||
}
|
||||
|
||||
// Cancel any in-flight step and the root context.
|
||||
cancel()
|
||||
a.rootCancel()
|
||||
@@ -362,6 +368,23 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// --- Extension: Input event (can transform or handle the prompt) ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.Input) {
|
||||
result, _ := a.opts.Extensions.Emit(extensions.InputEvent{
|
||||
Text: prompt,
|
||||
Source: a.inputSource(),
|
||||
})
|
||||
if r, ok := result.(extensions.InputResult); ok {
|
||||
switch r.Action {
|
||||
case "transform":
|
||||
prompt = r.Text
|
||||
case "handled":
|
||||
// Extension handled the input; skip the agent entirely.
|
||||
return &agent.GenerateWithLoopResult{}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add user message to the store immediately so history is consistent
|
||||
// even if the step is later cancelled.
|
||||
userMsg := fantasy.NewUserMessage(prompt)
|
||||
@@ -385,9 +408,39 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Track message count before agent runs so we can diff new messages.
|
||||
sentCount := len(msgs)
|
||||
|
||||
// --- Extension: BeforeAgentStart ---
|
||||
// Extensions can inject a system message or prepend context text into the
|
||||
// conversation before the agent runs.
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.BeforeAgentStart) {
|
||||
result, _ := a.opts.Extensions.Emit(extensions.BeforeAgentStartEvent{Prompt: prompt})
|
||||
if r, ok := result.(extensions.BeforeAgentStartResult); ok {
|
||||
if r.SystemPrompt != nil && *r.SystemPrompt != "" {
|
||||
// Prepend a system message so the LLM sees extension-provided
|
||||
// instructions. This supplements (not replaces) the agent's
|
||||
// configured system prompt.
|
||||
msgs = append([]fantasy.Message{fantasy.NewSystemMessage(*r.SystemPrompt)}, msgs...)
|
||||
}
|
||||
if r.InjectText != nil && *r.InjectText != "" {
|
||||
// Prepend a user message with the injected context so it
|
||||
// appears early in the conversation window.
|
||||
msgs = append([]fantasy.Message{fantasy.NewUserMessage(*r.InjectText)}, msgs...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Extension: AgentStart ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentStart) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.AgentStartEvent{Prompt: prompt})
|
||||
}
|
||||
|
||||
// Signal spinner start.
|
||||
sendFn(SpinnerEvent{Show: true})
|
||||
|
||||
// --- Extension: MessageStart ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageStart) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.MessageStartEvent{})
|
||||
}
|
||||
|
||||
result, err := a.opts.Agent.GenerateWithLoopAndStreaming(ctx, msgs,
|
||||
// onToolCall
|
||||
func(toolName, toolArgs string) {
|
||||
@@ -416,14 +469,42 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
},
|
||||
// onStreamingResponse — spinner keeps running alongside streaming text
|
||||
func(chunk string) {
|
||||
// Extension: MessageUpdate (observe streaming chunks)
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageUpdate) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.MessageUpdateEvent{Chunk: chunk})
|
||||
}
|
||||
sendFn(StreamChunkEvent{Content: chunk})
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
// --- Extension: AgentEnd with error ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{
|
||||
Response: "",
|
||||
StopReason: "error",
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// --- Extension: MessageEnd ---
|
||||
responseText := ""
|
||||
if result.FinalResponse != nil {
|
||||
responseText = result.FinalResponse.Content.Text()
|
||||
}
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageEnd) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.MessageEndEvent{Content: responseText})
|
||||
}
|
||||
|
||||
// --- Extension: AgentEnd with success ---
|
||||
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) {
|
||||
_, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{
|
||||
Response: responseText,
|
||||
StopReason: "completed",
|
||||
})
|
||||
}
|
||||
|
||||
// Replace the store with the full updated conversation returned by the agent
|
||||
// (includes tool call/result messages added during the step).
|
||||
a.store.Replace(result.ConversationMessages)
|
||||
@@ -439,6 +520,17 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// inputSource returns a string identifying how the current session receives
|
||||
// input — used by the Input extension event.
|
||||
func (a *App) inputSource() string {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.program != nil {
|
||||
return "interactive"
|
||||
}
|
||||
return "cli"
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Internal: event helpers
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -454,6 +546,45 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// PrintFromExtension outputs text from an extension to the user. The level
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
}
|
||||
|
||||
// PrintBlockFromExtension outputs a custom styled block from an extension.
|
||||
func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(ExtensionPrintEvent{
|
||||
Text: opts.Text,
|
||||
Level: "block",
|
||||
BorderColor: opts.BorderColor,
|
||||
Subtitle: opts.Subtitle,
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// updateUsage records token usage from a completed agent step into the configured
|
||||
// UsageTracker (if any). It uses the actual token counts from the agent result's
|
||||
// TotalUsage field when available; otherwise it falls back to text-based estimation.
|
||||
|
||||
@@ -96,3 +96,23 @@ type MessageCreatedEvent struct {
|
||||
// Message is the fantasy message that was added to the store.
|
||||
Message fantasy.Message
|
||||
}
|
||||
|
||||
// ExtensionPrintEvent is sent when an extension calls ctx.Print, ctx.PrintInfo,
|
||||
// ctx.PrintError, or ctx.PrintBlock. The TUI renders it via the appropriate
|
||||
// renderer and tea.Println (scrollback); the CLI handler uses
|
||||
// DisplayInfo/DisplayError or plain fmt.Println. This exists because BubbleTea
|
||||
// captures stdout, so plain fmt.Println inside extensions would be swallowed.
|
||||
type ExtensionPrintEvent struct {
|
||||
// Text is the content the extension wants to display to the user.
|
||||
Text string
|
||||
// Level controls the rendering style:
|
||||
// "" — plain text (no styling)
|
||||
// "info" — system message block (bordered, themed)
|
||||
// "error" — error block (red border, bold text)
|
||||
// "block" — custom block with BorderColor and Subtitle
|
||||
Level string
|
||||
// BorderColor is a hex color (e.g. "#a6e3a1") for Level="block".
|
||||
BorderColor string
|
||||
// Subtitle is optional muted text below the content for Level="block".
|
||||
Subtitle string
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
@@ -94,4 +95,10 @@ type Options struct {
|
||||
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
|
||||
// agent. Satisfied by *ui.UsageTracker; wired in cmd/root.go.
|
||||
UsageTracker UsageUpdater
|
||||
|
||||
// Extensions is the optional extension runner. When non-nil, lifecycle
|
||||
// events (Input, BeforeAgentStart, AgentEnd, etc.) are emitted through
|
||||
// it. Tool-level events (ToolCall, ToolResult) are handled by wrapper.go
|
||||
// at the tool layer, not here.
|
||||
Extensions *extensions.Runner
|
||||
}
|
||||
|
||||
@@ -0,0 +1,325 @@
|
||||
package extensions
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal types (used by runner, NOT exposed to Yaegi)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Event is the interface satisfied by all event types internally.
|
||||
type Event interface {
|
||||
Type() EventType
|
||||
}
|
||||
|
||||
// Result is the interface satisfied by all result types internally.
|
||||
type Result interface {
|
||||
isResult()
|
||||
}
|
||||
|
||||
// HandlerFunc is the internal handler signature used by the runner.
|
||||
type HandlerFunc func(event Event, ctx Context) Result
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context (exposed to Yaegi — concrete struct, no interfaces)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Context provides runtime information to handlers about the current session.
|
||||
type Context struct {
|
||||
SessionID string
|
||||
CWD string
|
||||
Model string
|
||||
Interactive bool
|
||||
|
||||
// Print outputs plain text to the user. In interactive mode this
|
||||
// routes through BubbleTea's scrollback (tea.Println); in
|
||||
// non-interactive mode it writes to stdout. Extensions must use
|
||||
// this instead of fmt.Println, which is swallowed by BubbleTea.
|
||||
Print func(string)
|
||||
|
||||
// PrintInfo outputs text as a styled system message block (bordered,
|
||||
// themed). Use this for informational notices the user should see.
|
||||
PrintInfo func(string)
|
||||
|
||||
// PrintError outputs text as a styled error block (red border, bold).
|
||||
// Use this for error messages or warnings.
|
||||
PrintError func(string)
|
||||
|
||||
// PrintBlock outputs text as a custom styled block with caller-chosen
|
||||
// border color and optional subtitle. Example:
|
||||
//
|
||||
// ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
// Text: "Deployment complete!",
|
||||
// BorderColor: "#a6e3a1",
|
||||
// Subtitle: "my-extension",
|
||||
// })
|
||||
PrintBlock func(PrintBlockOpts)
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
Text string
|
||||
// BorderColor is a hex color string (e.g. "#a6e3a1") for the left border.
|
||||
// Defaults to the theme's system color if empty.
|
||||
BorderColor string
|
||||
// Subtitle is optional text shown below the content in muted style
|
||||
// (e.g. extension name, timestamp). Empty means no subtitle line.
|
||||
Subtitle string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// API — the object passed to each extension's Init function.
|
||||
//
|
||||
// Instead of a generic On(EventType, HandlerFunc) that uses interfaces,
|
||||
// we expose event-specific methods with concrete function signatures.
|
||||
// This avoids Yaegi's genInterfaceWrapper crash entirely — no interfaces
|
||||
// cross the Yaegi boundary.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// API is passed to each extension's Init function. Extensions use it to
|
||||
// register typed event handlers, custom tools, and slash commands.
|
||||
type API struct {
|
||||
// Event-specific registration functions (wired by the loader).
|
||||
onToolCall func(func(ToolCallEvent, Context) *ToolCallResult)
|
||||
onToolExecStart func(func(ToolExecutionStartEvent, Context))
|
||||
onToolExecEnd func(func(ToolExecutionEndEvent, Context))
|
||||
onToolResult func(func(ToolResultEvent, Context) *ToolResultResult)
|
||||
onInput func(func(InputEvent, Context) *InputResult)
|
||||
onBeforeAgentStart func(func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult)
|
||||
onAgentStart func(func(AgentStartEvent, Context))
|
||||
onAgentEnd func(func(AgentEndEvent, Context))
|
||||
onMessageStart func(func(MessageStartEvent, Context))
|
||||
onMessageUpdate func(func(MessageUpdateEvent, Context))
|
||||
onMessageEnd func(func(MessageEndEvent, Context))
|
||||
onSessionStart func(func(SessionStartEvent, Context))
|
||||
onSessionShutdown func(func(SessionShutdownEvent, Context))
|
||||
registerToolFn func(ToolDef)
|
||||
registerCmdFn func(CommandDef)
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
// Return a non-nil ToolCallResult with Block=true to prevent execution.
|
||||
func (a *API) OnToolCall(handler func(ToolCallEvent, Context) *ToolCallResult) {
|
||||
a.onToolCall(handler)
|
||||
}
|
||||
|
||||
// OnToolExecutionStart registers a handler for tool execution start.
|
||||
func (a *API) OnToolExecutionStart(handler func(ToolExecutionStartEvent, Context)) {
|
||||
a.onToolExecStart(handler)
|
||||
}
|
||||
|
||||
// OnToolExecutionEnd registers a handler for tool execution end.
|
||||
func (a *API) OnToolExecutionEnd(handler func(ToolExecutionEndEvent, Context)) {
|
||||
a.onToolExecEnd(handler)
|
||||
}
|
||||
|
||||
// OnToolResult registers a handler that fires after tool execution.
|
||||
// Return a non-nil ToolResultResult to modify the output.
|
||||
func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnInput registers a handler that fires when user input is received.
|
||||
// Return a non-nil InputResult to transform or handle the input.
|
||||
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
|
||||
a.onInput(handler)
|
||||
}
|
||||
|
||||
// OnBeforeAgentStart registers a handler that fires before the agent loop.
|
||||
func (a *API) OnBeforeAgentStart(handler func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
|
||||
a.onBeforeAgentStart(handler)
|
||||
}
|
||||
|
||||
// OnAgentStart registers a handler for when the agent loop begins.
|
||||
func (a *API) OnAgentStart(handler func(AgentStartEvent, Context)) {
|
||||
a.onAgentStart(handler)
|
||||
}
|
||||
|
||||
// OnAgentEnd registers a handler for when the agent finishes responding.
|
||||
func (a *API) OnAgentEnd(handler func(AgentEndEvent, Context)) {
|
||||
a.onAgentEnd(handler)
|
||||
}
|
||||
|
||||
// OnMessageStart registers a handler for when an assistant message begins.
|
||||
func (a *API) OnMessageStart(handler func(MessageStartEvent, Context)) {
|
||||
a.onMessageStart(handler)
|
||||
}
|
||||
|
||||
// OnMessageUpdate registers a handler for streaming text chunks.
|
||||
func (a *API) OnMessageUpdate(handler func(MessageUpdateEvent, Context)) {
|
||||
a.onMessageUpdate(handler)
|
||||
}
|
||||
|
||||
// OnMessageEnd registers a handler for when the assistant message is complete.
|
||||
func (a *API) OnMessageEnd(handler func(MessageEndEvent, Context)) {
|
||||
a.onMessageEnd(handler)
|
||||
}
|
||||
|
||||
// OnSessionStart registers a handler for when a session is loaded or created.
|
||||
func (a *API) OnSessionStart(handler func(SessionStartEvent, Context)) {
|
||||
a.onSessionStart(handler)
|
||||
}
|
||||
|
||||
// OnSessionShutdown registers a handler for when the application is closing.
|
||||
func (a *API) OnSessionShutdown(handler func(SessionShutdownEvent, Context)) {
|
||||
a.onSessionShutdown(handler)
|
||||
}
|
||||
|
||||
// RegisterTool adds a custom tool that the LLM can invoke.
|
||||
func (a *API) RegisterTool(tool ToolDef) {
|
||||
a.registerToolFn(tool)
|
||||
}
|
||||
|
||||
// RegisterCommand adds a slash command available in interactive mode.
|
||||
func (a *API) RegisterCommand(cmd CommandDef) {
|
||||
a.registerCmdFn(cmd)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ToolDef / CommandDef
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ToolDef describes a custom tool registered by an extension.
|
||||
type ToolDef struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters string // JSON Schema string
|
||||
Execute func(input string) (string, error)
|
||||
}
|
||||
|
||||
// CommandDef describes a slash command registered by an extension.
|
||||
type CommandDef struct {
|
||||
Name string
|
||||
Description string
|
||||
Execute func(args string) (string, error)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Typed events (all concrete structs — safe for Yaegi)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ToolCallEvent fires before a tool executes.
|
||||
type ToolCallEvent struct {
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
Input string // JSON-encoded tool parameters
|
||||
}
|
||||
|
||||
func (e ToolCallEvent) Type() EventType { return ToolCall }
|
||||
|
||||
// ToolCallResult controls whether the tool call proceeds.
|
||||
type ToolCallResult struct {
|
||||
Block bool
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (ToolCallResult) isResult() {}
|
||||
|
||||
// ToolExecutionStartEvent fires when a tool begins executing.
|
||||
type ToolExecutionStartEvent struct {
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e ToolExecutionStartEvent) Type() EventType { return ToolExecutionStart }
|
||||
|
||||
// ToolExecutionEndEvent fires when a tool finishes executing.
|
||||
type ToolExecutionEndEvent struct {
|
||||
ToolName string
|
||||
}
|
||||
|
||||
func (e ToolExecutionEndEvent) Type() EventType { return ToolExecutionEnd }
|
||||
|
||||
// ToolResultEvent fires after tool execution with the output.
|
||||
type ToolResultEvent struct {
|
||||
ToolName string
|
||||
Input string
|
||||
Content string
|
||||
IsError bool
|
||||
}
|
||||
|
||||
func (e ToolResultEvent) Type() EventType { return ToolResult }
|
||||
|
||||
// ToolResultResult can modify the tool's output before it reaches the LLM.
|
||||
type ToolResultResult struct {
|
||||
Content *string // nil = unchanged
|
||||
IsError *bool // nil = unchanged
|
||||
}
|
||||
|
||||
func (ToolResultResult) isResult() {}
|
||||
|
||||
// InputEvent fires when user input is received.
|
||||
type InputEvent struct {
|
||||
Text string
|
||||
Source string // "interactive", "cli", "script", "queue"
|
||||
}
|
||||
|
||||
func (e InputEvent) Type() EventType { return Input }
|
||||
|
||||
// InputResult controls what happens with user input.
|
||||
//
|
||||
// Action: "continue" (default), "transform", "handled"
|
||||
type InputResult struct {
|
||||
Action string
|
||||
Text string // replacement text when Action="transform"
|
||||
}
|
||||
|
||||
func (InputResult) isResult() {}
|
||||
|
||||
// BeforeAgentStartEvent fires before the agent loop begins.
|
||||
type BeforeAgentStartEvent struct {
|
||||
Prompt string
|
||||
}
|
||||
|
||||
func (e BeforeAgentStartEvent) Type() EventType { return BeforeAgentStart }
|
||||
|
||||
// BeforeAgentStartResult can inject context before the agent runs.
|
||||
type BeforeAgentStartResult struct {
|
||||
InjectText *string
|
||||
SystemPrompt *string
|
||||
}
|
||||
|
||||
func (BeforeAgentStartResult) isResult() {}
|
||||
|
||||
// AgentStartEvent fires when the agent loop begins.
|
||||
type AgentStartEvent struct {
|
||||
Prompt string
|
||||
}
|
||||
|
||||
func (e AgentStartEvent) Type() EventType { return AgentStart }
|
||||
|
||||
// AgentEndEvent fires when the agent finishes responding.
|
||||
type AgentEndEvent struct {
|
||||
Response string
|
||||
StopReason string // "completed", "cancelled", "error"
|
||||
}
|
||||
|
||||
func (e AgentEndEvent) Type() EventType { return AgentEnd }
|
||||
|
||||
// MessageStartEvent fires when a new assistant message begins.
|
||||
type MessageStartEvent struct{}
|
||||
|
||||
func (e MessageStartEvent) Type() EventType { return MessageStart }
|
||||
|
||||
// MessageUpdateEvent fires for each streaming text chunk.
|
||||
type MessageUpdateEvent struct {
|
||||
Chunk string
|
||||
}
|
||||
|
||||
func (e MessageUpdateEvent) Type() EventType { return MessageUpdate }
|
||||
|
||||
// MessageEndEvent fires when the assistant message is complete.
|
||||
type MessageEndEvent struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
func (e MessageEndEvent) Type() EventType { return MessageEnd }
|
||||
|
||||
// SessionStartEvent fires when a session is loaded or created.
|
||||
type SessionStartEvent struct {
|
||||
SessionID string
|
||||
}
|
||||
|
||||
func (e SessionStartEvent) Type() EventType { return SessionStart }
|
||||
|
||||
// SessionShutdownEvent fires when the application is closing.
|
||||
type SessionShutdownEvent struct{}
|
||||
|
||||
func (e SessionShutdownEvent) Type() EventType { return SessionShutdown }
|
||||
@@ -0,0 +1,111 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mark3labs/kit/internal/hooks"
|
||||
)
|
||||
|
||||
// HooksAsExtension wraps an existing hooks.HookConfig as a LoadedExtension
|
||||
// so that legacy .kit/hooks.yml configurations continue to work alongside
|
||||
// the new Yaegi extension system. The adapter translates the old event names
|
||||
// and shell-command execution model into extension HandlerFunc handlers.
|
||||
func HooksAsExtension(config *hooks.HookConfig) *LoadedExtension {
|
||||
if config == nil || len(config.Hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ext := &LoadedExtension{
|
||||
Path: "hooks.yml (compat)",
|
||||
Handlers: make(map[EventType][]HandlerFunc),
|
||||
}
|
||||
|
||||
executor := hooks.NewExecutor(config, "", "")
|
||||
|
||||
// Map PreToolUse → ToolCall
|
||||
if matchers, ok := config.Hooks[hooks.PreToolUse]; ok && len(matchers) > 0 {
|
||||
ext.Handlers[ToolCall] = []HandlerFunc{
|
||||
func(event Event, _ Context) Result {
|
||||
tc, ok := event.(ToolCallEvent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
input := &hooks.PreToolUseInput{
|
||||
ToolName: tc.ToolName,
|
||||
ToolInput: json.RawMessage(tc.Input),
|
||||
}
|
||||
output, err := executor.ExecuteHooks(context.Background(), hooks.PreToolUse, input)
|
||||
if err != nil || output == nil {
|
||||
return nil
|
||||
}
|
||||
if output.Decision == "block" {
|
||||
return ToolCallResult{Block: true, Reason: output.Reason}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Map PostToolUse → ToolResult
|
||||
if matchers, ok := config.Hooks[hooks.PostToolUse]; ok && len(matchers) > 0 {
|
||||
ext.Handlers[ToolResult] = []HandlerFunc{
|
||||
func(event Event, _ Context) Result {
|
||||
tr, ok := event.(ToolResultEvent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
input := &hooks.PostToolUseInput{
|
||||
ToolName: tr.ToolName,
|
||||
ToolInput: json.RawMessage(tr.Input),
|
||||
ToolResponse: json.RawMessage(tr.Content),
|
||||
}
|
||||
_, _ = executor.ExecuteHooks(context.Background(), hooks.PostToolUse, input)
|
||||
return nil // legacy hooks don't modify results
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Map UserPromptSubmit → Input
|
||||
if matchers, ok := config.Hooks[hooks.UserPromptSubmit]; ok && len(matchers) > 0 {
|
||||
ext.Handlers[Input] = []HandlerFunc{
|
||||
func(event Event, _ Context) Result {
|
||||
ie, ok := event.(InputEvent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
input := &hooks.UserPromptSubmitInput{
|
||||
Prompt: ie.Text,
|
||||
}
|
||||
output, err := executor.ExecuteHooks(context.Background(), hooks.UserPromptSubmit, input)
|
||||
if err != nil || output == nil {
|
||||
return nil
|
||||
}
|
||||
if output.Decision == "block" {
|
||||
return InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Map Stop → AgentEnd
|
||||
if matchers, ok := config.Hooks[hooks.Stop]; ok && len(matchers) > 0 {
|
||||
ext.Handlers[AgentEnd] = []HandlerFunc{
|
||||
func(event Event, _ Context) Result {
|
||||
ae, ok := event.(AgentEndEvent)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
input := &hooks.StopInput{
|
||||
Response: ae.Response,
|
||||
StopReason: ae.StopReason,
|
||||
}
|
||||
_, _ = executor.ExecuteHooks(context.Background(), hooks.Stop, input)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return ext
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
// Package extensions implements a Pi-style in-process extension system for KIT.
|
||||
// Extensions are plain Go files loaded at runtime via Yaegi (a Go interpreter).
|
||||
// They register event handlers using an API object, enabling tool interception,
|
||||
// input transformation, and lifecycle observation — all without recompilation.
|
||||
package extensions
|
||||
|
||||
// EventType identifies a point in KIT's lifecycle where extensions can hook in.
|
||||
type EventType string
|
||||
|
||||
const (
|
||||
// ToolCall fires before a tool executes. Handlers can block execution.
|
||||
ToolCall EventType = "tool_call"
|
||||
|
||||
// ToolExecutionStart fires when a tool begins executing.
|
||||
ToolExecutionStart EventType = "tool_execution_start"
|
||||
|
||||
// ToolExecutionEnd fires when a tool finishes executing.
|
||||
ToolExecutionEnd EventType = "tool_execution_end"
|
||||
|
||||
// ToolResult fires after a tool executes. Handlers can modify the result.
|
||||
ToolResult EventType = "tool_result"
|
||||
|
||||
// Input fires when user input is received. Handlers can transform or handle it.
|
||||
Input EventType = "input"
|
||||
|
||||
// BeforeAgentStart fires before the agent loop begins for a prompt.
|
||||
BeforeAgentStart EventType = "before_agent_start"
|
||||
|
||||
// AgentStart fires when the agent loop begins processing.
|
||||
AgentStart EventType = "agent_start"
|
||||
|
||||
// AgentEnd fires when the agent finishes responding.
|
||||
AgentEnd EventType = "agent_end"
|
||||
|
||||
// MessageStart fires when a new assistant message begins.
|
||||
MessageStart EventType = "message_start"
|
||||
|
||||
// MessageUpdate fires for each streaming text chunk.
|
||||
MessageUpdate EventType = "message_update"
|
||||
|
||||
// MessageEnd fires when the assistant message is complete.
|
||||
MessageEnd EventType = "message_end"
|
||||
|
||||
// SessionStart fires when a session is loaded or created.
|
||||
SessionStart EventType = "session_start"
|
||||
|
||||
// SessionShutdown fires when the application is closing.
|
||||
SessionShutdown EventType = "session_shutdown"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
func AllEventTypes() []EventType {
|
||||
return []EventType{
|
||||
ToolCall, ToolExecutionStart, ToolExecutionEnd, ToolResult,
|
||||
Input, BeforeAgentStart, AgentStart, AgentEnd,
|
||||
MessageStart, MessageUpdate, MessageEnd,
|
||||
SessionStart, SessionShutdown,
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the event type is a recognised lifecycle event.
|
||||
func (e EventType) IsValid() bool {
|
||||
for _, valid := range AllEventTypes() {
|
||||
if e == valid {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package extensions
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 13 {
|
||||
t.Fatalf("expected 13 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllEventTypes_NoDuplicates(t *testing.T) {
|
||||
seen := make(map[EventType]bool)
|
||||
for _, et := range AllEventTypes() {
|
||||
if seen[et] {
|
||||
t.Fatalf("duplicate event type: %s", et)
|
||||
}
|
||||
seen[et] = true
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventType_IsValid(t *testing.T) {
|
||||
for _, et := range AllEventTypes() {
|
||||
if !et.IsValid() {
|
||||
t.Errorf("expected %s to be valid", et)
|
||||
}
|
||||
}
|
||||
|
||||
invalid := EventType("nonexistent_event")
|
||||
if invalid.IsValid() {
|
||||
t.Error("expected 'nonexistent_event' to be invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventType_TypeMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
event Event
|
||||
want EventType
|
||||
}{
|
||||
{ToolCallEvent{ToolName: "test"}, ToolCall},
|
||||
{ToolExecutionStartEvent{ToolName: "test"}, ToolExecutionStart},
|
||||
{ToolExecutionEndEvent{ToolName: "test"}, ToolExecutionEnd},
|
||||
{ToolResultEvent{ToolName: "test"}, ToolResult},
|
||||
{InputEvent{Text: "hello"}, Input},
|
||||
{BeforeAgentStartEvent{Prompt: "test"}, BeforeAgentStart},
|
||||
{AgentStartEvent{Prompt: "test"}, AgentStart},
|
||||
{AgentEndEvent{Response: "done"}, AgentEnd},
|
||||
{MessageStartEvent{}, MessageStart},
|
||||
{MessageUpdateEvent{Chunk: "hi"}, MessageUpdate},
|
||||
{MessageEndEvent{Content: "done"}, MessageEnd},
|
||||
{SessionStartEvent{SessionID: "abc"}, SessionStart},
|
||||
{SessionShutdownEvent{}, SessionShutdown},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.event.Type(); got != tt.want {
|
||||
t.Errorf("event %T.Type() = %s, want %s", tt.event, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/traefik/yaegi/interp"
|
||||
"github.com/traefik/yaegi/stdlib"
|
||||
)
|
||||
|
||||
// Discovery paths searched in order (lowest to highest precedence):
|
||||
//
|
||||
// ~/.config/kit/extensions/*.go global single files
|
||||
// ~/.config/kit/extensions/*/main.go global subdirectories
|
||||
// .kit/extensions/*.go project-local single files
|
||||
// .kit/extensions/*/main.go project-local subdirectories
|
||||
//
|
||||
// Explicit paths passed via --extension / -e flags are appended last.
|
||||
|
||||
// LoadExtensions discovers and loads extensions from standard locations and
|
||||
// any extra paths. Each extension is loaded into its own Yaegi interpreter
|
||||
// for isolation. Extensions that fail to load are logged and skipped.
|
||||
func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
paths := discoverExtensionPaths(extraPaths)
|
||||
if len(paths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var loaded []LoadedExtension
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Warn("skipping extension", "path", p, "err", err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Debug("loaded extension", "path", p,
|
||||
"handlers", countHandlers(ext),
|
||||
"tools", len(ext.Tools),
|
||||
"commands", len(ext.Commands))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// discoverExtensionPaths returns deduplicated paths to extension files in
|
||||
// load-order (global first, then project-local, then explicit).
|
||||
func discoverExtensionPaths(extraPaths []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var paths []string
|
||||
|
||||
add := func(p string) {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
seen[abs] = true
|
||||
paths = append(paths, abs)
|
||||
}
|
||||
|
||||
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
||||
globalDir := globalExtensionsDir()
|
||||
for _, p := range findExtensionsInDir(globalDir) {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
for _, found := range findExtensionsInDir(p) {
|
||||
add(found)
|
||||
}
|
||||
} else if strings.HasSuffix(p, ".go") {
|
||||
add(p)
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
||||
|
||||
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
||||
func findExtensionsInDir(dir string) []string {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil || !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var results []string
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
if _, err := os.Stat(main); err == nil {
|
||||
results = append(results, main)
|
||||
}
|
||||
}
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
// globalExtensionsDir returns the global extensions directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/extensions.
|
||||
func globalExtensionsDir() string {
|
||||
base := os.Getenv("XDG_CONFIG_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(base, "kit", "extensions")
|
||||
}
|
||||
|
||||
// loadSingleExtension loads one .go file into a fresh Yaegi interpreter,
|
||||
// calls the Init(ext.API) function, and returns the registered handlers.
|
||||
func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
ext := &LoadedExtension{
|
||||
Path: path,
|
||||
Handlers: make(map[EventType][]HandlerFunc),
|
||||
}
|
||||
|
||||
// Create a fresh interpreter.
|
||||
i := interp.New(interp.Options{})
|
||||
|
||||
// Expose a safe subset of the Go stdlib.
|
||||
if err := i.Use(stdlib.Symbols); err != nil {
|
||||
return nil, fmt.Errorf("loading stdlib symbols: %w", err)
|
||||
}
|
||||
|
||||
// Expose KIT's extension API types so the extension can
|
||||
// import "kit/ext" and use ext.ToolCall, ext.API, etc.
|
||||
if err := i.Use(Symbols()); err != nil {
|
||||
return nil, fmt.Errorf("loading extension symbols: %w", err)
|
||||
}
|
||||
|
||||
// Read and evaluate the extension source file.
|
||||
src, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading file: %w", err)
|
||||
}
|
||||
|
||||
if _, err := i.Eval(string(src)); err != nil {
|
||||
return nil, fmt.Errorf("evaluating source: %w", err)
|
||||
}
|
||||
|
||||
// Extract the Init function. Extensions must export:
|
||||
// func Init(api ext.API)
|
||||
initVal, err := i.Eval("Init")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no Init function: %w", err)
|
||||
}
|
||||
|
||||
initFn, ok := initVal.Interface().(func(API))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Init has wrong signature (want func(ext.API), got %T)", initVal.Interface())
|
||||
}
|
||||
|
||||
// Build the API object that wires typed registration methods back to
|
||||
// the extension's internal handler map. Each method wraps the concrete
|
||||
// handler into the internal HandlerFunc type.
|
||||
reg := func(event EventType, fn HandlerFunc) {
|
||||
ext.Handlers[event] = append(ext.Handlers[event], fn)
|
||||
}
|
||||
|
||||
api := API{
|
||||
onToolCall: func(h func(ToolCallEvent, Context) *ToolCallResult) {
|
||||
reg(ToolCall, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolCallEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
|
||||
reg(ToolExecutionStart, func(e Event, c Context) Result {
|
||||
h(e.(ToolExecutionStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolExecEnd: func(h func(ToolExecutionEndEvent, Context)) {
|
||||
reg(ToolExecutionEnd, func(e Event, c Context) Result {
|
||||
h(e.(ToolExecutionEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
|
||||
reg(ToolResult, func(e Event, c Context) Result {
|
||||
r := h(e.(ToolResultEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onInput: func(h func(InputEvent, Context) *InputResult) {
|
||||
reg(Input, func(e Event, c Context) Result {
|
||||
r := h(e.(InputEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onBeforeAgentStart: func(h func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
|
||||
reg(BeforeAgentStart, func(e Event, c Context) Result {
|
||||
r := h(e.(BeforeAgentStartEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onAgentStart: func(h func(AgentStartEvent, Context)) {
|
||||
reg(AgentStart, func(e Event, c Context) Result {
|
||||
h(e.(AgentStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onAgentEnd: func(h func(AgentEndEvent, Context)) {
|
||||
reg(AgentEnd, func(e Event, c Context) Result {
|
||||
h(e.(AgentEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageStart: func(h func(MessageStartEvent, Context)) {
|
||||
reg(MessageStart, func(e Event, c Context) Result {
|
||||
h(e.(MessageStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageUpdate: func(h func(MessageUpdateEvent, Context)) {
|
||||
reg(MessageUpdate, func(e Event, c Context) Result {
|
||||
h(e.(MessageUpdateEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onMessageEnd: func(h func(MessageEndEvent, Context)) {
|
||||
reg(MessageEnd, func(e Event, c Context) Result {
|
||||
h(e.(MessageEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSessionStart: func(h func(SessionStartEvent, Context)) {
|
||||
reg(SessionStart, func(e Event, c Context) Result {
|
||||
h(e.(SessionStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSessionShutdown: func(h func(SessionShutdownEvent, Context)) {
|
||||
reg(SessionShutdown, func(e Event, c Context) Result {
|
||||
h(e.(SessionShutdownEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
registerToolFn: func(tool ToolDef) {
|
||||
ext.Tools = append(ext.Tools, tool)
|
||||
},
|
||||
registerCmdFn: func(cmd CommandDef) {
|
||||
ext.Commands = append(ext.Commands, cmd)
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
initFn(api)
|
||||
|
||||
return ext, nil
|
||||
}
|
||||
|
||||
// countHandlers returns the total number of registered handlers across all events.
|
||||
func countHandlers(ext *LoadedExtension) int {
|
||||
n := 0
|
||||
for _, handlers := range ext.Handlers {
|
||||
n += len(handlers)
|
||||
}
|
||||
return n
|
||||
}
|
||||
@@ -0,0 +1,604 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDiscoverExtensionPaths_ExplicitFile(t *testing.T) {
|
||||
// Create a temp dir with a .go file.
|
||||
dir := t.TempDir()
|
||||
f := filepath.Join(dir, "my-ext.go")
|
||||
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := discoverExtensionPaths([]string{f})
|
||||
if len(paths) == 0 {
|
||||
t.Fatal("expected at least 1 path")
|
||||
}
|
||||
|
||||
abs, _ := filepath.Abs(f)
|
||||
found := false
|
||||
for _, p := range paths {
|
||||
if p == abs {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected %q in discovered paths %v", abs, paths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverExtensionPaths_ExplicitDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
f := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := discoverExtensionPaths([]string{dir})
|
||||
abs, _ := filepath.Abs(f)
|
||||
found := false
|
||||
for _, p := range paths {
|
||||
if p == abs {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected %q in discovered paths %v", abs, paths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverExtensionPaths_SubdirMainGo(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
subdir := filepath.Join(dir, "my-plugin")
|
||||
if err := os.MkdirAll(subdir, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
main := filepath.Join(subdir, "main.go")
|
||||
if err := os.WriteFile(main, []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := discoverExtensionPaths([]string{dir})
|
||||
abs, _ := filepath.Abs(main)
|
||||
found := false
|
||||
for _, p := range paths {
|
||||
if p == abs {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected %q in discovered paths %v", abs, paths)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverExtensionPaths_Dedup(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
f := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Pass the same file twice.
|
||||
paths := discoverExtensionPaths([]string{f, f})
|
||||
count := 0
|
||||
abs, _ := filepath.Abs(f)
|
||||
for _, p := range paths {
|
||||
if p == abs {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected dedup to 1, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverExtensionPaths_NonGoFileIgnored(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
f := filepath.Join(dir, "readme.txt")
|
||||
if err := os.WriteFile(f, []byte("hello"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
paths := discoverExtensionPaths([]string{f})
|
||||
for _, p := range paths {
|
||||
abs, _ := filepath.Abs(f)
|
||||
if p == abs {
|
||||
t.Error("non-.go file should not be discovered")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverExtensionPaths_NonexistentIgnored(t *testing.T) {
|
||||
paths := discoverExtensionPaths([]string{"/nonexistent/path/ext.go"})
|
||||
for _, p := range paths {
|
||||
if p == "/nonexistent/path/ext.go" {
|
||||
t.Error("nonexistent path should not be discovered")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindExtensionsInDir_EmptyDir(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
results := findExtensionsInDir(dir)
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindExtensionsInDir_NonexistentDir(t *testing.T) {
|
||||
results := findExtensionsInDir("/nonexistent/dir")
|
||||
if len(results) != 0 {
|
||||
t.Errorf("expected 0 results, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindExtensionsInDir_MixedContent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// .go file at top level
|
||||
if err := os.WriteFile(filepath.Join(dir, "ext.go"), []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// non-.go file (should be ignored)
|
||||
if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// subdir with main.go
|
||||
sub := filepath.Join(dir, "plugin")
|
||||
if err := os.MkdirAll(sub, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(sub, "main.go"), []byte("package main"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// subdir without main.go (should be ignored)
|
||||
empty := filepath.Join(dir, "empty")
|
||||
if err := os.MkdirAll(empty, 0755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
results := findExtensionsInDir(dir)
|
||||
if len(results) != 2 {
|
||||
t.Fatalf("expected 2 results, got %d: %v", len(results), results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_ValidExtension(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
return nil
|
||||
})
|
||||
api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) {
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "valid.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
if ext.Path != f {
|
||||
t.Errorf("expected path %q, got %q", f, ext.Path)
|
||||
}
|
||||
if len(ext.Handlers[ToolCall]) != 1 {
|
||||
t.Errorf("expected 1 ToolCall handler, got %d", len(ext.Handlers[ToolCall]))
|
||||
}
|
||||
if len(ext.Handlers[SessionStart]) != 1 {
|
||||
t.Errorf("expected 1 SessionStart handler, got %d", len(ext.Handlers[SessionStart]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_NoInitFunction(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
func Hello() string { return "hi" }
|
||||
`
|
||||
f := filepath.Join(dir, "noinit.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := loadSingleExtension(f)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing Init function")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_SyntaxError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
func Init( { broken }
|
||||
`
|
||||
f := filepath.Join(dir, "broken.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := loadSingleExtension(f)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for syntax error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_WrongSignature(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
func Init(s string) {}
|
||||
`
|
||||
f := filepath.Join(dir, "wrongsig.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err := loadSingleExtension(f)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for wrong Init signature")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_RegistersTool(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.RegisterTool(ext.ToolDef{
|
||||
Name: "my_tool",
|
||||
Description: "does stuff",
|
||||
Parameters: "{\"type\":\"object\"}",
|
||||
Execute: func(input string) (string, error) {
|
||||
return "result: " + input, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "toolreg.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
if len(ext.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(ext.Tools))
|
||||
}
|
||||
if ext.Tools[0].Name != "my_tool" {
|
||||
t.Errorf("expected tool name 'my_tool', got %q", ext.Tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_RegistersCommand(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "hello",
|
||||
Description: "says hello",
|
||||
Execute: func(args string) (string, error) {
|
||||
return "hello " + args, nil
|
||||
},
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "cmdreg.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
if len(ext.Commands) != 1 {
|
||||
t.Fatalf("expected 1 command, got %d", len(ext.Commands))
|
||||
}
|
||||
if ext.Commands[0].Name != "hello" {
|
||||
t.Errorf("expected command name 'hello', got %q", ext.Commands[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadExtensions_SkipsBadFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Good extension
|
||||
good := `package main
|
||||
import "kit/ext"
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, _ ext.Context) {})
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(dir, "good.go"), []byte(good), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Bad extension (syntax error)
|
||||
bad := `package main
|
||||
func Init( { broken }
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(dir, "bad.go"), []byte(bad), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
loaded, err := LoadExtensions([]string{dir})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have loaded the good one and skipped the bad one.
|
||||
if len(loaded) != 1 {
|
||||
t.Fatalf("expected 1 loaded extension, got %d", len(loaded))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_HandlerExecution(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName == "banned" {
|
||||
return &ext.ToolCallResult{Block: true, Reason: "tool is banned"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "blocker.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
|
||||
// Build a runner and test the handler actually works.
|
||||
r := NewRunner([]LoadedExtension{*ext})
|
||||
result, err := r.Emit(ToolCallEvent{ToolName: "banned", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
tcr, ok := result.(ToolCallResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolCallResult, got %T", result)
|
||||
}
|
||||
if !tcr.Block {
|
||||
t.Error("expected Block=true for banned tool")
|
||||
}
|
||||
if tcr.Reason != "tool is banned" {
|
||||
t.Errorf("expected reason 'tool is banned', got %q", tcr.Reason)
|
||||
}
|
||||
|
||||
// Non-banned tool should pass through.
|
||||
result2, _ := r.Emit(ToolCallEvent{ToolName: "allowed", Input: "{}"})
|
||||
if result2 != nil {
|
||||
t.Errorf("expected nil result for allowed tool, got %v", result2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalExtensionsDir_XDG(t *testing.T) {
|
||||
// Save and restore XDG_CONFIG_HOME.
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
os.Setenv("XDG_CONFIG_HOME", "/custom/config")
|
||||
dir := globalExtensionsDir()
|
||||
expected := "/custom/config/kit/extensions"
|
||||
if dir != expected {
|
||||
t.Errorf("expected %q, got %q", expected, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalExtensionsDir_Default(t *testing.T) {
|
||||
orig := os.Getenv("XDG_CONFIG_HOME")
|
||||
defer os.Setenv("XDG_CONFIG_HOME", orig)
|
||||
|
||||
os.Setenv("XDG_CONFIG_HOME", "")
|
||||
dir := globalExtensionsDir()
|
||||
home, _ := os.UserHomeDir()
|
||||
expected := filepath.Join(home, ".config", "kit", "extensions")
|
||||
if dir != expected {
|
||||
t.Errorf("expected %q, got %q", expected, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_ContextPrint(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if ie.Text == "!hello" && ctx.Print != nil {
|
||||
ctx.Print("Hello from extension!")
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "printer.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
|
||||
// Wire up a Print function and verify it's called.
|
||||
var printed []string
|
||||
r := NewRunner([]LoadedExtension{*ext})
|
||||
r.SetContext(Context{
|
||||
Print: func(text string) {
|
||||
printed = append(printed, text)
|
||||
},
|
||||
})
|
||||
|
||||
result, err := r.Emit(InputEvent{Text: "!hello", Source: "interactive"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ir, ok := result.(InputResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected InputResult, got %T", result)
|
||||
}
|
||||
if ir.Action != "handled" {
|
||||
t.Errorf("expected Action 'handled', got %q", ir.Action)
|
||||
}
|
||||
if len(printed) != 1 || printed[0] != "Hello from extension!" {
|
||||
t.Errorf("expected Print to capture 'Hello from extension!', got %v", printed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_ContextPrintInfo(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if ie.Text == "!info" && ctx.PrintInfo != nil {
|
||||
ctx.PrintInfo("Styled info from extension")
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
if ie.Text == "!error" && ctx.PrintError != nil {
|
||||
ctx.PrintError("Styled error from extension")
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "styled.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
|
||||
var infos, errors []string
|
||||
r := NewRunner([]LoadedExtension{*ext})
|
||||
r.SetContext(Context{
|
||||
PrintInfo: func(text string) { infos = append(infos, text) },
|
||||
PrintError: func(text string) { errors = append(errors, text) },
|
||||
})
|
||||
|
||||
result, _ := r.Emit(InputEvent{Text: "!info"})
|
||||
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
|
||||
t.Fatal("expected handled result for !info")
|
||||
}
|
||||
if len(infos) != 1 || infos[0] != "Styled info from extension" {
|
||||
t.Errorf("expected PrintInfo capture, got %v", infos)
|
||||
}
|
||||
|
||||
result, _ = r.Emit(InputEvent{Text: "!error"})
|
||||
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
|
||||
t.Fatal("expected handled result for !error")
|
||||
}
|
||||
if len(errors) != 1 || errors[0] != "Styled error from extension" {
|
||||
t.Errorf("expected PrintError capture, got %v", errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadSingleExtension_ContextPrintBlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
src := `package main
|
||||
|
||||
import "kit/ext"
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if ie.Text == "!status" && ctx.PrintBlock != nil {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: "All systems go\nModel: " + ctx.Model,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "test-ext",
|
||||
})
|
||||
return &ext.InputResult{Action: "handled"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
`
|
||||
f := filepath.Join(dir, "block.go")
|
||||
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ext, err := loadSingleExtension(f)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load extension: %v", err)
|
||||
}
|
||||
|
||||
var captured []PrintBlockOpts
|
||||
r := NewRunner([]LoadedExtension{*ext})
|
||||
r.SetContext(Context{
|
||||
Model: "claude-4",
|
||||
PrintBlock: func(opts PrintBlockOpts) {
|
||||
captured = append(captured, opts)
|
||||
},
|
||||
})
|
||||
|
||||
result, _ := r.Emit(InputEvent{Text: "!status", Source: "interactive"})
|
||||
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
|
||||
t.Fatal("expected handled result for !status")
|
||||
}
|
||||
if len(captured) != 1 {
|
||||
t.Fatalf("expected 1 PrintBlock call, got %d", len(captured))
|
||||
}
|
||||
if captured[0].BorderColor != "#a6e3a1" {
|
||||
t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor)
|
||||
}
|
||||
if captured[0].Subtitle != "test-ext" {
|
||||
t.Errorf("expected subtitle 'test-ext', got %q", captured[0].Subtitle)
|
||||
}
|
||||
// Verify the text includes the model from context.
|
||||
if captured[0].Text != "All systems go\nModel: claude-4" {
|
||||
t.Errorf("unexpected text: %q", captured[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountHandlers(t *testing.T) {
|
||||
ext := &LoadedExtension{
|
||||
Handlers: map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(Event, Context) Result { return nil }, func(Event, Context) Result { return nil }},
|
||||
SessionStart: {func(Event, Context) Result { return nil }},
|
||||
},
|
||||
}
|
||||
if n := countHandlers(ext); n != 3 {
|
||||
t.Errorf("expected 3 handlers, got %d", n)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// Runner manages loaded extensions and dispatches events to their handlers
|
||||
// sequentially, mirroring Pi's ExtensionRunner. Handlers execute in extension
|
||||
// load order; for cancellable events the first blocking result wins.
|
||||
type Runner struct {
|
||||
extensions []LoadedExtension
|
||||
ctx Context
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// LoadedExtension represents a single extension that has been discovered,
|
||||
// loaded, and initialised. It holds the registered handlers and any custom
|
||||
// tools or commands the extension provided.
|
||||
type LoadedExtension struct {
|
||||
Path string
|
||||
Handlers map[EventType][]HandlerFunc
|
||||
Tools []ToolDef
|
||||
Commands []CommandDef
|
||||
}
|
||||
|
||||
// NewRunner creates a Runner from a set of loaded extensions.
|
||||
func NewRunner(exts []LoadedExtension) *Runner {
|
||||
return &Runner{extensions: exts}
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
}
|
||||
|
||||
// HasHandlers returns true if any loaded extension has at least one handler
|
||||
// registered for the given event type.
|
||||
func (r *Runner) HasHandlers(event EventType) bool {
|
||||
for i := range r.extensions {
|
||||
if len(r.extensions[i].Handlers[event]) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Emit dispatches an event to all matching handlers sequentially. It returns
|
||||
// the accumulated result from all handlers, or nil if no handler responded.
|
||||
//
|
||||
// For blocking events (ToolCall, Input), the first blocking result short-circuits:
|
||||
// - ToolCallResult{Block: true} stops iteration and returns immediately.
|
||||
// - InputResult{Action: "handled"} stops iteration and returns immediately.
|
||||
//
|
||||
// For chainable events (ToolResult), each handler sees the accumulated result
|
||||
// from previous handlers. The final merged result is returned.
|
||||
//
|
||||
// Panics in handlers are recovered and logged; they do not crash the process.
|
||||
func (r *Runner) Emit(event Event) (Result, error) {
|
||||
r.mu.RLock()
|
||||
ctx := r.ctx
|
||||
r.mu.RUnlock()
|
||||
|
||||
var accumulated Result
|
||||
|
||||
for i := range r.extensions {
|
||||
ext := &r.extensions[i]
|
||||
handlers := ext.Handlers[event.Type()]
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
log.Warn("extension handler error",
|
||||
"path", ext.Path,
|
||||
"event", event.Type(),
|
||||
"err", err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for blocking/short-circuit results.
|
||||
if isBlocking(result) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Chain: keep the latest non-nil result. For ToolResultResult
|
||||
// the caller is responsible for applying the modifications.
|
||||
accumulated = result
|
||||
}
|
||||
}
|
||||
return accumulated, nil
|
||||
}
|
||||
|
||||
// RegisteredTools returns all custom tools registered by loaded extensions.
|
||||
func (r *Runner) RegisteredTools() []ToolDef {
|
||||
var tools []ToolDef
|
||||
for i := range r.extensions {
|
||||
tools = append(tools, r.extensions[i].Tools...)
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// RegisteredCommands returns all slash commands registered by loaded extensions.
|
||||
func (r *Runner) RegisteredCommands() []CommandDef {
|
||||
var cmds []CommandDef
|
||||
for i := range r.extensions {
|
||||
cmds = append(cmds, r.extensions[i].Commands...)
|
||||
}
|
||||
return cmds
|
||||
}
|
||||
|
||||
// Extensions returns the loaded extensions for inspection (e.g. CLI list).
|
||||
func (r *Runner) Extensions() []LoadedExtension {
|
||||
return r.extensions
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// safeCall invokes a handler, recovering from panics.
|
||||
func safeCall(handler HandlerFunc, event Event, ctx Context) (result Result, err error) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
err = fmt.Errorf("extension panicked: %v", rec)
|
||||
}
|
||||
}()
|
||||
return handler(event, ctx), nil
|
||||
}
|
||||
|
||||
// isBlocking returns true if the result should short-circuit further handlers.
|
||||
func isBlocking(result Result) bool {
|
||||
switch r := result.(type) {
|
||||
case ToolCallResult:
|
||||
return r.Block
|
||||
case InputResult:
|
||||
return r.Action == "handled"
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,573 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// makeRunner builds a Runner with the given extensions for testing.
|
||||
func makeRunner(exts ...LoadedExtension) *Runner {
|
||||
return NewRunner(exts)
|
||||
}
|
||||
|
||||
// makeHandlerExt creates a LoadedExtension with handlers registered for the given events.
|
||||
func makeHandlerExt(path string, handlers map[EventType][]HandlerFunc) LoadedExtension {
|
||||
return LoadedExtension{
|
||||
Path: path,
|
||||
Handlers: handlers,
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitNoHandlers(t *testing.T) {
|
||||
r := makeRunner()
|
||||
result, err := r.Emit(ToolCallEvent{ToolName: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != nil {
|
||||
t.Fatalf("expected nil result, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitSequentialOrder(t *testing.T) {
|
||||
var order []int
|
||||
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result { order = append(order, 1); return nil },
|
||||
},
|
||||
})
|
||||
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result { order = append(order, 2); return nil },
|
||||
},
|
||||
})
|
||||
ext3 := makeHandlerExt("ext3.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result { order = append(order, 3); return nil },
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext1, ext2, ext3)
|
||||
_, err := r.Emit(SessionStartEvent{SessionID: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 {
|
||||
t.Fatalf("expected sequential order [1,2,3], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitMultipleHandlersPerExtension(t *testing.T) {
|
||||
var calls int
|
||||
ext := makeHandlerExt("multi.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result { calls++; return nil },
|
||||
func(e Event, c Context) Result { calls++; return nil },
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Fatalf("expected 2 calls, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitToolCallBlocking(t *testing.T) {
|
||||
var secondCalled bool
|
||||
ext1 := makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result {
|
||||
return ToolCallResult{Block: true, Reason: "denied"}
|
||||
},
|
||||
},
|
||||
})
|
||||
ext2 := makeHandlerExt("second.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result {
|
||||
secondCalled = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext1, ext2)
|
||||
result, err := r.Emit(ToolCallEvent{ToolName: "bash", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if secondCalled {
|
||||
t.Error("second handler should not have been called after block")
|
||||
}
|
||||
tcr, ok := result.(ToolCallResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolCallResult, got %T", result)
|
||||
}
|
||||
if !tcr.Block {
|
||||
t.Error("expected Block=true")
|
||||
}
|
||||
if tcr.Reason != "denied" {
|
||||
t.Errorf("expected reason 'denied', got %q", tcr.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitToolCallNonBlocking(t *testing.T) {
|
||||
ext := makeHandlerExt("allow.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result {
|
||||
return ToolCallResult{Block: false}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(ToolCallEvent{ToolName: "bash"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
tcr, ok := result.(ToolCallResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolCallResult, got %T", result)
|
||||
}
|
||||
if tcr.Block {
|
||||
t.Error("expected Block=false for non-blocking result")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitInputBlocking(t *testing.T) {
|
||||
ext := makeHandlerExt("input-handler.go", map[EventType][]HandlerFunc{
|
||||
Input: {
|
||||
func(e Event, c Context) Result {
|
||||
return InputResult{Action: "handled"}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(InputEvent{Text: "secret", Source: "interactive"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ir, ok := result.(InputResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected InputResult, got %T", result)
|
||||
}
|
||||
if ir.Action != "handled" {
|
||||
t.Errorf("expected Action 'handled', got %q", ir.Action)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitInputTransform(t *testing.T) {
|
||||
ext := makeHandlerExt("transform.go", map[EventType][]HandlerFunc{
|
||||
Input: {
|
||||
func(e Event, c Context) Result {
|
||||
ie := e.(InputEvent)
|
||||
return InputResult{Action: "transform", Text: ie.Text + " transformed"}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(InputEvent{Text: "hello", Source: "cli"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
ir, ok := result.(InputResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected InputResult, got %T", result)
|
||||
}
|
||||
if ir.Action != "transform" {
|
||||
t.Errorf("expected Action 'transform', got %q", ir.Action)
|
||||
}
|
||||
if ir.Text != "hello transformed" {
|
||||
t.Errorf("expected transformed text, got %q", ir.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitToolResultChaining(t *testing.T) {
|
||||
modified := "modified content"
|
||||
ext := makeHandlerExt("modifier.go", map[EventType][]HandlerFunc{
|
||||
ToolResult: {
|
||||
func(e Event, c Context) Result {
|
||||
return ToolResultResult{Content: &modified}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(ToolResultEvent{ToolName: "read", Content: "original"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
trr, ok := result.(ToolResultResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected ToolResultResult, got %T", result)
|
||||
}
|
||||
if trr.Content == nil || *trr.Content != "modified content" {
|
||||
t.Error("expected content to be modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitPanicRecovery(t *testing.T) {
|
||||
var secondCalled bool
|
||||
ext := makeHandlerExt("panicker.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result { panic("boom") },
|
||||
func(e Event, c Context) Result { secondCalled = true; return nil },
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// After a panic, the runner should continue to the next handler.
|
||||
if !secondCalled {
|
||||
t.Error("second handler should still be called after panic in first")
|
||||
}
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitEventPassedCorrectly(t *testing.T) {
|
||||
var receivedName string
|
||||
var receivedInput string
|
||||
ext := makeHandlerExt("inspect.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result {
|
||||
tc := e.(ToolCallEvent)
|
||||
receivedName = tc.ToolName
|
||||
receivedInput = tc.Input
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, _ = r.Emit(ToolCallEvent{ToolName: "bash", ToolCallID: "123", Input: `{"cmd":"ls"}`})
|
||||
if receivedName != "bash" {
|
||||
t.Errorf("expected tool name 'bash', got %q", receivedName)
|
||||
}
|
||||
if receivedInput != `{"cmd":"ls"}` {
|
||||
t.Errorf("expected input '{\"cmd\":\"ls\"}', got %q", receivedInput)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_SetContext(t *testing.T) {
|
||||
var receivedCtx Context
|
||||
ext := makeHandlerExt("ctx.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
receivedCtx = c
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
r.SetContext(Context{
|
||||
SessionID: "sess-123",
|
||||
CWD: "/tmp",
|
||||
Model: "claude-4",
|
||||
Interactive: true,
|
||||
})
|
||||
|
||||
_, _ = r.Emit(SessionStartEvent{})
|
||||
if receivedCtx.SessionID != "sess-123" {
|
||||
t.Errorf("expected SessionID 'sess-123', got %q", receivedCtx.SessionID)
|
||||
}
|
||||
if receivedCtx.CWD != "/tmp" {
|
||||
t.Errorf("expected CWD '/tmp', got %q", receivedCtx.CWD)
|
||||
}
|
||||
if receivedCtx.Model != "claude-4" {
|
||||
t.Errorf("expected Model 'claude-4', got %q", receivedCtx.Model)
|
||||
}
|
||||
if !receivedCtx.Interactive {
|
||||
t.Error("expected Interactive=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_HasHandlers(t *testing.T) {
|
||||
ext := makeHandlerExt("test.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result { return nil },
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
if !r.HasHandlers(ToolCall) {
|
||||
t.Error("expected HasHandlers(ToolCall) = true")
|
||||
}
|
||||
if r.HasHandlers(SessionStart) {
|
||||
t.Error("expected HasHandlers(SessionStart) = false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_RegisteredTools(t *testing.T) {
|
||||
ext := LoadedExtension{
|
||||
Path: "tools.go",
|
||||
Handlers: make(map[EventType][]HandlerFunc),
|
||||
Tools: []ToolDef{
|
||||
{Name: "tool1", Description: "first"},
|
||||
{Name: "tool2", Description: "second"},
|
||||
},
|
||||
}
|
||||
|
||||
r := makeRunner(ext)
|
||||
tools := r.RegisteredTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Name != "tool1" || tools[1].Name != "tool2" {
|
||||
t.Error("tools not returned in expected order")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_RegisteredCommands(t *testing.T) {
|
||||
ext := LoadedExtension{
|
||||
Path: "cmds.go",
|
||||
Handlers: make(map[EventType][]HandlerFunc),
|
||||
Commands: []CommandDef{
|
||||
{Name: "cmd1", Description: "first"},
|
||||
},
|
||||
}
|
||||
|
||||
r := makeRunner(ext)
|
||||
cmds := r.RegisteredCommands()
|
||||
if len(cmds) != 1 {
|
||||
t.Fatalf("expected 1 command, got %d", len(cmds))
|
||||
}
|
||||
if cmds[0].Name != "cmd1" {
|
||||
t.Errorf("expected command name 'cmd1', got %q", cmds[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_Extensions(t *testing.T) {
|
||||
ext1 := makeHandlerExt("a.go", map[EventType][]HandlerFunc{})
|
||||
ext2 := makeHandlerExt("b.go", map[EventType][]HandlerFunc{})
|
||||
r := makeRunner(ext1, ext2)
|
||||
if len(r.Extensions()) != 2 {
|
||||
t.Fatalf("expected 2 extensions, got %d", len(r.Extensions()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitOnlyMatchingEvent(t *testing.T) {
|
||||
var called bool
|
||||
ext := makeHandlerExt("mismatch.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {
|
||||
func(e Event, c Context) Result { called = true; return nil },
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, _ = r.Emit(SessionStartEvent{}) // different event type
|
||||
if called {
|
||||
t.Error("ToolCall handler should not be called for SessionStart event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_EmitBeforeAgentStartResult(t *testing.T) {
|
||||
injected := "extra context"
|
||||
ext := makeHandlerExt("inject.go", map[EventType][]HandlerFunc{
|
||||
BeforeAgentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
return BeforeAgentStartResult{InjectText: &injected}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, err := r.Emit(BeforeAgentStartEvent{Prompt: "hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
bar, ok := result.(BeforeAgentStartResult)
|
||||
if !ok {
|
||||
t.Fatalf("expected BeforeAgentStartResult, got %T", result)
|
||||
}
|
||||
if bar.InjectText == nil || *bar.InjectText != "extra context" {
|
||||
t.Error("expected InjectText to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_LastResultWins(t *testing.T) {
|
||||
// When multiple handlers return non-nil, non-blocking results,
|
||||
// the last one should be returned (accumulated).
|
||||
first := "first"
|
||||
second := "second"
|
||||
ext := makeHandlerExt("chain.go", map[EventType][]HandlerFunc{
|
||||
ToolResult: {
|
||||
func(e Event, c Context) Result {
|
||||
return ToolResultResult{Content: &first}
|
||||
},
|
||||
func(e Event, c Context) Result {
|
||||
return ToolResultResult{Content: &second}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
result, _ := r.Emit(ToolResultEvent{ToolName: "test", Content: "orig"})
|
||||
trr := result.(ToolResultResult)
|
||||
if trr.Content == nil || *trr.Content != "second" {
|
||||
t.Errorf("expected last result to win, got %v", trr.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ContextPrint(t *testing.T) {
|
||||
var printed []string
|
||||
var receivedCtx Context
|
||||
ext := makeHandlerExt("print.go", map[EventType][]HandlerFunc{
|
||||
Input: {
|
||||
func(e Event, c Context) Result {
|
||||
receivedCtx = c
|
||||
if c.Print != nil {
|
||||
c.Print("hello from extension")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
r.SetContext(Context{
|
||||
Print: func(text string) {
|
||||
printed = append(printed, text)
|
||||
},
|
||||
})
|
||||
|
||||
_, _ = r.Emit(InputEvent{Text: "test"})
|
||||
if receivedCtx.Print == nil {
|
||||
t.Fatal("expected Print to be non-nil in context")
|
||||
}
|
||||
if len(printed) != 1 || printed[0] != "hello from extension" {
|
||||
t.Errorf("expected Print to capture 'hello from extension', got %v", printed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ContextPrintInfo(t *testing.T) {
|
||||
var infos []string
|
||||
ext := makeHandlerExt("info.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
if c.PrintInfo != nil {
|
||||
c.PrintInfo("extension loaded successfully")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
r.SetContext(Context{
|
||||
PrintInfo: func(text string) {
|
||||
infos = append(infos, text)
|
||||
},
|
||||
})
|
||||
|
||||
_, _ = r.Emit(SessionStartEvent{})
|
||||
if len(infos) != 1 || infos[0] != "extension loaded successfully" {
|
||||
t.Errorf("expected PrintInfo to capture message, got %v", infos)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ContextPrintError(t *testing.T) {
|
||||
var errors []string
|
||||
ext := makeHandlerExt("err.go", map[EventType][]HandlerFunc{
|
||||
ToolResult: {
|
||||
func(e Event, c Context) Result {
|
||||
tr := e.(ToolResultEvent)
|
||||
if tr.IsError && c.PrintError != nil {
|
||||
c.PrintError("tool failed: " + tr.ToolName)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
r.SetContext(Context{
|
||||
PrintError: func(text string) {
|
||||
errors = append(errors, text)
|
||||
},
|
||||
})
|
||||
|
||||
_, _ = r.Emit(ToolResultEvent{ToolName: "bash", IsError: true, Content: "exit 1"})
|
||||
if len(errors) != 1 || errors[0] != "tool failed: bash" {
|
||||
t.Errorf("expected PrintError to capture message, got %v", errors)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ContextPrintBlock(t *testing.T) {
|
||||
var captured []PrintBlockOpts
|
||||
ext := makeHandlerExt("block.go", map[EventType][]HandlerFunc{
|
||||
Input: {
|
||||
func(e Event, c Context) Result {
|
||||
if c.PrintBlock != nil {
|
||||
c.PrintBlock(PrintBlockOpts{
|
||||
Text: "deploy complete",
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "deploy-ext",
|
||||
})
|
||||
}
|
||||
return InputResult{Action: "handled"}
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
r.SetContext(Context{
|
||||
PrintBlock: func(opts PrintBlockOpts) {
|
||||
captured = append(captured, opts)
|
||||
},
|
||||
})
|
||||
|
||||
_, _ = r.Emit(InputEvent{Text: "!deploy"})
|
||||
if len(captured) != 1 {
|
||||
t.Fatalf("expected 1 PrintBlock call, got %d", len(captured))
|
||||
}
|
||||
if captured[0].Text != "deploy complete" {
|
||||
t.Errorf("expected text 'deploy complete', got %q", captured[0].Text)
|
||||
}
|
||||
if captured[0].BorderColor != "#a6e3a1" {
|
||||
t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor)
|
||||
}
|
||||
if captured[0].Subtitle != "deploy-ext" {
|
||||
t.Errorf("expected subtitle 'deploy-ext', got %q", captured[0].Subtitle)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ContextPrintNilSafe(t *testing.T) {
|
||||
// When Print/PrintInfo/PrintError/PrintBlock are not set (nil), guarded calls should not panic.
|
||||
ext := makeHandlerExt("nilprint.go", map[EventType][]HandlerFunc{
|
||||
Input: {
|
||||
func(e Event, c Context) Result {
|
||||
if c.Print != nil {
|
||||
c.Print("should not happen")
|
||||
}
|
||||
if c.PrintInfo != nil {
|
||||
c.PrintInfo("should not happen")
|
||||
}
|
||||
if c.PrintError != nil {
|
||||
c.PrintError("should not happen")
|
||||
}
|
||||
if c.PrintBlock != nil {
|
||||
c.PrintBlock(PrintBlockOpts{Text: "nope"})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
// Context without any Print functions set.
|
||||
r.SetContext(Context{Model: "test"})
|
||||
_, err := r.Emit(InputEvent{Text: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/traefik/yaegi/interp"
|
||||
)
|
||||
|
||||
// Symbols returns the Yaegi export table that makes KIT's extension API
|
||||
// available to interpreted Go code. Extensions import these types as:
|
||||
//
|
||||
// import "kit/ext"
|
||||
//
|
||||
// IMPORTANT: Only concrete types (structs, constants) are exported. Interfaces
|
||||
// (Event, Result) and the HandlerFunc type are NOT exported because Yaegi
|
||||
// cannot generate interface wrappers for them. Instead, extensions use
|
||||
// event-specific methods like api.OnToolCall() which accept concrete function
|
||||
// signatures.
|
||||
func Symbols() interp.Exports {
|
||||
return interp.Exports{
|
||||
"kit/ext/ext": map[string]reflect.Value{
|
||||
// Struct types (nil pointer trick for type registration)
|
||||
"API": reflect.ValueOf((*API)(nil)),
|
||||
"Context": reflect.ValueOf((*Context)(nil)),
|
||||
"ToolDef": reflect.ValueOf((*ToolDef)(nil)),
|
||||
"CommandDef": reflect.ValueOf((*CommandDef)(nil)),
|
||||
"PrintBlockOpts": reflect.ValueOf((*PrintBlockOpts)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
"ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)),
|
||||
"ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)),
|
||||
"ToolResultEvent": reflect.ValueOf((*ToolResultEvent)(nil)),
|
||||
"ToolResultResult": reflect.ValueOf((*ToolResultResult)(nil)),
|
||||
"InputEvent": reflect.ValueOf((*InputEvent)(nil)),
|
||||
"InputResult": reflect.ValueOf((*InputResult)(nil)),
|
||||
"BeforeAgentStartEvent": reflect.ValueOf((*BeforeAgentStartEvent)(nil)),
|
||||
"BeforeAgentStartResult": reflect.ValueOf((*BeforeAgentStartResult)(nil)),
|
||||
"AgentStartEvent": reflect.ValueOf((*AgentStartEvent)(nil)),
|
||||
"AgentEndEvent": reflect.ValueOf((*AgentEndEvent)(nil)),
|
||||
"MessageStartEvent": reflect.ValueOf((*MessageStartEvent)(nil)),
|
||||
"MessageUpdateEvent": reflect.ValueOf((*MessageUpdateEvent)(nil)),
|
||||
"MessageEndEvent": reflect.ValueOf((*MessageEndEvent)(nil)),
|
||||
"SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)),
|
||||
"SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// WrapToolsWithExtensions wraps each tool so that ToolCall and ToolResult
|
||||
// events are emitted through the extension runner before and after execution.
|
||||
// This is the Go equivalent of Pi's wrapper.ts pattern.
|
||||
//
|
||||
// If the runner has no relevant handlers the original tools are returned
|
||||
// unchanged (zero overhead).
|
||||
func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantasy.AgentTool {
|
||||
if runner == nil {
|
||||
return tools
|
||||
}
|
||||
if !runner.HasHandlers(ToolCall) && !runner.HasHandlers(ToolResult) &&
|
||||
!runner.HasHandlers(ToolExecutionStart) && !runner.HasHandlers(ToolExecutionEnd) {
|
||||
return tools
|
||||
}
|
||||
|
||||
wrapped := make([]fantasy.AgentTool, len(tools))
|
||||
for i, tool := range tools {
|
||||
wrapped[i] = &wrappedTool{inner: tool, runner: runner}
|
||||
}
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
|
||||
// into fantasy.AgentTool implementations so the LLM can invoke them.
|
||||
func ExtensionToolsAsFantasy(defs []ToolDef) []fantasy.AgentTool {
|
||||
tools := make([]fantasy.AgentTool, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
tools = append(tools, &extensionTool{def: def})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// wrappedTool — intercepts tool calls through the extension runner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type wrappedTool struct {
|
||||
inner fantasy.AgentTool
|
||||
runner *Runner
|
||||
}
|
||||
|
||||
func (w *wrappedTool) Info() fantasy.ToolInfo { return w.inner.Info() }
|
||||
func (w *wrappedTool) ProviderOptions() fantasy.ProviderOptions { return w.inner.ProviderOptions() }
|
||||
func (w *wrappedTool) SetProviderOptions(o fantasy.ProviderOptions) { w.inner.SetProviderOptions(o) }
|
||||
|
||||
func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolName := w.inner.Info().Name
|
||||
|
||||
// 1. Emit ToolCall — extensions can block execution.
|
||||
if w.runner.HasHandlers(ToolCall) {
|
||||
result, _ := w.runner.Emit(ToolCallEvent{
|
||||
ToolName: toolName,
|
||||
ToolCallID: call.ID,
|
||||
Input: call.Input,
|
||||
})
|
||||
if r, ok := result.(ToolCallResult); ok && r.Block {
|
||||
reason := r.Reason
|
||||
if reason == "" {
|
||||
reason = "blocked by extension"
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
fmt.Errorf("tool blocked by extension: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Emit ToolExecutionStart.
|
||||
if w.runner.HasHandlers(ToolExecutionStart) {
|
||||
_, _ = w.runner.Emit(ToolExecutionStartEvent{ToolName: toolName})
|
||||
}
|
||||
|
||||
// 3. Execute the actual tool.
|
||||
resp, err := w.inner.Run(ctx, call)
|
||||
|
||||
// 4. Emit ToolExecutionEnd.
|
||||
if w.runner.HasHandlers(ToolExecutionEnd) {
|
||||
_, _ = w.runner.Emit(ToolExecutionEndEvent{ToolName: toolName})
|
||||
}
|
||||
|
||||
// 5. Emit ToolResult — extensions can modify output.
|
||||
if w.runner.HasHandlers(ToolResult) {
|
||||
result, _ := w.runner.Emit(ToolResultEvent{
|
||||
ToolName: toolName,
|
||||
Input: call.Input,
|
||||
Content: resp.Content,
|
||||
IsError: err != nil || resp.IsError,
|
||||
})
|
||||
if r, ok := result.(ToolResultResult); ok {
|
||||
if r.Content != nil {
|
||||
resp.Content = *r.Content
|
||||
}
|
||||
if r.IsError != nil {
|
||||
resp.IsError = *r.IsError
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type extensionTool struct {
|
||||
def ToolDef
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
func (t *extensionTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: t.def.Name,
|
||||
Description: t.def.Description,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *extensionTool) ProviderOptions() fantasy.ProviderOptions { return t.providerOptions }
|
||||
func (t *extensionTool) SetProviderOptions(o fantasy.ProviderOptions) { t.providerOptions = o }
|
||||
|
||||
func (t *extensionTool) Run(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.def.Execute(call.Input)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), err
|
||||
}
|
||||
return fantasy.NewTextResponse(result), nil
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// mockTool implements fantasy.AgentTool for testing.
|
||||
type mockTool struct {
|
||||
name string
|
||||
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
|
||||
provOpt fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
func (m *mockTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
|
||||
}
|
||||
func (m *mockTool) ProviderOptions() fantasy.ProviderOptions { return m.provOpt }
|
||||
func (m *mockTool) SetProviderOptions(o fantasy.ProviderOptions) { m.provOpt = o }
|
||||
func (m *mockTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
if m.runFn != nil {
|
||||
return m.runFn(ctx, call)
|
||||
}
|
||||
return fantasy.NewTextResponse("ok"), nil
|
||||
}
|
||||
|
||||
func newMockTool(name string) *mockTool {
|
||||
return &mockTool{name: name}
|
||||
}
|
||||
|
||||
func TestWrapToolsWithExtensions_NilRunner(t *testing.T) {
|
||||
tools := []fantasy.AgentTool{newMockTool("test")}
|
||||
result := WrapToolsWithExtensions(tools, nil)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(result))
|
||||
}
|
||||
// Should be the same pointer (unwrapped).
|
||||
if result[0] != tools[0] {
|
||||
t.Error("expected original tool when runner is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapToolsWithExtensions_NoRelevantHandlers(t *testing.T) {
|
||||
r := makeRunner(makeHandlerExt("other.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {func(e Event, c Context) Result { return nil }},
|
||||
}))
|
||||
tools := []fantasy.AgentTool{newMockTool("test")}
|
||||
result := WrapToolsWithExtensions(tools, r)
|
||||
if result[0] != tools[0] {
|
||||
t.Error("expected original tool when no tool handlers exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapToolsWithExtensions_WrapsWhenHandlersExist(t *testing.T) {
|
||||
r := makeRunner(makeHandlerExt("tc.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(e Event, c Context) Result { return nil }},
|
||||
}))
|
||||
tools := []fantasy.AgentTool{newMockTool("test")}
|
||||
result := WrapToolsWithExtensions(tools, r)
|
||||
if result[0] == tools[0] {
|
||||
t.Error("expected wrapped tool when ToolCall handlers exist")
|
||||
}
|
||||
// Verify Info() is passed through.
|
||||
if result[0].Info().Name != "test" {
|
||||
t.Errorf("expected name 'test', got %q", result[0].Info().Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrappedTool_NormalExecution(t *testing.T) {
|
||||
var toolCallSeen, toolResultSeen bool
|
||||
r := makeRunner(makeHandlerExt("observe.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(e Event, c Context) Result {
|
||||
toolCallSeen = true
|
||||
return nil
|
||||
}},
|
||||
ToolResult: {func(e Event, c Context) Result {
|
||||
toolResultSeen = true
|
||||
return nil
|
||||
}},
|
||||
}))
|
||||
|
||||
mock := newMockTool("bash")
|
||||
mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("output"), nil
|
||||
}
|
||||
|
||||
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1", Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "output" {
|
||||
t.Errorf("expected 'output', got %q", resp.Content)
|
||||
}
|
||||
if !toolCallSeen {
|
||||
t.Error("ToolCall handler was not invoked")
|
||||
}
|
||||
if !toolResultSeen {
|
||||
t.Error("ToolResult handler was not invoked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrappedTool_BlockExecution(t *testing.T) {
|
||||
var toolRan bool
|
||||
r := makeRunner(makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(e Event, c Context) Result {
|
||||
return ToolCallResult{Block: true, Reason: "forbidden"}
|
||||
}},
|
||||
}))
|
||||
|
||||
mock := newMockTool("danger")
|
||||
mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
toolRan = true
|
||||
return fantasy.NewTextResponse("bad"), nil
|
||||
}
|
||||
|
||||
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
|
||||
if toolRan {
|
||||
t.Error("tool should not have run after block")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("expected error from blocked tool")
|
||||
}
|
||||
if resp.IsError != true {
|
||||
t.Error("expected IsError=true from blocked response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrappedTool_ModifyResult(t *testing.T) {
|
||||
modified := "redacted"
|
||||
r := makeRunner(makeHandlerExt("redactor.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(e Event, c Context) Result { return nil }},
|
||||
ToolResult: {func(e Event, c Context) Result {
|
||||
return ToolResultResult{Content: &modified}
|
||||
}},
|
||||
}))
|
||||
|
||||
mock := newMockTool("read")
|
||||
mock.runFn = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("secret data"), nil
|
||||
}
|
||||
|
||||
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "redacted" {
|
||||
t.Errorf("expected 'redacted', got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
|
||||
var startSeen, endSeen bool
|
||||
r := makeRunner(makeHandlerExt("lifecycle.go", map[EventType][]HandlerFunc{
|
||||
ToolCall: {func(e Event, c Context) Result { return nil }},
|
||||
ToolExecutionStart: {func(e Event, c Context) Result { startSeen = true; return nil }},
|
||||
ToolExecutionEnd: {func(e Event, c Context) Result { endSeen = true; return nil }},
|
||||
}))
|
||||
|
||||
tools := WrapToolsWithExtensions([]fantasy.AgentTool{newMockTool("test")}, r)
|
||||
_, _ = tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
|
||||
if !startSeen {
|
||||
t.Error("ToolExecutionStart not emitted")
|
||||
}
|
||||
if !endSeen {
|
||||
t.Error("ToolExecutionEnd not emitted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
defs := []ToolDef{
|
||||
{
|
||||
Name: "greet",
|
||||
Description: "greets someone",
|
||||
Parameters: `{"type":"object"}`,
|
||||
Execute: func(input string) (string, error) { return "hello " + input, nil },
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
|
||||
info := tools[0].Info()
|
||||
if info.Name != "greet" {
|
||||
t.Errorf("expected name 'greet', got %q", info.Name)
|
||||
}
|
||||
if info.Description != "greets someone" {
|
||||
t.Errorf("expected description 'greets someone', got %q", info.Description)
|
||||
}
|
||||
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "world"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.Content != "hello world" {
|
||||
t.Errorf("expected 'hello world', got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionTool_Error(t *testing.T) {
|
||||
defs := []ToolDef{
|
||||
{
|
||||
Name: "fail",
|
||||
Execute: func(input string) (string, error) { return "", context.DeadlineExceeded },
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionTool_ProviderOptions(t *testing.T) {
|
||||
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
|
||||
tools := ExtensionToolsAsFantasy(defs)
|
||||
|
||||
// Initially nil.
|
||||
opts := tools[0].ProviderOptions()
|
||||
if opts != nil {
|
||||
t.Error("expected nil ProviderOptions initially")
|
||||
}
|
||||
|
||||
// SetProviderOptions round-trips.
|
||||
po := fantasy.ProviderOptions{}
|
||||
tools[0].SetProviderOptions(po)
|
||||
got := tools[0].ProviderOptions()
|
||||
if got == nil {
|
||||
t.Error("expected non-nil ProviderOptions after set")
|
||||
}
|
||||
}
|
||||
@@ -177,6 +177,32 @@ func (c *CLI) DisplayInfo(message string) {
|
||||
c.displayContainer()
|
||||
}
|
||||
|
||||
// DisplayExtensionBlock renders a custom styled block with the given border
|
||||
// color and optional subtitle. Used by extensions via ctx.PrintBlock.
|
||||
func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
theme := GetTheme()
|
||||
|
||||
var borderClr = lipgloss.Color("#89b4fa")
|
||||
if borderColor != "" {
|
||||
borderClr = lipgloss.Color(borderColor)
|
||||
}
|
||||
|
||||
content := text
|
||||
if subtitle != "" {
|
||||
sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + subtitle)
|
||||
content = content + "\n" + sub
|
||||
}
|
||||
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
c.messageRenderer.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(borderClr),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
|
||||
@@ -129,6 +129,19 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
|
||||
h.lastDisplayed = e.Content
|
||||
}
|
||||
|
||||
case app.ExtensionPrintEvent:
|
||||
h.stopSpinner()
|
||||
switch e.Level {
|
||||
case "info":
|
||||
h.cli.DisplayInfo(e.Text)
|
||||
case "error":
|
||||
h.cli.DisplayError(fmt.Errorf("%s", e.Text))
|
||||
case "block":
|
||||
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
|
||||
default:
|
||||
fmt.Println(e.Text)
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
h.stopSpinner()
|
||||
|
||||
|
||||
@@ -530,6 +530,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.state = stateInput
|
||||
m.canceling = false
|
||||
|
||||
case app.ExtensionPrintEvent:
|
||||
// Extension output — route through styled renderers when a level is set.
|
||||
switch msg.Level {
|
||||
case "info":
|
||||
cmds = append(cmds, m.printSystemMessage(msg.Text))
|
||||
case "error":
|
||||
cmds = append(cmds, m.printErrorResponse(app.StepErrorEvent{
|
||||
Err: fmt.Errorf("%s", msg.Text),
|
||||
}))
|
||||
case "block":
|
||||
cmds = append(cmds, m.printExtensionBlock(msg))
|
||||
default:
|
||||
cmds = append(cmds, tea.Println(msg.Text))
|
||||
}
|
||||
|
||||
default:
|
||||
// Pass unrecognised messages to all children.
|
||||
if m.input != nil {
|
||||
@@ -791,6 +806,34 @@ func (m *AppModel) printSystemMessage(text string) tea.Cmd {
|
||||
return tea.Println(rendered)
|
||||
}
|
||||
|
||||
// printExtensionBlock renders a custom styled block from an extension with
|
||||
// caller-chosen border color and optional subtitle, then emits it to scrollback.
|
||||
func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) tea.Cmd {
|
||||
theme := GetTheme()
|
||||
|
||||
// Resolve border color: use the extension's hex value, fall back to theme accent.
|
||||
var borderClr = lipgloss.Color("#89b4fa") // default blue
|
||||
if evt.BorderColor != "" {
|
||||
borderClr = lipgloss.Color(evt.BorderColor)
|
||||
}
|
||||
|
||||
// Build content: main text + optional subtitle line.
|
||||
content := evt.Text
|
||||
if evt.Subtitle != "" {
|
||||
sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + evt.Subtitle)
|
||||
content = strings.TrimSuffix(content, "\n") + "\n" + sub
|
||||
}
|
||||
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
m.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(borderClr),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
return tea.Println(rendered)
|
||||
}
|
||||
|
||||
// printHelpMessage renders the help text listing all available slash commands.
|
||||
func (m *AppModel) printHelpMessage() tea.Cmd {
|
||||
help := "## Available Commands\n\n" +
|
||||
|
||||
Reference in New Issue
Block a user