add Yaegi-based in-process extension system with event handlers, tool/command registration, and styled output

Implement a Pi-style extension system where plain .go files are loaded at
runtime via Yaegi. Extensions register typed event handlers against 13
lifecycle events (tool_call, tool_result, input, before_agent_start, etc.)
using concrete-type-only API methods to avoid Yaegi interface panics.

Key capabilities:
- Tool interception: block calls, modify results (wrapper pattern)
- Input handling: transform or fully handle user input (skip agent)
- System prompt injection via BeforeAgentStartResult
- Custom tool and slash command registration
- Styled output: ctx.Print, PrintInfo, PrintError, PrintBlock
- Legacy hooks.yml compatibility via adapter
- Auto-discovery from ~/.config/kit/extensions/ and .kit/extensions/
- CLI: kit extensions list|validate|init, --no-extensions, -e flags
- 58 unit tests covering runner, loader (Yaegi), wrapper, events
This commit is contained in:
Ed Zynda
2026-02-27 00:08:48 +03:00
parent dd018b65ec
commit f42d487214
26 changed files with 3246 additions and 4 deletions
+175
View File
@@ -0,0 +1,175 @@
package cmd
import (
"fmt"
"os"
"text/tabwriter"
"github.com/mark3labs/kit/internal/extensions"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var extensionsCmd = &cobra.Command{
Use: "extensions",
Short: "Manage KIT extensions",
Long: "Commands for listing, validating, and scaffolding KIT extensions",
}
var extensionsListCmd = &cobra.Command{
Use: "list",
Short: "List discovered extensions and their handlers",
RunE: func(cmd *cobra.Command, args []string) error {
loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension"))
if err != nil {
return fmt.Errorf("loading extensions: %w", err)
}
if len(loaded) == 0 {
fmt.Println("No extensions found.")
fmt.Println()
fmt.Println("Extension search paths:")
fmt.Println(" ~/.config/kit/extensions/*.go (global)")
fmt.Println(" .kit/extensions/*.go (project)")
fmt.Println()
fmt.Println("Run 'kit extensions init' to create an example extension.")
return nil
}
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
_, _ = fmt.Fprintln(w, "EXTENSION\tEVENT\tHANDLERS\tTOOLS\tCOMMANDS")
for _, ext := range loaded {
totalHandlers := 0
for _, handlers := range ext.Handlers {
totalHandlers += len(handlers)
}
first := true
for event, handlers := range ext.Handlers {
if first {
_, _ = fmt.Fprintf(w, "%s\t%s\t%d\t%d\t%d\n",
ext.Path, event, len(handlers), len(ext.Tools), len(ext.Commands))
first = false
} else {
_, _ = fmt.Fprintf(w, "\t%s\t%d\t\t\n",
event, len(handlers))
}
}
if first {
// Extension loaded but registered no handlers
_, _ = fmt.Fprintf(w, "%s\t(none)\t0\t%d\t%d\n",
ext.Path, len(ext.Tools), len(ext.Commands))
}
}
return w.Flush()
},
}
var extensionsValidateCmd = &cobra.Command{
Use: "validate",
Short: "Validate all extension files can be loaded",
RunE: func(cmd *cobra.Command, args []string) error {
loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension"))
if err != nil {
return fmt.Errorf("validation failed: %w", err)
}
fmt.Printf("Loaded %d extension(s) successfully\n", len(loaded))
for _, ext := range loaded {
total := 0
for _, h := range ext.Handlers {
total += len(h)
}
fmt.Printf(" %s (%d handlers, %d tools, %d commands)\n",
ext.Path, total, len(ext.Tools), len(ext.Commands))
}
return nil
},
}
var extensionsInitCmd = &cobra.Command{
Use: "init",
Short: "Generate an example extension file",
RunE: func(cmd *cobra.Command, args []string) error {
dir := ".kit/extensions"
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("creating extensions directory: %w", err)
}
example := `package main
import (
"fmt"
"os"
"strings"
"time"
"kit/ext"
)
// Init is called when the extension is loaded. Register handlers here.
func Init(api ext.API) {
// Log every tool call to a file.
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
f, err := os.OpenFile("/tmp/kit-tool-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
defer f.Close()
fmt.Fprintf(f, "[%s] tool=%s\n", time.Now().Format(time.RFC3339), tc.ToolName)
}
return nil // don't block
})
// Block dangerous bash commands.
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
if tc.ToolName == "bash" && strings.Contains(tc.Input, "rm -rf /") {
return &ext.ToolCallResult{Block: true, Reason: "Blocked: dangerous command"}
}
return nil
})
// Handle custom ! commands. Use ctx.Print/PrintInfo/PrintError/PrintBlock
// instead of fmt.Println — BubbleTea captures stdout in interactive mode.
//
// ctx.Print("text") — plain text
// ctx.PrintInfo("text") — styled system message block
// ctx.PrintError("text") — styled error block
// ctx.PrintBlock(opts) — custom block with border color and subtitle
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
switch ie.Text {
case "!time":
ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339))
return &ext.InputResult{Action: "handled"}
case "!status":
ctx.PrintBlock(ext.PrintBlockOpts{
Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD,
BorderColor: "#a6e3a1",
Subtitle: "my-extension",
})
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
path := dir + "/example.go"
if err := os.WriteFile(path, []byte(example), 0644); err != nil {
return fmt.Errorf("writing example: %w", err)
}
fmt.Printf("Created %s with example extension\n", path)
fmt.Println()
fmt.Println("The extension will be auto-loaded on the next kit run.")
fmt.Println("Use --no-extensions to disable all extensions.")
return nil
},
}
func init() {
rootCmd.AddCommand(extensionsCmd)
extensionsCmd.AddCommand(extensionsListCmd)
extensionsCmd.AddCommand(extensionsValidateCmd)
extensionsCmd.AddCommand(extensionsInitCmd)
}
+27 -2
View File
@@ -14,6 +14,7 @@ import (
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/ui"
"github.com/spf13/cobra"
@@ -57,7 +58,9 @@ var (
numGPU int32
mainGPU int32
// Hooks control
// Extensions control
noExtensionsFlag bool
extensionPaths []string
// TLS configuration
tlsSkipVerify bool
@@ -301,6 +304,10 @@ func init() {
BoolVarP(&resumeFlag, "resume", "r", false, "interactive session picker")
rootCmd.PersistentFlags().
BoolVar(&noSessionFlag, "no-session", false, "ephemeral mode — no session persistence")
rootCmd.PersistentFlags().
BoolVar(&noExtensionsFlag, "no-extensions", false, "disable all extensions and hooks")
rootCmd.PersistentFlags().
StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)")
flags := rootCmd.PersistentFlags()
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
@@ -338,6 +345,8 @@ func init() {
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
_ = viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu"))
_ = viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify"))
_ = viper.BindPFlag("no-extensions", rootCmd.PersistentFlags().Lookup("no-extensions"))
_ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension"))
// Defaults are already set in flag definitions, no need to duplicate in viper
@@ -542,7 +551,7 @@ func runNormalMode(ctx context.Context) error {
}
// Create the app.App instance now that session messages are loaded.
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames)
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner)
appOpts.SessionManager = sessionManager
appOpts.TreeSession = treeSession
@@ -564,6 +573,22 @@ func runNormalMode(ctx context.Context) error {
appInstance := app.New(appOpts, messages)
defer appInstance.Close()
// Emit SessionStart event to extensions.
if agentResult.ExtRunner != nil {
agentResult.ExtRunner.SetContext(extensions.Context{
CWD: cwd,
Model: modelName,
Interactive: promptFlag == "",
Print: func(text string) { appInstance.PrintFromExtension("", text) },
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
PrintBlock: appInstance.PrintBlockFromExtension,
})
if agentResult.ExtRunner.HasHandlers(extensions.SessionStart) {
_, _ = agentResult.ExtRunner.Emit(extensions.SessionStartEvent{})
}
}
// Check if running in non-interactive mode
if promptFlag != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, promptFlag, quietFlag, noExitFlag, modelName, parsedProvider, mcpAgent.GetLoadingMessage(), serverNames, toolNames, usageTracker)
+1 -1
View File
@@ -534,7 +534,7 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string,
DisplayDebugConfig(cli, mcpAgent, mcpConfig, parsedProvider)
// Build app options.
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames)
appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner)
if cli != nil {
if tracker := cli.GetUsageTracker(); tracker != nil {
appOpts.UsageTracker = tracker
+72 -1
View File
@@ -5,9 +5,13 @@ import (
"fmt"
"strings"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/hooks"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/tools"
"github.com/mark3labs/kit/internal/ui"
@@ -70,6 +74,9 @@ type AgentSetupOptions struct {
type AgentSetupResult struct {
Agent *agent.Agent
BufferedLogger *tools.BufferedDebugLogger
// ExtRunner is the extension runner (nil when --no-extensions or no
// extensions were discovered).
ExtRunner *extensions.Runner
}
// SetupAgent creates an agent from the current viper state + the provided
@@ -93,6 +100,20 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
}
}
// Load extensions unless --no-extensions is set. Extensions must be loaded
// BEFORE agent creation so their tool wrapper and custom tools are included
// in the Fantasy agent's tool list.
var extRunner *extensions.Runner
var extCreationOpts extensionCreationOpts
if !viper.GetBool("no-extensions") {
var extErr error
extRunner, extCreationOpts, extErr = setupExtensions()
if extErr != nil {
// Extension loading failures are non-fatal.
fmt.Printf("Warning: Failed to load extensions: %v\n", extErr)
}
}
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
@@ -103,6 +124,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
Quiet: quietFlag,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
ToolWrapper: extCreationOpts.toolWrapper,
ExtraTools: extCreationOpts.extraTools,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)
@@ -110,10 +133,57 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
return &AgentSetupResult{
Agent: a,
ExtRunner: extRunner,
BufferedLogger: bufferedLogger,
}, nil
}
// extensionCreationOpts holds the tool wrapper and extra tools that need to be
// passed into agent creation, extracted from loaded extensions.
type extensionCreationOpts struct {
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
extraTools []fantasy.AgentTool
}
// setupExtensions discovers and loads Yaegi extensions plus legacy hooks.yml,
// builds the runner, and returns the tool wrapper/extra tools needed by the
// agent factory.
func setupExtensions() (*extensions.Runner, extensionCreationOpts, error) {
extraPaths := viper.GetStringSlice("extension")
loaded, err := extensions.LoadExtensions(extraPaths)
if err != nil {
return nil, extensionCreationOpts{}, err
}
// Also load legacy hooks.yml as a compat extension.
hooksCfg, _ := hooks.LoadHooksConfig()
if hooksCfg != nil && len(hooksCfg.Hooks) > 0 {
compat := extensions.HooksAsExtension(hooksCfg)
if compat != nil {
loaded = append([]extensions.LoadedExtension{*compat}, loaded...)
}
}
if len(loaded) == 0 {
return nil, extensionCreationOpts{}, nil
}
runner := extensions.NewRunner(loaded)
// Build the tool wrapper that intercepts tool calls through the runner.
wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool {
return extensions.WrapToolsWithExtensions(tools, runner)
}
// Collect custom tools registered by extensions.
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools())
return runner, extensionCreationOpts{
toolWrapper: wrapper,
extraTools: extTools,
}, nil
}
// CollectAgentMetadata extracts model display info and tool/server name lists
// from the agent. This is used by both root.go and script.go to populate
// app.Options and UI setup.
@@ -138,7 +208,7 @@ func CollectAgentMetadata(mcpAgent *agent.Agent, mcpConfig *config.Config) (prov
// BuildAppOptions constructs the app.Options struct from the current state.
// Both root.go and script.go converge here after agent creation.
func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string) app.Options {
func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string, extRunner *extensions.Runner) app.Options {
return app.Options{
Agent: mcpAgent,
MCPConfig: mcpConfig,
@@ -149,6 +219,7 @@ func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName
Quiet: quietFlag,
Debug: viper.GetBool("debug"),
CompactMode: viper.GetBool("compact"),
Extensions: extRunner,
}
}
+81
View File
@@ -0,0 +1,81 @@
//go:build ignore
package main
import (
"fmt"
"os"
"time"
"kit/ext"
)
// Init registers handlers that log all tool calls and session lifecycle
// events to /tmp/kit-tool-log.txt.
func Init(api ext.API) {
logFile := "/tmp/kit-tool-log.txt"
// Log every tool call before execution.
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
defer f.Close()
fmt.Fprintf(f, "[%s] CALL tool=%s model=%s\n",
time.Now().Format(time.RFC3339), tc.ToolName, ctx.Model)
}
return nil
})
// Log tool results after execution.
api.OnToolResult(func(tr ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
defer f.Close()
status := "ok"
if tr.IsError {
status = "error"
}
fmt.Fprintf(f, "[%s] RESULT tool=%s status=%s bytes=%d\n",
time.Now().Format(time.RFC3339), tr.ToolName, status, len(tr.Content))
}
return nil // don't modify the result
})
// Log session start/shutdown.
api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) {
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
defer f.Close()
fmt.Fprintf(f, "[%s] SESSION_START cwd=%s\n",
time.Now().Format(time.RFC3339), ctx.CWD)
}
})
api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) {
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
defer f.Close()
fmt.Fprintf(f, "[%s] SESSION_SHUTDOWN\n",
time.Now().Format(time.RFC3339))
}
})
// "!time" — prints the current time as a styled info block.
// "!status" — prints a custom block with green border and subtitle.
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
switch ie.Text {
case "!time":
ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339))
return &ext.InputResult{Action: "handled"}
case "!status":
ctx.PrintBlock(ext.PrintBlockOpts{
Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD,
BorderColor: "#a6e3a1",
Subtitle: "tool-logger extension",
})
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
+4
View File
@@ -48,6 +48,7 @@ require (
github.com/charmbracelet/colorprofile v0.4.2 // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/log v0.4.2 // indirect
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260223200540-d6a276319c45 // indirect
@@ -61,6 +62,7 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
@@ -95,6 +97,7 @@ require (
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/traefik/yaegi v0.16.1 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
@@ -108,6 +111,7 @@ require (
go.opentelemetry.io/otel/trace v1.40.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/oauth2 v0.35.0 // indirect
golang.org/x/time v0.14.0 // indirect
+6
View File
@@ -86,6 +86,8 @@ github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA=
github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig=
github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw=
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 h1:Af/L28Xh+pddhouT/6lJ7IAIYfu5tWJOB0iqt+mXsYM=
github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73/go.mod h1:E6/0abq9uG2SnM8IbLB9Y5SW09uIgfaFETk8aRzgXUQ=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
@@ -132,6 +134,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -251,6 +255,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
+19
View File
@@ -24,6 +24,15 @@ type AgentConfig struct {
MaxSteps int
StreamingEnabled bool
DebugLogger tools.DebugLogger
// ToolWrapper is an optional function that wraps the combined tool list
// before it is passed to the Fantasy agent. Used by the extensions system
// to intercept tool calls/results.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include alongside core and MCP tools.
// Used by extensions to register custom tools.
ExtraTools []fantasy.AgentTool
}
// ToolCallHandler is a function type for handling tool calls as they happen.
@@ -109,6 +118,16 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
}
}
// Append any extra tools provided by extensions.
if len(agentConfig.ExtraTools) > 0 {
allTools = append(allTools, agentConfig.ExtraTools...)
}
// Apply tool wrapper (extension interception layer) if configured.
if agentConfig.ToolWrapper != nil {
allTools = agentConfig.ToolWrapper(allTools)
}
// Build fantasy agent options
var agentOpts []fantasy.AgentOption
+8
View File
@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/tools"
@@ -34,6 +36,10 @@ type AgentCreationOptions struct {
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
// DebugLogger is an optional logger for debugging MCP communications
DebugLogger tools.DebugLogger // Optional debug logger
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include (e.g. from extensions).
ExtraTools []fantasy.AgentTool
}
// CreateAgent creates an agent with optional spinner for Ollama models.
@@ -47,6 +53,8 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
}
var agent *Agent
+131
View File
@@ -9,6 +9,7 @@ import (
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/session"
)
@@ -254,6 +255,11 @@ func (a *App) Close() {
cancel := a.cancelStep
a.mu.Unlock()
// --- Extension: SessionShutdown ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.SessionShutdown) {
_, _ = a.opts.Extensions.Emit(extensions.SessionShutdownEvent{})
}
// Cancel any in-flight step and the root context.
cancel()
a.rootCancel()
@@ -362,6 +368,23 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
}
}
// --- Extension: Input event (can transform or handle the prompt) ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.Input) {
result, _ := a.opts.Extensions.Emit(extensions.InputEvent{
Text: prompt,
Source: a.inputSource(),
})
if r, ok := result.(extensions.InputResult); ok {
switch r.Action {
case "transform":
prompt = r.Text
case "handled":
// Extension handled the input; skip the agent entirely.
return &agent.GenerateWithLoopResult{}, nil
}
}
}
// Add user message to the store immediately so history is consistent
// even if the step is later cancelled.
userMsg := fantasy.NewUserMessage(prompt)
@@ -385,9 +408,39 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
// Track message count before agent runs so we can diff new messages.
sentCount := len(msgs)
// --- Extension: BeforeAgentStart ---
// Extensions can inject a system message or prepend context text into the
// conversation before the agent runs.
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.BeforeAgentStart) {
result, _ := a.opts.Extensions.Emit(extensions.BeforeAgentStartEvent{Prompt: prompt})
if r, ok := result.(extensions.BeforeAgentStartResult); ok {
if r.SystemPrompt != nil && *r.SystemPrompt != "" {
// Prepend a system message so the LLM sees extension-provided
// instructions. This supplements (not replaces) the agent's
// configured system prompt.
msgs = append([]fantasy.Message{fantasy.NewSystemMessage(*r.SystemPrompt)}, msgs...)
}
if r.InjectText != nil && *r.InjectText != "" {
// Prepend a user message with the injected context so it
// appears early in the conversation window.
msgs = append([]fantasy.Message{fantasy.NewUserMessage(*r.InjectText)}, msgs...)
}
}
}
// --- Extension: AgentStart ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentStart) {
_, _ = a.opts.Extensions.Emit(extensions.AgentStartEvent{Prompt: prompt})
}
// Signal spinner start.
sendFn(SpinnerEvent{Show: true})
// --- Extension: MessageStart ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageStart) {
_, _ = a.opts.Extensions.Emit(extensions.MessageStartEvent{})
}
result, err := a.opts.Agent.GenerateWithLoopAndStreaming(ctx, msgs,
// onToolCall
func(toolName, toolArgs string) {
@@ -416,14 +469,42 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
},
// onStreamingResponse — spinner keeps running alongside streaming text
func(chunk string) {
// Extension: MessageUpdate (observe streaming chunks)
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageUpdate) {
_, _ = a.opts.Extensions.Emit(extensions.MessageUpdateEvent{Chunk: chunk})
}
sendFn(StreamChunkEvent{Content: chunk})
},
)
if err != nil {
// --- Extension: AgentEnd with error ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) {
_, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{
Response: "",
StopReason: "error",
})
}
return nil, err
}
// --- Extension: MessageEnd ---
responseText := ""
if result.FinalResponse != nil {
responseText = result.FinalResponse.Content.Text()
}
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageEnd) {
_, _ = a.opts.Extensions.Emit(extensions.MessageEndEvent{Content: responseText})
}
// --- Extension: AgentEnd with success ---
if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) {
_, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{
Response: responseText,
StopReason: "completed",
})
}
// Replace the store with the full updated conversation returned by the agent
// (includes tool call/result messages added during the step).
a.store.Replace(result.ConversationMessages)
@@ -439,6 +520,17 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
return result, nil
}
// inputSource returns a string identifying how the current session receives
// input — used by the Input extension event.
func (a *App) inputSource() string {
a.mu.Lock()
defer a.mu.Unlock()
if a.program != nil {
return "interactive"
}
return "cli"
}
// --------------------------------------------------------------------------
// Internal: event helpers
// --------------------------------------------------------------------------
@@ -454,6 +546,45 @@ func (a *App) sendEvent(msg tea.Msg) {
}
}
// PrintFromExtension outputs text from an extension to the user. The level
// controls styling: "" for plain text, "info" for a system message block,
// "error" for an error block. In interactive mode it sends an
// ExtensionPrintEvent through the program so the TUI can render it with the
// appropriate renderer. In non-interactive mode it falls back to stdout.
func (a *App) PrintFromExtension(level, text string) {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
return
}
// Non-interactive fallback: write directly to stdout.
fmt.Println(text)
}
// PrintBlockFromExtension outputs a custom styled block from an extension.
func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(ExtensionPrintEvent{
Text: opts.Text,
Level: "block",
BorderColor: opts.BorderColor,
Subtitle: opts.Subtitle,
})
return
}
// Non-interactive fallback.
if opts.Subtitle != "" {
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
} else {
fmt.Println(opts.Text)
}
}
// updateUsage records token usage from a completed agent step into the configured
// UsageTracker (if any). It uses the actual token counts from the agent result's
// TotalUsage field when available; otherwise it falls back to text-based estimation.
+20
View File
@@ -96,3 +96,23 @@ type MessageCreatedEvent struct {
// Message is the fantasy message that was added to the store.
Message fantasy.Message
}
// ExtensionPrintEvent is sent when an extension calls ctx.Print, ctx.PrintInfo,
// ctx.PrintError, or ctx.PrintBlock. The TUI renders it via the appropriate
// renderer and tea.Println (scrollback); the CLI handler uses
// DisplayInfo/DisplayError or plain fmt.Println. This exists because BubbleTea
// captures stdout, so plain fmt.Println inside extensions would be swallowed.
type ExtensionPrintEvent struct {
// Text is the content the extension wants to display to the user.
Text string
// Level controls the rendering style:
// "" — plain text (no styling)
// "info" — system message block (bordered, themed)
// "error" — error block (red border, bold text)
// "block" — custom block with BorderColor and Subtitle
Level string
// BorderColor is a hex color (e.g. "#a6e3a1") for Level="block".
BorderColor string
// Subtitle is optional muted text below the content for Level="block".
Subtitle string
}
+7
View File
@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/session"
)
@@ -94,4 +95,10 @@ type Options struct {
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
// agent. Satisfied by *ui.UsageTracker; wired in cmd/root.go.
UsageTracker UsageUpdater
// Extensions is the optional extension runner. When non-nil, lifecycle
// events (Input, BeforeAgentStart, AgentEnd, etc.) are emitted through
// it. Tool-level events (ToolCall, ToolResult) are handled by wrapper.go
// at the tool layer, not here.
Extensions *extensions.Runner
}
+325
View File
@@ -0,0 +1,325 @@
package extensions
// ---------------------------------------------------------------------------
// Internal types (used by runner, NOT exposed to Yaegi)
// ---------------------------------------------------------------------------
// Event is the interface satisfied by all event types internally.
type Event interface {
Type() EventType
}
// Result is the interface satisfied by all result types internally.
type Result interface {
isResult()
}
// HandlerFunc is the internal handler signature used by the runner.
type HandlerFunc func(event Event, ctx Context) Result
// ---------------------------------------------------------------------------
// Context (exposed to Yaegi — concrete struct, no interfaces)
// ---------------------------------------------------------------------------
// Context provides runtime information to handlers about the current session.
type Context struct {
SessionID string
CWD string
Model string
Interactive bool
// Print outputs plain text to the user. In interactive mode this
// routes through BubbleTea's scrollback (tea.Println); in
// non-interactive mode it writes to stdout. Extensions must use
// this instead of fmt.Println, which is swallowed by BubbleTea.
Print func(string)
// PrintInfo outputs text as a styled system message block (bordered,
// themed). Use this for informational notices the user should see.
PrintInfo func(string)
// PrintError outputs text as a styled error block (red border, bold).
// Use this for error messages or warnings.
PrintError func(string)
// PrintBlock outputs text as a custom styled block with caller-chosen
// border color and optional subtitle. Example:
//
// ctx.PrintBlock(ext.PrintBlockOpts{
// Text: "Deployment complete!",
// BorderColor: "#a6e3a1",
// Subtitle: "my-extension",
// })
PrintBlock func(PrintBlockOpts)
}
// PrintBlockOpts configures a custom styled block for PrintBlock.
type PrintBlockOpts struct {
// Text is the main content to display.
Text string
// BorderColor is a hex color string (e.g. "#a6e3a1") for the left border.
// Defaults to the theme's system color if empty.
BorderColor string
// Subtitle is optional text shown below the content in muted style
// (e.g. extension name, timestamp). Empty means no subtitle line.
Subtitle string
}
// ---------------------------------------------------------------------------
// API — the object passed to each extension's Init function.
//
// Instead of a generic On(EventType, HandlerFunc) that uses interfaces,
// we expose event-specific methods with concrete function signatures.
// This avoids Yaegi's genInterfaceWrapper crash entirely — no interfaces
// cross the Yaegi boundary.
// ---------------------------------------------------------------------------
// API is passed to each extension's Init function. Extensions use it to
// register typed event handlers, custom tools, and slash commands.
type API struct {
// Event-specific registration functions (wired by the loader).
onToolCall func(func(ToolCallEvent, Context) *ToolCallResult)
onToolExecStart func(func(ToolExecutionStartEvent, Context))
onToolExecEnd func(func(ToolExecutionEndEvent, Context))
onToolResult func(func(ToolResultEvent, Context) *ToolResultResult)
onInput func(func(InputEvent, Context) *InputResult)
onBeforeAgentStart func(func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult)
onAgentStart func(func(AgentStartEvent, Context))
onAgentEnd func(func(AgentEndEvent, Context))
onMessageStart func(func(MessageStartEvent, Context))
onMessageUpdate func(func(MessageUpdateEvent, Context))
onMessageEnd func(func(MessageEndEvent, Context))
onSessionStart func(func(SessionStartEvent, Context))
onSessionShutdown func(func(SessionShutdownEvent, Context))
registerToolFn func(ToolDef)
registerCmdFn func(CommandDef)
}
// OnToolCall registers a handler that fires before a tool executes.
// Return a non-nil ToolCallResult with Block=true to prevent execution.
func (a *API) OnToolCall(handler func(ToolCallEvent, Context) *ToolCallResult) {
a.onToolCall(handler)
}
// OnToolExecutionStart registers a handler for tool execution start.
func (a *API) OnToolExecutionStart(handler func(ToolExecutionStartEvent, Context)) {
a.onToolExecStart(handler)
}
// OnToolExecutionEnd registers a handler for tool execution end.
func (a *API) OnToolExecutionEnd(handler func(ToolExecutionEndEvent, Context)) {
a.onToolExecEnd(handler)
}
// OnToolResult registers a handler that fires after tool execution.
// Return a non-nil ToolResultResult to modify the output.
func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultResult) {
a.onToolResult(handler)
}
// OnInput registers a handler that fires when user input is received.
// Return a non-nil InputResult to transform or handle the input.
func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) {
a.onInput(handler)
}
// OnBeforeAgentStart registers a handler that fires before the agent loop.
func (a *API) OnBeforeAgentStart(handler func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
a.onBeforeAgentStart(handler)
}
// OnAgentStart registers a handler for when the agent loop begins.
func (a *API) OnAgentStart(handler func(AgentStartEvent, Context)) {
a.onAgentStart(handler)
}
// OnAgentEnd registers a handler for when the agent finishes responding.
func (a *API) OnAgentEnd(handler func(AgentEndEvent, Context)) {
a.onAgentEnd(handler)
}
// OnMessageStart registers a handler for when an assistant message begins.
func (a *API) OnMessageStart(handler func(MessageStartEvent, Context)) {
a.onMessageStart(handler)
}
// OnMessageUpdate registers a handler for streaming text chunks.
func (a *API) OnMessageUpdate(handler func(MessageUpdateEvent, Context)) {
a.onMessageUpdate(handler)
}
// OnMessageEnd registers a handler for when the assistant message is complete.
func (a *API) OnMessageEnd(handler func(MessageEndEvent, Context)) {
a.onMessageEnd(handler)
}
// OnSessionStart registers a handler for when a session is loaded or created.
func (a *API) OnSessionStart(handler func(SessionStartEvent, Context)) {
a.onSessionStart(handler)
}
// OnSessionShutdown registers a handler for when the application is closing.
func (a *API) OnSessionShutdown(handler func(SessionShutdownEvent, Context)) {
a.onSessionShutdown(handler)
}
// RegisterTool adds a custom tool that the LLM can invoke.
func (a *API) RegisterTool(tool ToolDef) {
a.registerToolFn(tool)
}
// RegisterCommand adds a slash command available in interactive mode.
func (a *API) RegisterCommand(cmd CommandDef) {
a.registerCmdFn(cmd)
}
// ---------------------------------------------------------------------------
// ToolDef / CommandDef
// ---------------------------------------------------------------------------
// ToolDef describes a custom tool registered by an extension.
type ToolDef struct {
Name string
Description string
Parameters string // JSON Schema string
Execute func(input string) (string, error)
}
// CommandDef describes a slash command registered by an extension.
type CommandDef struct {
Name string
Description string
Execute func(args string) (string, error)
}
// ---------------------------------------------------------------------------
// Typed events (all concrete structs — safe for Yaegi)
// ---------------------------------------------------------------------------
// ToolCallEvent fires before a tool executes.
type ToolCallEvent struct {
ToolName string
ToolCallID string
Input string // JSON-encoded tool parameters
}
func (e ToolCallEvent) Type() EventType { return ToolCall }
// ToolCallResult controls whether the tool call proceeds.
type ToolCallResult struct {
Block bool
Reason string
}
func (ToolCallResult) isResult() {}
// ToolExecutionStartEvent fires when a tool begins executing.
type ToolExecutionStartEvent struct {
ToolName string
}
func (e ToolExecutionStartEvent) Type() EventType { return ToolExecutionStart }
// ToolExecutionEndEvent fires when a tool finishes executing.
type ToolExecutionEndEvent struct {
ToolName string
}
func (e ToolExecutionEndEvent) Type() EventType { return ToolExecutionEnd }
// ToolResultEvent fires after tool execution with the output.
type ToolResultEvent struct {
ToolName string
Input string
Content string
IsError bool
}
func (e ToolResultEvent) Type() EventType { return ToolResult }
// ToolResultResult can modify the tool's output before it reaches the LLM.
type ToolResultResult struct {
Content *string // nil = unchanged
IsError *bool // nil = unchanged
}
func (ToolResultResult) isResult() {}
// InputEvent fires when user input is received.
type InputEvent struct {
Text string
Source string // "interactive", "cli", "script", "queue"
}
func (e InputEvent) Type() EventType { return Input }
// InputResult controls what happens with user input.
//
// Action: "continue" (default), "transform", "handled"
type InputResult struct {
Action string
Text string // replacement text when Action="transform"
}
func (InputResult) isResult() {}
// BeforeAgentStartEvent fires before the agent loop begins.
type BeforeAgentStartEvent struct {
Prompt string
}
func (e BeforeAgentStartEvent) Type() EventType { return BeforeAgentStart }
// BeforeAgentStartResult can inject context before the agent runs.
type BeforeAgentStartResult struct {
InjectText *string
SystemPrompt *string
}
func (BeforeAgentStartResult) isResult() {}
// AgentStartEvent fires when the agent loop begins.
type AgentStartEvent struct {
Prompt string
}
func (e AgentStartEvent) Type() EventType { return AgentStart }
// AgentEndEvent fires when the agent finishes responding.
type AgentEndEvent struct {
Response string
StopReason string // "completed", "cancelled", "error"
}
func (e AgentEndEvent) Type() EventType { return AgentEnd }
// MessageStartEvent fires when a new assistant message begins.
type MessageStartEvent struct{}
func (e MessageStartEvent) Type() EventType { return MessageStart }
// MessageUpdateEvent fires for each streaming text chunk.
type MessageUpdateEvent struct {
Chunk string
}
func (e MessageUpdateEvent) Type() EventType { return MessageUpdate }
// MessageEndEvent fires when the assistant message is complete.
type MessageEndEvent struct {
Content string
}
func (e MessageEndEvent) Type() EventType { return MessageEnd }
// SessionStartEvent fires when a session is loaded or created.
type SessionStartEvent struct {
SessionID string
}
func (e SessionStartEvent) Type() EventType { return SessionStart }
// SessionShutdownEvent fires when the application is closing.
type SessionShutdownEvent struct{}
func (e SessionShutdownEvent) Type() EventType { return SessionShutdown }
+111
View File
@@ -0,0 +1,111 @@
package extensions
import (
"context"
"encoding/json"
"github.com/mark3labs/kit/internal/hooks"
)
// HooksAsExtension wraps an existing hooks.HookConfig as a LoadedExtension
// so that legacy .kit/hooks.yml configurations continue to work alongside
// the new Yaegi extension system. The adapter translates the old event names
// and shell-command execution model into extension HandlerFunc handlers.
func HooksAsExtension(config *hooks.HookConfig) *LoadedExtension {
if config == nil || len(config.Hooks) == 0 {
return nil
}
ext := &LoadedExtension{
Path: "hooks.yml (compat)",
Handlers: make(map[EventType][]HandlerFunc),
}
executor := hooks.NewExecutor(config, "", "")
// Map PreToolUse → ToolCall
if matchers, ok := config.Hooks[hooks.PreToolUse]; ok && len(matchers) > 0 {
ext.Handlers[ToolCall] = []HandlerFunc{
func(event Event, _ Context) Result {
tc, ok := event.(ToolCallEvent)
if !ok {
return nil
}
input := &hooks.PreToolUseInput{
ToolName: tc.ToolName,
ToolInput: json.RawMessage(tc.Input),
}
output, err := executor.ExecuteHooks(context.Background(), hooks.PreToolUse, input)
if err != nil || output == nil {
return nil
}
if output.Decision == "block" {
return ToolCallResult{Block: true, Reason: output.Reason}
}
return nil
},
}
}
// Map PostToolUse → ToolResult
if matchers, ok := config.Hooks[hooks.PostToolUse]; ok && len(matchers) > 0 {
ext.Handlers[ToolResult] = []HandlerFunc{
func(event Event, _ Context) Result {
tr, ok := event.(ToolResultEvent)
if !ok {
return nil
}
input := &hooks.PostToolUseInput{
ToolName: tr.ToolName,
ToolInput: json.RawMessage(tr.Input),
ToolResponse: json.RawMessage(tr.Content),
}
_, _ = executor.ExecuteHooks(context.Background(), hooks.PostToolUse, input)
return nil // legacy hooks don't modify results
},
}
}
// Map UserPromptSubmit → Input
if matchers, ok := config.Hooks[hooks.UserPromptSubmit]; ok && len(matchers) > 0 {
ext.Handlers[Input] = []HandlerFunc{
func(event Event, _ Context) Result {
ie, ok := event.(InputEvent)
if !ok {
return nil
}
input := &hooks.UserPromptSubmitInput{
Prompt: ie.Text,
}
output, err := executor.ExecuteHooks(context.Background(), hooks.UserPromptSubmit, input)
if err != nil || output == nil {
return nil
}
if output.Decision == "block" {
return InputResult{Action: "handled"}
}
return nil
},
}
}
// Map Stop → AgentEnd
if matchers, ok := config.Hooks[hooks.Stop]; ok && len(matchers) > 0 {
ext.Handlers[AgentEnd] = []HandlerFunc{
func(event Event, _ Context) Result {
ae, ok := event.(AgentEndEvent)
if !ok {
return nil
}
input := &hooks.StopInput{
Response: ae.Response,
StopReason: ae.StopReason,
}
_, _ = executor.ExecuteHooks(context.Background(), hooks.Stop, input)
return nil
},
}
}
return ext
}
+69
View File
@@ -0,0 +1,69 @@
// Package extensions implements a Pi-style in-process extension system for KIT.
// Extensions are plain Go files loaded at runtime via Yaegi (a Go interpreter).
// They register event handlers using an API object, enabling tool interception,
// input transformation, and lifecycle observation — all without recompilation.
package extensions
// EventType identifies a point in KIT's lifecycle where extensions can hook in.
type EventType string
const (
// ToolCall fires before a tool executes. Handlers can block execution.
ToolCall EventType = "tool_call"
// ToolExecutionStart fires when a tool begins executing.
ToolExecutionStart EventType = "tool_execution_start"
// ToolExecutionEnd fires when a tool finishes executing.
ToolExecutionEnd EventType = "tool_execution_end"
// ToolResult fires after a tool executes. Handlers can modify the result.
ToolResult EventType = "tool_result"
// Input fires when user input is received. Handlers can transform or handle it.
Input EventType = "input"
// BeforeAgentStart fires before the agent loop begins for a prompt.
BeforeAgentStart EventType = "before_agent_start"
// AgentStart fires when the agent loop begins processing.
AgentStart EventType = "agent_start"
// AgentEnd fires when the agent finishes responding.
AgentEnd EventType = "agent_end"
// MessageStart fires when a new assistant message begins.
MessageStart EventType = "message_start"
// MessageUpdate fires for each streaming text chunk.
MessageUpdate EventType = "message_update"
// MessageEnd fires when the assistant message is complete.
MessageEnd EventType = "message_end"
// SessionStart fires when a session is loaded or created.
SessionStart EventType = "session_start"
// SessionShutdown fires when the application is closing.
SessionShutdown EventType = "session_shutdown"
)
// AllEventTypes returns every supported event type.
func AllEventTypes() []EventType {
return []EventType{
ToolCall, ToolExecutionStart, ToolExecutionEnd, ToolResult,
Input, BeforeAgentStart, AgentStart, AgentEnd,
MessageStart, MessageUpdate, MessageEnd,
SessionStart, SessionShutdown,
}
}
// IsValid returns true if the event type is a recognised lifecycle event.
func (e EventType) IsValid() bool {
for _, valid := range AllEventTypes() {
if e == valid {
return true
}
}
return false
}
+60
View File
@@ -0,0 +1,60 @@
package extensions
import "testing"
func TestAllEventTypes_Count(t *testing.T) {
all := AllEventTypes()
if len(all) != 13 {
t.Fatalf("expected 13 event types, got %d", len(all))
}
}
func TestAllEventTypes_NoDuplicates(t *testing.T) {
seen := make(map[EventType]bool)
for _, et := range AllEventTypes() {
if seen[et] {
t.Fatalf("duplicate event type: %s", et)
}
seen[et] = true
}
}
func TestEventType_IsValid(t *testing.T) {
for _, et := range AllEventTypes() {
if !et.IsValid() {
t.Errorf("expected %s to be valid", et)
}
}
invalid := EventType("nonexistent_event")
if invalid.IsValid() {
t.Error("expected 'nonexistent_event' to be invalid")
}
}
func TestEventType_TypeMethod(t *testing.T) {
tests := []struct {
event Event
want EventType
}{
{ToolCallEvent{ToolName: "test"}, ToolCall},
{ToolExecutionStartEvent{ToolName: "test"}, ToolExecutionStart},
{ToolExecutionEndEvent{ToolName: "test"}, ToolExecutionEnd},
{ToolResultEvent{ToolName: "test"}, ToolResult},
{InputEvent{Text: "hello"}, Input},
{BeforeAgentStartEvent{Prompt: "test"}, BeforeAgentStart},
{AgentStartEvent{Prompt: "test"}, AgentStart},
{AgentEndEvent{Response: "done"}, AgentEnd},
{MessageStartEvent{}, MessageStart},
{MessageUpdateEvent{Chunk: "hi"}, MessageUpdate},
{MessageEndEvent{Content: "done"}, MessageEnd},
{SessionStartEvent{SessionID: "abc"}, SessionStart},
{SessionShutdownEvent{}, SessionShutdown},
}
for _, tt := range tests {
if got := tt.event.Type(); got != tt.want {
t.Errorf("event %T.Type() = %s, want %s", tt.event, got, tt.want)
}
}
}
+301
View File
@@ -0,0 +1,301 @@
package extensions
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
"github.com/traefik/yaegi/interp"
"github.com/traefik/yaegi/stdlib"
)
// Discovery paths searched in order (lowest to highest precedence):
//
// ~/.config/kit/extensions/*.go global single files
// ~/.config/kit/extensions/*/main.go global subdirectories
// .kit/extensions/*.go project-local single files
// .kit/extensions/*/main.go project-local subdirectories
//
// Explicit paths passed via --extension / -e flags are appended last.
// LoadExtensions discovers and loads extensions from standard locations and
// any extra paths. Each extension is loaded into its own Yaegi interpreter
// for isolation. Extensions that fail to load are logged and skipped.
func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
paths := discoverExtensionPaths(extraPaths)
if len(paths) == 0 {
return nil, nil
}
var loaded []LoadedExtension
for _, p := range paths {
ext, err := loadSingleExtension(p)
if err != nil {
log.Warn("skipping extension", "path", p, "err", err)
continue
}
loaded = append(loaded, *ext)
log.Debug("loaded extension", "path", p,
"handlers", countHandlers(ext),
"tools", len(ext.Tools),
"commands", len(ext.Commands))
}
return loaded, nil
}
// discoverExtensionPaths returns deduplicated paths to extension files in
// load-order (global first, then project-local, then explicit).
func discoverExtensionPaths(extraPaths []string) []string {
seen := make(map[string]bool)
var paths []string
add := func(p string) {
abs, err := filepath.Abs(p)
if err != nil {
return
}
if seen[abs] {
return
}
seen[abs] = true
paths = append(paths, abs)
}
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
globalDir := globalExtensionsDir()
for _, p := range findExtensionsInDir(globalDir) {
add(p)
}
// Project-local extensions: .kit/extensions/
localDir := filepath.Join(".kit", "extensions")
for _, p := range findExtensionsInDir(localDir) {
add(p)
}
// Explicit paths (highest precedence)
for _, p := range extraPaths {
info, err := os.Stat(p)
if err != nil {
continue
}
if info.IsDir() {
for _, found := range findExtensionsInDir(p) {
add(found)
}
} else if strings.HasSuffix(p, ".go") {
add(p)
}
}
return paths
}
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
func findExtensionsInDir(dir string) []string {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return nil
}
var results []string
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
for _, entry := range entries {
full := filepath.Join(dir, entry.Name())
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
results = append(results, full)
} else if entry.IsDir() {
main := filepath.Join(full, "main.go")
if _, err := os.Stat(main); err == nil {
results = append(results, main)
}
}
}
return results
}
// globalExtensionsDir returns the global extensions directory, respecting
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/extensions.
func globalExtensionsDir() string {
base := os.Getenv("XDG_CONFIG_HOME")
if base == "" {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
base = filepath.Join(home, ".config")
}
return filepath.Join(base, "kit", "extensions")
}
// loadSingleExtension loads one .go file into a fresh Yaegi interpreter,
// calls the Init(ext.API) function, and returns the registered handlers.
func loadSingleExtension(path string) (*LoadedExtension, error) {
ext := &LoadedExtension{
Path: path,
Handlers: make(map[EventType][]HandlerFunc),
}
// Create a fresh interpreter.
i := interp.New(interp.Options{})
// Expose a safe subset of the Go stdlib.
if err := i.Use(stdlib.Symbols); err != nil {
return nil, fmt.Errorf("loading stdlib symbols: %w", err)
}
// Expose KIT's extension API types so the extension can
// import "kit/ext" and use ext.ToolCall, ext.API, etc.
if err := i.Use(Symbols()); err != nil {
return nil, fmt.Errorf("loading extension symbols: %w", err)
}
// Read and evaluate the extension source file.
src, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading file: %w", err)
}
if _, err := i.Eval(string(src)); err != nil {
return nil, fmt.Errorf("evaluating source: %w", err)
}
// Extract the Init function. Extensions must export:
// func Init(api ext.API)
initVal, err := i.Eval("Init")
if err != nil {
return nil, fmt.Errorf("no Init function: %w", err)
}
initFn, ok := initVal.Interface().(func(API))
if !ok {
return nil, fmt.Errorf("Init has wrong signature (want func(ext.API), got %T)", initVal.Interface())
}
// Build the API object that wires typed registration methods back to
// the extension's internal handler map. Each method wraps the concrete
// handler into the internal HandlerFunc type.
reg := func(event EventType, fn HandlerFunc) {
ext.Handlers[event] = append(ext.Handlers[event], fn)
}
api := API{
onToolCall: func(h func(ToolCallEvent, Context) *ToolCallResult) {
reg(ToolCall, func(e Event, c Context) Result {
r := h(e.(ToolCallEvent), c)
if r == nil {
return nil
}
return *r
})
},
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
reg(ToolExecutionStart, func(e Event, c Context) Result {
h(e.(ToolExecutionStartEvent), c)
return nil
})
},
onToolExecEnd: func(h func(ToolExecutionEndEvent, Context)) {
reg(ToolExecutionEnd, func(e Event, c Context) Result {
h(e.(ToolExecutionEndEvent), c)
return nil
})
},
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
reg(ToolResult, func(e Event, c Context) Result {
r := h(e.(ToolResultEvent), c)
if r == nil {
return nil
}
return *r
})
},
onInput: func(h func(InputEvent, Context) *InputResult) {
reg(Input, func(e Event, c Context) Result {
r := h(e.(InputEvent), c)
if r == nil {
return nil
}
return *r
})
},
onBeforeAgentStart: func(h func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
reg(BeforeAgentStart, func(e Event, c Context) Result {
r := h(e.(BeforeAgentStartEvent), c)
if r == nil {
return nil
}
return *r
})
},
onAgentStart: func(h func(AgentStartEvent, Context)) {
reg(AgentStart, func(e Event, c Context) Result {
h(e.(AgentStartEvent), c)
return nil
})
},
onAgentEnd: func(h func(AgentEndEvent, Context)) {
reg(AgentEnd, func(e Event, c Context) Result {
h(e.(AgentEndEvent), c)
return nil
})
},
onMessageStart: func(h func(MessageStartEvent, Context)) {
reg(MessageStart, func(e Event, c Context) Result {
h(e.(MessageStartEvent), c)
return nil
})
},
onMessageUpdate: func(h func(MessageUpdateEvent, Context)) {
reg(MessageUpdate, func(e Event, c Context) Result {
h(e.(MessageUpdateEvent), c)
return nil
})
},
onMessageEnd: func(h func(MessageEndEvent, Context)) {
reg(MessageEnd, func(e Event, c Context) Result {
h(e.(MessageEndEvent), c)
return nil
})
},
onSessionStart: func(h func(SessionStartEvent, Context)) {
reg(SessionStart, func(e Event, c Context) Result {
h(e.(SessionStartEvent), c)
return nil
})
},
onSessionShutdown: func(h func(SessionShutdownEvent, Context)) {
reg(SessionShutdown, func(e Event, c Context) Result {
h(e.(SessionShutdownEvent), c)
return nil
})
},
registerToolFn: func(tool ToolDef) {
ext.Tools = append(ext.Tools, tool)
},
registerCmdFn: func(cmd CommandDef) {
ext.Commands = append(ext.Commands, cmd)
},
}
// Call Init — the extension registers its handlers, tools, commands.
initFn(api)
return ext, nil
}
// countHandlers returns the total number of registered handlers across all events.
func countHandlers(ext *LoadedExtension) int {
n := 0
for _, handlers := range ext.Handlers {
n += len(handlers)
}
return n
}
+604
View File
@@ -0,0 +1,604 @@
package extensions
import (
"os"
"path/filepath"
"testing"
)
func TestDiscoverExtensionPaths_ExplicitFile(t *testing.T) {
// Create a temp dir with a .go file.
dir := t.TempDir()
f := filepath.Join(dir, "my-ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{f})
if len(paths) == 0 {
t.Fatal("expected at least 1 path")
}
abs, _ := filepath.Abs(f)
found := false
for _, p := range paths {
if p == abs {
found = true
break
}
}
if !found {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_ExplicitDir(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{dir})
abs, _ := filepath.Abs(f)
found := false
for _, p := range paths {
if p == abs {
found = true
break
}
}
if !found {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_SubdirMainGo(t *testing.T) {
dir := t.TempDir()
subdir := filepath.Join(dir, "my-plugin")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
main := filepath.Join(subdir, "main.go")
if err := os.WriteFile(main, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{dir})
abs, _ := filepath.Abs(main)
found := false
for _, p := range paths {
if p == abs {
found = true
break
}
}
if !found {
t.Errorf("expected %q in discovered paths %v", abs, paths)
}
}
func TestDiscoverExtensionPaths_Dedup(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "ext.go")
if err := os.WriteFile(f, []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// Pass the same file twice.
paths := discoverExtensionPaths([]string{f, f})
count := 0
abs, _ := filepath.Abs(f)
for _, p := range paths {
if p == abs {
count++
}
}
if count != 1 {
t.Errorf("expected dedup to 1, got %d", count)
}
}
func TestDiscoverExtensionPaths_NonGoFileIgnored(t *testing.T) {
dir := t.TempDir()
f := filepath.Join(dir, "readme.txt")
if err := os.WriteFile(f, []byte("hello"), 0644); err != nil {
t.Fatal(err)
}
paths := discoverExtensionPaths([]string{f})
for _, p := range paths {
abs, _ := filepath.Abs(f)
if p == abs {
t.Error("non-.go file should not be discovered")
}
}
}
func TestDiscoverExtensionPaths_NonexistentIgnored(t *testing.T) {
paths := discoverExtensionPaths([]string{"/nonexistent/path/ext.go"})
for _, p := range paths {
if p == "/nonexistent/path/ext.go" {
t.Error("nonexistent path should not be discovered")
}
}
}
func TestFindExtensionsInDir_EmptyDir(t *testing.T) {
dir := t.TempDir()
results := findExtensionsInDir(dir)
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
}
func TestFindExtensionsInDir_NonexistentDir(t *testing.T) {
results := findExtensionsInDir("/nonexistent/dir")
if len(results) != 0 {
t.Errorf("expected 0 results, got %d", len(results))
}
}
func TestFindExtensionsInDir_MixedContent(t *testing.T) {
dir := t.TempDir()
// .go file at top level
if err := os.WriteFile(filepath.Join(dir, "ext.go"), []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// non-.go file (should be ignored)
if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0644); err != nil {
t.Fatal(err)
}
// subdir with main.go
sub := filepath.Join(dir, "plugin")
if err := os.MkdirAll(sub, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(sub, "main.go"), []byte("package main"), 0644); err != nil {
t.Fatal(err)
}
// subdir without main.go (should be ignored)
empty := filepath.Join(dir, "empty")
if err := os.MkdirAll(empty, 0755); err != nil {
t.Fatal(err)
}
results := findExtensionsInDir(dir)
if len(results) != 2 {
t.Fatalf("expected 2 results, got %d: %v", len(results), results)
}
}
func TestLoadSingleExtension_ValidExtension(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
return nil
})
api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) {
})
}
`
f := filepath.Join(dir, "valid.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if ext.Path != f {
t.Errorf("expected path %q, got %q", f, ext.Path)
}
if len(ext.Handlers[ToolCall]) != 1 {
t.Errorf("expected 1 ToolCall handler, got %d", len(ext.Handlers[ToolCall]))
}
if len(ext.Handlers[SessionStart]) != 1 {
t.Errorf("expected 1 SessionStart handler, got %d", len(ext.Handlers[SessionStart]))
}
}
func TestLoadSingleExtension_NoInitFunction(t *testing.T) {
dir := t.TempDir()
src := `package main
func Hello() string { return "hi" }
`
f := filepath.Join(dir, "noinit.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for missing Init function")
}
}
func TestLoadSingleExtension_SyntaxError(t *testing.T) {
dir := t.TempDir()
src := `package main
func Init( { broken }
`
f := filepath.Join(dir, "broken.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for syntax error")
}
}
func TestLoadSingleExtension_WrongSignature(t *testing.T) {
dir := t.TempDir()
src := `package main
func Init(s string) {}
`
f := filepath.Join(dir, "wrongsig.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
_, err := loadSingleExtension(f)
if err == nil {
t.Fatal("expected error for wrong Init signature")
}
}
func TestLoadSingleExtension_RegistersTool(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.RegisterTool(ext.ToolDef{
Name: "my_tool",
Description: "does stuff",
Parameters: "{\"type\":\"object\"}",
Execute: func(input string) (string, error) {
return "result: " + input, nil
},
})
}
`
f := filepath.Join(dir, "toolreg.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if len(ext.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(ext.Tools))
}
if ext.Tools[0].Name != "my_tool" {
t.Errorf("expected tool name 'my_tool', got %q", ext.Tools[0].Name)
}
}
func TestLoadSingleExtension_RegistersCommand(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.RegisterCommand(ext.CommandDef{
Name: "hello",
Description: "says hello",
Execute: func(args string) (string, error) {
return "hello " + args, nil
},
})
}
`
f := filepath.Join(dir, "cmdreg.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
if len(ext.Commands) != 1 {
t.Fatalf("expected 1 command, got %d", len(ext.Commands))
}
if ext.Commands[0].Name != "hello" {
t.Errorf("expected command name 'hello', got %q", ext.Commands[0].Name)
}
}
func TestLoadExtensions_SkipsBadFiles(t *testing.T) {
dir := t.TempDir()
// Good extension
good := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnSessionStart(func(_ ext.SessionStartEvent, _ ext.Context) {})
}
`
if err := os.WriteFile(filepath.Join(dir, "good.go"), []byte(good), 0644); err != nil {
t.Fatal(err)
}
// Bad extension (syntax error)
bad := `package main
func Init( { broken }
`
if err := os.WriteFile(filepath.Join(dir, "bad.go"), []byte(bad), 0644); err != nil {
t.Fatal(err)
}
loaded, err := LoadExtensions([]string{dir})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have loaded the good one and skipped the bad one.
if len(loaded) != 1 {
t.Fatalf("expected 1 loaded extension, got %d", len(loaded))
}
}
func TestLoadSingleExtension_HandlerExecution(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
if tc.ToolName == "banned" {
return &ext.ToolCallResult{Block: true, Reason: "tool is banned"}
}
return nil
})
}
`
f := filepath.Join(dir, "blocker.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
// Build a runner and test the handler actually works.
r := NewRunner([]LoadedExtension{*ext})
result, err := r.Emit(ToolCallEvent{ToolName: "banned", Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tcr, ok := result.(ToolCallResult)
if !ok {
t.Fatalf("expected ToolCallResult, got %T", result)
}
if !tcr.Block {
t.Error("expected Block=true for banned tool")
}
if tcr.Reason != "tool is banned" {
t.Errorf("expected reason 'tool is banned', got %q", tcr.Reason)
}
// Non-banned tool should pass through.
result2, _ := r.Emit(ToolCallEvent{ToolName: "allowed", Input: "{}"})
if result2 != nil {
t.Errorf("expected nil result for allowed tool, got %v", result2)
}
}
func TestGlobalExtensionsDir_XDG(t *testing.T) {
// Save and restore XDG_CONFIG_HOME.
orig := os.Getenv("XDG_CONFIG_HOME")
defer os.Setenv("XDG_CONFIG_HOME", orig)
os.Setenv("XDG_CONFIG_HOME", "/custom/config")
dir := globalExtensionsDir()
expected := "/custom/config/kit/extensions"
if dir != expected {
t.Errorf("expected %q, got %q", expected, dir)
}
}
func TestGlobalExtensionsDir_Default(t *testing.T) {
orig := os.Getenv("XDG_CONFIG_HOME")
defer os.Setenv("XDG_CONFIG_HOME", orig)
os.Setenv("XDG_CONFIG_HOME", "")
dir := globalExtensionsDir()
home, _ := os.UserHomeDir()
expected := filepath.Join(home, ".config", "kit", "extensions")
if dir != expected {
t.Errorf("expected %q, got %q", expected, dir)
}
}
func TestLoadSingleExtension_ContextPrint(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!hello" && ctx.Print != nil {
ctx.Print("Hello from extension!")
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "printer.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
// Wire up a Print function and verify it's called.
var printed []string
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
Print: func(text string) {
printed = append(printed, text)
},
})
result, err := r.Emit(InputEvent{Text: "!hello", Source: "interactive"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ir, ok := result.(InputResult)
if !ok {
t.Fatalf("expected InputResult, got %T", result)
}
if ir.Action != "handled" {
t.Errorf("expected Action 'handled', got %q", ir.Action)
}
if len(printed) != 1 || printed[0] != "Hello from extension!" {
t.Errorf("expected Print to capture 'Hello from extension!', got %v", printed)
}
}
func TestLoadSingleExtension_ContextPrintInfo(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!info" && ctx.PrintInfo != nil {
ctx.PrintInfo("Styled info from extension")
return &ext.InputResult{Action: "handled"}
}
if ie.Text == "!error" && ctx.PrintError != nil {
ctx.PrintError("Styled error from extension")
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "styled.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
var infos, errors []string
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
PrintInfo: func(text string) { infos = append(infos, text) },
PrintError: func(text string) { errors = append(errors, text) },
})
result, _ := r.Emit(InputEvent{Text: "!info"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !info")
}
if len(infos) != 1 || infos[0] != "Styled info from extension" {
t.Errorf("expected PrintInfo capture, got %v", infos)
}
result, _ = r.Emit(InputEvent{Text: "!error"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !error")
}
if len(errors) != 1 || errors[0] != "Styled error from extension" {
t.Errorf("expected PrintError capture, got %v", errors)
}
}
func TestLoadSingleExtension_ContextPrintBlock(t *testing.T) {
dir := t.TempDir()
src := `package main
import "kit/ext"
func Init(api ext.API) {
api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult {
if ie.Text == "!status" && ctx.PrintBlock != nil {
ctx.PrintBlock(ext.PrintBlockOpts{
Text: "All systems go\nModel: " + ctx.Model,
BorderColor: "#a6e3a1",
Subtitle: "test-ext",
})
return &ext.InputResult{Action: "handled"}
}
return nil
})
}
`
f := filepath.Join(dir, "block.go")
if err := os.WriteFile(f, []byte(src), 0644); err != nil {
t.Fatal(err)
}
ext, err := loadSingleExtension(f)
if err != nil {
t.Fatalf("failed to load extension: %v", err)
}
var captured []PrintBlockOpts
r := NewRunner([]LoadedExtension{*ext})
r.SetContext(Context{
Model: "claude-4",
PrintBlock: func(opts PrintBlockOpts) {
captured = append(captured, opts)
},
})
result, _ := r.Emit(InputEvent{Text: "!status", Source: "interactive"})
if ir, ok := result.(InputResult); !ok || ir.Action != "handled" {
t.Fatal("expected handled result for !status")
}
if len(captured) != 1 {
t.Fatalf("expected 1 PrintBlock call, got %d", len(captured))
}
if captured[0].BorderColor != "#a6e3a1" {
t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor)
}
if captured[0].Subtitle != "test-ext" {
t.Errorf("expected subtitle 'test-ext', got %q", captured[0].Subtitle)
}
// Verify the text includes the model from context.
if captured[0].Text != "All systems go\nModel: claude-4" {
t.Errorf("unexpected text: %q", captured[0].Text)
}
}
func TestCountHandlers(t *testing.T) {
ext := &LoadedExtension{
Handlers: map[EventType][]HandlerFunc{
ToolCall: {func(Event, Context) Result { return nil }, func(Event, Context) Result { return nil }},
SessionStart: {func(Event, Context) Result { return nil }},
},
}
if n := countHandlers(ext); n != 3 {
t.Errorf("expected 3 handlers, got %d", n)
}
}
+146
View File
@@ -0,0 +1,146 @@
package extensions
import (
"fmt"
"sync"
"github.com/charmbracelet/log"
)
// Runner manages loaded extensions and dispatches events to their handlers
// sequentially, mirroring Pi's ExtensionRunner. Handlers execute in extension
// load order; for cancellable events the first blocking result wins.
type Runner struct {
extensions []LoadedExtension
ctx Context
mu sync.RWMutex
}
// LoadedExtension represents a single extension that has been discovered,
// loaded, and initialised. It holds the registered handlers and any custom
// tools or commands the extension provided.
type LoadedExtension struct {
Path string
Handlers map[EventType][]HandlerFunc
Tools []ToolDef
Commands []CommandDef
}
// NewRunner creates a Runner from a set of loaded extensions.
func NewRunner(exts []LoadedExtension) *Runner {
return &Runner{extensions: exts}
}
// SetContext updates the runtime context (session ID, model, etc.) that is
// passed to every handler invocation. Thread-safe.
func (r *Runner) SetContext(ctx Context) {
r.mu.Lock()
defer r.mu.Unlock()
r.ctx = ctx
}
// HasHandlers returns true if any loaded extension has at least one handler
// registered for the given event type.
func (r *Runner) HasHandlers(event EventType) bool {
for i := range r.extensions {
if len(r.extensions[i].Handlers[event]) > 0 {
return true
}
}
return false
}
// Emit dispatches an event to all matching handlers sequentially. It returns
// the accumulated result from all handlers, or nil if no handler responded.
//
// For blocking events (ToolCall, Input), the first blocking result short-circuits:
// - ToolCallResult{Block: true} stops iteration and returns immediately.
// - InputResult{Action: "handled"} stops iteration and returns immediately.
//
// For chainable events (ToolResult), each handler sees the accumulated result
// from previous handlers. The final merged result is returned.
//
// Panics in handlers are recovered and logged; they do not crash the process.
func (r *Runner) Emit(event Event) (Result, error) {
r.mu.RLock()
ctx := r.ctx
r.mu.RUnlock()
var accumulated Result
for i := range r.extensions {
ext := &r.extensions[i]
handlers := ext.Handlers[event.Type()]
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
log.Warn("extension handler error",
"path", ext.Path,
"event", event.Type(),
"err", err)
continue
}
if result == nil {
continue
}
// Check for blocking/short-circuit results.
if isBlocking(result) {
return result, nil
}
// Chain: keep the latest non-nil result. For ToolResultResult
// the caller is responsible for applying the modifications.
accumulated = result
}
}
return accumulated, nil
}
// RegisteredTools returns all custom tools registered by loaded extensions.
func (r *Runner) RegisteredTools() []ToolDef {
var tools []ToolDef
for i := range r.extensions {
tools = append(tools, r.extensions[i].Tools...)
}
return tools
}
// RegisteredCommands returns all slash commands registered by loaded extensions.
func (r *Runner) RegisteredCommands() []CommandDef {
var cmds []CommandDef
for i := range r.extensions {
cmds = append(cmds, r.extensions[i].Commands...)
}
return cmds
}
// Extensions returns the loaded extensions for inspection (e.g. CLI list).
func (r *Runner) Extensions() []LoadedExtension {
return r.extensions
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
// safeCall invokes a handler, recovering from panics.
func safeCall(handler HandlerFunc, event Event, ctx Context) (result Result, err error) {
defer func() {
if rec := recover(); rec != nil {
err = fmt.Errorf("extension panicked: %v", rec)
}
}()
return handler(event, ctx), nil
}
// isBlocking returns true if the result should short-circuit further handlers.
func isBlocking(result Result) bool {
switch r := result.(type) {
case ToolCallResult:
return r.Block
case InputResult:
return r.Action == "handled"
}
return false
}
+573
View File
@@ -0,0 +1,573 @@
package extensions
import (
"testing"
)
// makeRunner builds a Runner with the given extensions for testing.
func makeRunner(exts ...LoadedExtension) *Runner {
return NewRunner(exts)
}
// makeHandlerExt creates a LoadedExtension with handlers registered for the given events.
func makeHandlerExt(path string, handlers map[EventType][]HandlerFunc) LoadedExtension {
return LoadedExtension{
Path: path,
Handlers: handlers,
}
}
func TestRunner_EmitNoHandlers(t *testing.T) {
r := makeRunner()
result, err := r.Emit(ToolCallEvent{ToolName: "test"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != nil {
t.Fatalf("expected nil result, got %v", result)
}
}
func TestRunner_EmitSequentialOrder(t *testing.T) {
var order []int
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result { order = append(order, 1); return nil },
},
})
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result { order = append(order, 2); return nil },
},
})
ext3 := makeHandlerExt("ext3.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result { order = append(order, 3); return nil },
},
})
r := makeRunner(ext1, ext2, ext3)
_, err := r.Emit(SessionStartEvent{SessionID: "test"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Fatalf("expected sequential order [1,2,3], got %v", order)
}
}
func TestRunner_EmitMultipleHandlersPerExtension(t *testing.T) {
var calls int
ext := makeHandlerExt("multi.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result { calls++; return nil },
func(e Event, c Context) Result { calls++; return nil },
},
})
r := makeRunner(ext)
_, err := r.Emit(SessionStartEvent{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if calls != 2 {
t.Fatalf("expected 2 calls, got %d", calls)
}
}
func TestRunner_EmitToolCallBlocking(t *testing.T) {
var secondCalled bool
ext1 := makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result {
return ToolCallResult{Block: true, Reason: "denied"}
},
},
})
ext2 := makeHandlerExt("second.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result {
secondCalled = true
return nil
},
},
})
r := makeRunner(ext1, ext2)
result, err := r.Emit(ToolCallEvent{ToolName: "bash", Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if secondCalled {
t.Error("second handler should not have been called after block")
}
tcr, ok := result.(ToolCallResult)
if !ok {
t.Fatalf("expected ToolCallResult, got %T", result)
}
if !tcr.Block {
t.Error("expected Block=true")
}
if tcr.Reason != "denied" {
t.Errorf("expected reason 'denied', got %q", tcr.Reason)
}
}
func TestRunner_EmitToolCallNonBlocking(t *testing.T) {
ext := makeHandlerExt("allow.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result {
return ToolCallResult{Block: false}
},
},
})
r := makeRunner(ext)
result, err := r.Emit(ToolCallEvent{ToolName: "bash"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tcr, ok := result.(ToolCallResult)
if !ok {
t.Fatalf("expected ToolCallResult, got %T", result)
}
if tcr.Block {
t.Error("expected Block=false for non-blocking result")
}
}
func TestRunner_EmitInputBlocking(t *testing.T) {
ext := makeHandlerExt("input-handler.go", map[EventType][]HandlerFunc{
Input: {
func(e Event, c Context) Result {
return InputResult{Action: "handled"}
},
},
})
r := makeRunner(ext)
result, err := r.Emit(InputEvent{Text: "secret", Source: "interactive"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ir, ok := result.(InputResult)
if !ok {
t.Fatalf("expected InputResult, got %T", result)
}
if ir.Action != "handled" {
t.Errorf("expected Action 'handled', got %q", ir.Action)
}
}
func TestRunner_EmitInputTransform(t *testing.T) {
ext := makeHandlerExt("transform.go", map[EventType][]HandlerFunc{
Input: {
func(e Event, c Context) Result {
ie := e.(InputEvent)
return InputResult{Action: "transform", Text: ie.Text + " transformed"}
},
},
})
r := makeRunner(ext)
result, err := r.Emit(InputEvent{Text: "hello", Source: "cli"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
ir, ok := result.(InputResult)
if !ok {
t.Fatalf("expected InputResult, got %T", result)
}
if ir.Action != "transform" {
t.Errorf("expected Action 'transform', got %q", ir.Action)
}
if ir.Text != "hello transformed" {
t.Errorf("expected transformed text, got %q", ir.Text)
}
}
func TestRunner_EmitToolResultChaining(t *testing.T) {
modified := "modified content"
ext := makeHandlerExt("modifier.go", map[EventType][]HandlerFunc{
ToolResult: {
func(e Event, c Context) Result {
return ToolResultResult{Content: &modified}
},
},
})
r := makeRunner(ext)
result, err := r.Emit(ToolResultEvent{ToolName: "read", Content: "original"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
trr, ok := result.(ToolResultResult)
if !ok {
t.Fatalf("expected ToolResultResult, got %T", result)
}
if trr.Content == nil || *trr.Content != "modified content" {
t.Error("expected content to be modified")
}
}
func TestRunner_EmitPanicRecovery(t *testing.T) {
var secondCalled bool
ext := makeHandlerExt("panicker.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result { panic("boom") },
func(e Event, c Context) Result { secondCalled = true; return nil },
},
})
r := makeRunner(ext)
result, err := r.Emit(SessionStartEvent{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// After a panic, the runner should continue to the next handler.
if !secondCalled {
t.Error("second handler should still be called after panic in first")
}
if result != nil {
t.Errorf("expected nil result, got %v", result)
}
}
func TestRunner_EmitEventPassedCorrectly(t *testing.T) {
var receivedName string
var receivedInput string
ext := makeHandlerExt("inspect.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result {
tc := e.(ToolCallEvent)
receivedName = tc.ToolName
receivedInput = tc.Input
return nil
},
},
})
r := makeRunner(ext)
_, _ = r.Emit(ToolCallEvent{ToolName: "bash", ToolCallID: "123", Input: `{"cmd":"ls"}`})
if receivedName != "bash" {
t.Errorf("expected tool name 'bash', got %q", receivedName)
}
if receivedInput != `{"cmd":"ls"}` {
t.Errorf("expected input '{\"cmd\":\"ls\"}', got %q", receivedInput)
}
}
func TestRunner_SetContext(t *testing.T) {
var receivedCtx Context
ext := makeHandlerExt("ctx.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result {
receivedCtx = c
return nil
},
},
})
r := makeRunner(ext)
r.SetContext(Context{
SessionID: "sess-123",
CWD: "/tmp",
Model: "claude-4",
Interactive: true,
})
_, _ = r.Emit(SessionStartEvent{})
if receivedCtx.SessionID != "sess-123" {
t.Errorf("expected SessionID 'sess-123', got %q", receivedCtx.SessionID)
}
if receivedCtx.CWD != "/tmp" {
t.Errorf("expected CWD '/tmp', got %q", receivedCtx.CWD)
}
if receivedCtx.Model != "claude-4" {
t.Errorf("expected Model 'claude-4', got %q", receivedCtx.Model)
}
if !receivedCtx.Interactive {
t.Error("expected Interactive=true")
}
}
func TestRunner_HasHandlers(t *testing.T) {
ext := makeHandlerExt("test.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result { return nil },
},
})
r := makeRunner(ext)
if !r.HasHandlers(ToolCall) {
t.Error("expected HasHandlers(ToolCall) = true")
}
if r.HasHandlers(SessionStart) {
t.Error("expected HasHandlers(SessionStart) = false")
}
}
func TestRunner_RegisteredTools(t *testing.T) {
ext := LoadedExtension{
Path: "tools.go",
Handlers: make(map[EventType][]HandlerFunc),
Tools: []ToolDef{
{Name: "tool1", Description: "first"},
{Name: "tool2", Description: "second"},
},
}
r := makeRunner(ext)
tools := r.RegisteredTools()
if len(tools) != 2 {
t.Fatalf("expected 2 tools, got %d", len(tools))
}
if tools[0].Name != "tool1" || tools[1].Name != "tool2" {
t.Error("tools not returned in expected order")
}
}
func TestRunner_RegisteredCommands(t *testing.T) {
ext := LoadedExtension{
Path: "cmds.go",
Handlers: make(map[EventType][]HandlerFunc),
Commands: []CommandDef{
{Name: "cmd1", Description: "first"},
},
}
r := makeRunner(ext)
cmds := r.RegisteredCommands()
if len(cmds) != 1 {
t.Fatalf("expected 1 command, got %d", len(cmds))
}
if cmds[0].Name != "cmd1" {
t.Errorf("expected command name 'cmd1', got %q", cmds[0].Name)
}
}
func TestRunner_Extensions(t *testing.T) {
ext1 := makeHandlerExt("a.go", map[EventType][]HandlerFunc{})
ext2 := makeHandlerExt("b.go", map[EventType][]HandlerFunc{})
r := makeRunner(ext1, ext2)
if len(r.Extensions()) != 2 {
t.Fatalf("expected 2 extensions, got %d", len(r.Extensions()))
}
}
func TestRunner_EmitOnlyMatchingEvent(t *testing.T) {
var called bool
ext := makeHandlerExt("mismatch.go", map[EventType][]HandlerFunc{
ToolCall: {
func(e Event, c Context) Result { called = true; return nil },
},
})
r := makeRunner(ext)
_, _ = r.Emit(SessionStartEvent{}) // different event type
if called {
t.Error("ToolCall handler should not be called for SessionStart event")
}
}
func TestRunner_EmitBeforeAgentStartResult(t *testing.T) {
injected := "extra context"
ext := makeHandlerExt("inject.go", map[EventType][]HandlerFunc{
BeforeAgentStart: {
func(e Event, c Context) Result {
return BeforeAgentStartResult{InjectText: &injected}
},
},
})
r := makeRunner(ext)
result, err := r.Emit(BeforeAgentStartEvent{Prompt: "hello"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
bar, ok := result.(BeforeAgentStartResult)
if !ok {
t.Fatalf("expected BeforeAgentStartResult, got %T", result)
}
if bar.InjectText == nil || *bar.InjectText != "extra context" {
t.Error("expected InjectText to be set")
}
}
func TestRunner_LastResultWins(t *testing.T) {
// When multiple handlers return non-nil, non-blocking results,
// the last one should be returned (accumulated).
first := "first"
second := "second"
ext := makeHandlerExt("chain.go", map[EventType][]HandlerFunc{
ToolResult: {
func(e Event, c Context) Result {
return ToolResultResult{Content: &first}
},
func(e Event, c Context) Result {
return ToolResultResult{Content: &second}
},
},
})
r := makeRunner(ext)
result, _ := r.Emit(ToolResultEvent{ToolName: "test", Content: "orig"})
trr := result.(ToolResultResult)
if trr.Content == nil || *trr.Content != "second" {
t.Errorf("expected last result to win, got %v", trr.Content)
}
}
func TestRunner_ContextPrint(t *testing.T) {
var printed []string
var receivedCtx Context
ext := makeHandlerExt("print.go", map[EventType][]HandlerFunc{
Input: {
func(e Event, c Context) Result {
receivedCtx = c
if c.Print != nil {
c.Print("hello from extension")
}
return nil
},
},
})
r := makeRunner(ext)
r.SetContext(Context{
Print: func(text string) {
printed = append(printed, text)
},
})
_, _ = r.Emit(InputEvent{Text: "test"})
if receivedCtx.Print == nil {
t.Fatal("expected Print to be non-nil in context")
}
if len(printed) != 1 || printed[0] != "hello from extension" {
t.Errorf("expected Print to capture 'hello from extension', got %v", printed)
}
}
func TestRunner_ContextPrintInfo(t *testing.T) {
var infos []string
ext := makeHandlerExt("info.go", map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result {
if c.PrintInfo != nil {
c.PrintInfo("extension loaded successfully")
}
return nil
},
},
})
r := makeRunner(ext)
r.SetContext(Context{
PrintInfo: func(text string) {
infos = append(infos, text)
},
})
_, _ = r.Emit(SessionStartEvent{})
if len(infos) != 1 || infos[0] != "extension loaded successfully" {
t.Errorf("expected PrintInfo to capture message, got %v", infos)
}
}
func TestRunner_ContextPrintError(t *testing.T) {
var errors []string
ext := makeHandlerExt("err.go", map[EventType][]HandlerFunc{
ToolResult: {
func(e Event, c Context) Result {
tr := e.(ToolResultEvent)
if tr.IsError && c.PrintError != nil {
c.PrintError("tool failed: " + tr.ToolName)
}
return nil
},
},
})
r := makeRunner(ext)
r.SetContext(Context{
PrintError: func(text string) {
errors = append(errors, text)
},
})
_, _ = r.Emit(ToolResultEvent{ToolName: "bash", IsError: true, Content: "exit 1"})
if len(errors) != 1 || errors[0] != "tool failed: bash" {
t.Errorf("expected PrintError to capture message, got %v", errors)
}
}
func TestRunner_ContextPrintBlock(t *testing.T) {
var captured []PrintBlockOpts
ext := makeHandlerExt("block.go", map[EventType][]HandlerFunc{
Input: {
func(e Event, c Context) Result {
if c.PrintBlock != nil {
c.PrintBlock(PrintBlockOpts{
Text: "deploy complete",
BorderColor: "#a6e3a1",
Subtitle: "deploy-ext",
})
}
return InputResult{Action: "handled"}
},
},
})
r := makeRunner(ext)
r.SetContext(Context{
PrintBlock: func(opts PrintBlockOpts) {
captured = append(captured, opts)
},
})
_, _ = r.Emit(InputEvent{Text: "!deploy"})
if len(captured) != 1 {
t.Fatalf("expected 1 PrintBlock call, got %d", len(captured))
}
if captured[0].Text != "deploy complete" {
t.Errorf("expected text 'deploy complete', got %q", captured[0].Text)
}
if captured[0].BorderColor != "#a6e3a1" {
t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor)
}
if captured[0].Subtitle != "deploy-ext" {
t.Errorf("expected subtitle 'deploy-ext', got %q", captured[0].Subtitle)
}
}
func TestRunner_ContextPrintNilSafe(t *testing.T) {
// When Print/PrintInfo/PrintError/PrintBlock are not set (nil), guarded calls should not panic.
ext := makeHandlerExt("nilprint.go", map[EventType][]HandlerFunc{
Input: {
func(e Event, c Context) Result {
if c.Print != nil {
c.Print("should not happen")
}
if c.PrintInfo != nil {
c.PrintInfo("should not happen")
}
if c.PrintError != nil {
c.PrintError("should not happen")
}
if c.PrintBlock != nil {
c.PrintBlock(PrintBlockOpts{Text: "nope"})
}
return nil
},
},
})
r := makeRunner(ext)
// Context without any Print functions set.
r.SetContext(Context{Model: "test"})
_, err := r.Emit(InputEvent{Text: "test"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
+49
View File
@@ -0,0 +1,49 @@
package extensions
import (
"reflect"
"github.com/traefik/yaegi/interp"
)
// Symbols returns the Yaegi export table that makes KIT's extension API
// available to interpreted Go code. Extensions import these types as:
//
// import "kit/ext"
//
// IMPORTANT: Only concrete types (structs, constants) are exported. Interfaces
// (Event, Result) and the HandlerFunc type are NOT exported because Yaegi
// cannot generate interface wrappers for them. Instead, extensions use
// event-specific methods like api.OnToolCall() which accept concrete function
// signatures.
func Symbols() interp.Exports {
return interp.Exports{
"kit/ext/ext": map[string]reflect.Value{
// Struct types (nil pointer trick for type registration)
"API": reflect.ValueOf((*API)(nil)),
"Context": reflect.ValueOf((*Context)(nil)),
"ToolDef": reflect.ValueOf((*ToolDef)(nil)),
"CommandDef": reflect.ValueOf((*CommandDef)(nil)),
"PrintBlockOpts": reflect.ValueOf((*PrintBlockOpts)(nil)),
// Event structs
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
"ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)),
"ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)),
"ToolResultEvent": reflect.ValueOf((*ToolResultEvent)(nil)),
"ToolResultResult": reflect.ValueOf((*ToolResultResult)(nil)),
"InputEvent": reflect.ValueOf((*InputEvent)(nil)),
"InputResult": reflect.ValueOf((*InputResult)(nil)),
"BeforeAgentStartEvent": reflect.ValueOf((*BeforeAgentStartEvent)(nil)),
"BeforeAgentStartResult": reflect.ValueOf((*BeforeAgentStartResult)(nil)),
"AgentStartEvent": reflect.ValueOf((*AgentStartEvent)(nil)),
"AgentEndEvent": reflect.ValueOf((*AgentEndEvent)(nil)),
"MessageStartEvent": reflect.ValueOf((*MessageStartEvent)(nil)),
"MessageUpdateEvent": reflect.ValueOf((*MessageUpdateEvent)(nil)),
"MessageEndEvent": reflect.ValueOf((*MessageEndEvent)(nil)),
"SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)),
"SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)),
},
}
}
+134
View File
@@ -0,0 +1,134 @@
package extensions
import (
"context"
"fmt"
"charm.land/fantasy"
)
// WrapToolsWithExtensions wraps each tool so that ToolCall and ToolResult
// events are emitted through the extension runner before and after execution.
// This is the Go equivalent of Pi's wrapper.ts pattern.
//
// If the runner has no relevant handlers the original tools are returned
// unchanged (zero overhead).
func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantasy.AgentTool {
if runner == nil {
return tools
}
if !runner.HasHandlers(ToolCall) && !runner.HasHandlers(ToolResult) &&
!runner.HasHandlers(ToolExecutionStart) && !runner.HasHandlers(ToolExecutionEnd) {
return tools
}
wrapped := make([]fantasy.AgentTool, len(tools))
for i, tool := range tools {
wrapped[i] = &wrappedTool{inner: tool, runner: runner}
}
return wrapped
}
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
// into fantasy.AgentTool implementations so the LLM can invoke them.
func ExtensionToolsAsFantasy(defs []ToolDef) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(defs))
for _, def := range defs {
tools = append(tools, &extensionTool{def: def})
}
return tools
}
// ---------------------------------------------------------------------------
// wrappedTool — intercepts tool calls through the extension runner
// ---------------------------------------------------------------------------
type wrappedTool struct {
inner fantasy.AgentTool
runner *Runner
}
func (w *wrappedTool) Info() fantasy.ToolInfo { return w.inner.Info() }
func (w *wrappedTool) ProviderOptions() fantasy.ProviderOptions { return w.inner.ProviderOptions() }
func (w *wrappedTool) SetProviderOptions(o fantasy.ProviderOptions) { w.inner.SetProviderOptions(o) }
func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolName := w.inner.Info().Name
// 1. Emit ToolCall — extensions can block execution.
if w.runner.HasHandlers(ToolCall) {
result, _ := w.runner.Emit(ToolCallEvent{
ToolName: toolName,
ToolCallID: call.ID,
Input: call.Input,
})
if r, ok := result.(ToolCallResult); ok && r.Block {
reason := r.Reason
if reason == "" {
reason = "blocked by extension"
}
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
fmt.Errorf("tool blocked by extension: %s", reason)
}
}
// 2. Emit ToolExecutionStart.
if w.runner.HasHandlers(ToolExecutionStart) {
_, _ = w.runner.Emit(ToolExecutionStartEvent{ToolName: toolName})
}
// 3. Execute the actual tool.
resp, err := w.inner.Run(ctx, call)
// 4. Emit ToolExecutionEnd.
if w.runner.HasHandlers(ToolExecutionEnd) {
_, _ = w.runner.Emit(ToolExecutionEndEvent{ToolName: toolName})
}
// 5. Emit ToolResult — extensions can modify output.
if w.runner.HasHandlers(ToolResult) {
result, _ := w.runner.Emit(ToolResultEvent{
ToolName: toolName,
Input: call.Input,
Content: resp.Content,
IsError: err != nil || resp.IsError,
})
if r, ok := result.(ToolResultResult); ok {
if r.Content != nil {
resp.Content = *r.Content
}
if r.IsError != nil {
resp.IsError = *r.IsError
}
}
}
return resp, err
}
// ---------------------------------------------------------------------------
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
// ---------------------------------------------------------------------------
type extensionTool struct {
def ToolDef
providerOptions fantasy.ProviderOptions
}
func (t *extensionTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: t.def.Name,
Description: t.def.Description,
}
}
func (t *extensionTool) ProviderOptions() fantasy.ProviderOptions { return t.providerOptions }
func (t *extensionTool) SetProviderOptions(o fantasy.ProviderOptions) { t.providerOptions = o }
func (t *extensionTool) Run(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
result, err := t.def.Execute(call.Input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), err
}
return fantasy.NewTextResponse(result), nil
}
+241
View File
@@ -0,0 +1,241 @@
package extensions
import (
"context"
"testing"
"charm.land/fantasy"
)
// mockTool implements fantasy.AgentTool for testing.
type mockTool struct {
name string
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
provOpt fantasy.ProviderOptions
}
func (m *mockTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
}
func (m *mockTool) ProviderOptions() fantasy.ProviderOptions { return m.provOpt }
func (m *mockTool) SetProviderOptions(o fantasy.ProviderOptions) { m.provOpt = o }
func (m *mockTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
if m.runFn != nil {
return m.runFn(ctx, call)
}
return fantasy.NewTextResponse("ok"), nil
}
func newMockTool(name string) *mockTool {
return &mockTool{name: name}
}
func TestWrapToolsWithExtensions_NilRunner(t *testing.T) {
tools := []fantasy.AgentTool{newMockTool("test")}
result := WrapToolsWithExtensions(tools, nil)
if len(result) != 1 {
t.Fatalf("expected 1 tool, got %d", len(result))
}
// Should be the same pointer (unwrapped).
if result[0] != tools[0] {
t.Error("expected original tool when runner is nil")
}
}
func TestWrapToolsWithExtensions_NoRelevantHandlers(t *testing.T) {
r := makeRunner(makeHandlerExt("other.go", map[EventType][]HandlerFunc{
SessionStart: {func(e Event, c Context) Result { return nil }},
}))
tools := []fantasy.AgentTool{newMockTool("test")}
result := WrapToolsWithExtensions(tools, r)
if result[0] != tools[0] {
t.Error("expected original tool when no tool handlers exist")
}
}
func TestWrapToolsWithExtensions_WrapsWhenHandlersExist(t *testing.T) {
r := makeRunner(makeHandlerExt("tc.go", map[EventType][]HandlerFunc{
ToolCall: {func(e Event, c Context) Result { return nil }},
}))
tools := []fantasy.AgentTool{newMockTool("test")}
result := WrapToolsWithExtensions(tools, r)
if result[0] == tools[0] {
t.Error("expected wrapped tool when ToolCall handlers exist")
}
// Verify Info() is passed through.
if result[0].Info().Name != "test" {
t.Errorf("expected name 'test', got %q", result[0].Info().Name)
}
}
func TestWrappedTool_NormalExecution(t *testing.T) {
var toolCallSeen, toolResultSeen bool
r := makeRunner(makeHandlerExt("observe.go", map[EventType][]HandlerFunc{
ToolCall: {func(e Event, c Context) Result {
toolCallSeen = true
return nil
}},
ToolResult: {func(e Event, c Context) Result {
toolResultSeen = true
return nil
}},
}))
mock := newMockTool("bash")
mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("output"), nil
}
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1", Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "output" {
t.Errorf("expected 'output', got %q", resp.Content)
}
if !toolCallSeen {
t.Error("ToolCall handler was not invoked")
}
if !toolResultSeen {
t.Error("ToolResult handler was not invoked")
}
}
func TestWrappedTool_BlockExecution(t *testing.T) {
var toolRan bool
r := makeRunner(makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{
ToolCall: {func(e Event, c Context) Result {
return ToolCallResult{Block: true, Reason: "forbidden"}
}},
}))
mock := newMockTool("danger")
mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
toolRan = true
return fantasy.NewTextResponse("bad"), nil
}
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
if toolRan {
t.Error("tool should not have run after block")
}
if err == nil {
t.Error("expected error from blocked tool")
}
if resp.IsError != true {
t.Error("expected IsError=true from blocked response")
}
}
func TestWrappedTool_ModifyResult(t *testing.T) {
modified := "redacted"
r := makeRunner(makeHandlerExt("redactor.go", map[EventType][]HandlerFunc{
ToolCall: {func(e Event, c Context) Result { return nil }},
ToolResult: {func(e Event, c Context) Result {
return ToolResultResult{Content: &modified}
}},
}))
mock := newMockTool("read")
mock.runFn = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("secret data"), nil
}
tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "redacted" {
t.Errorf("expected 'redacted', got %q", resp.Content)
}
}
func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
var startSeen, endSeen bool
r := makeRunner(makeHandlerExt("lifecycle.go", map[EventType][]HandlerFunc{
ToolCall: {func(e Event, c Context) Result { return nil }},
ToolExecutionStart: {func(e Event, c Context) Result { startSeen = true; return nil }},
ToolExecutionEnd: {func(e Event, c Context) Result { endSeen = true; return nil }},
}))
tools := WrapToolsWithExtensions([]fantasy.AgentTool{newMockTool("test")}, r)
_, _ = tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"})
if !startSeen {
t.Error("ToolExecutionStart not emitted")
}
if !endSeen {
t.Error("ToolExecutionEnd not emitted")
}
}
func TestExtensionToolsAsFantasy(t *testing.T) {
defs := []ToolDef{
{
Name: "greet",
Description: "greets someone",
Parameters: `{"type":"object"}`,
Execute: func(input string) (string, error) { return "hello " + input, nil },
},
}
tools := ExtensionToolsAsFantasy(defs)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
info := tools[0].Info()
if info.Name != "greet" {
t.Errorf("expected name 'greet', got %q", info.Name)
}
if info.Description != "greets someone" {
t.Errorf("expected description 'greets someone', got %q", info.Description)
}
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "world"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.Content != "hello world" {
t.Errorf("expected 'hello world', got %q", resp.Content)
}
}
func TestExtensionTool_Error(t *testing.T) {
defs := []ToolDef{
{
Name: "fail",
Execute: func(input string) (string, error) { return "", context.DeadlineExceeded },
},
}
tools := ExtensionToolsAsFantasy(defs)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
if err == nil {
t.Error("expected error")
}
if !resp.IsError {
t.Error("expected IsError=true")
}
}
func TestExtensionTool_ProviderOptions(t *testing.T) {
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
tools := ExtensionToolsAsFantasy(defs)
// Initially nil.
opts := tools[0].ProviderOptions()
if opts != nil {
t.Error("expected nil ProviderOptions initially")
}
// SetProviderOptions round-trips.
po := fantasy.ProviderOptions{}
tools[0].SetProviderOptions(po)
got := tools[0].ProviderOptions()
if got == nil {
t.Error("expected non-nil ProviderOptions after set")
}
}
+26
View File
@@ -177,6 +177,32 @@ func (c *CLI) DisplayInfo(message string) {
c.displayContainer()
}
// DisplayExtensionBlock renders a custom styled block with the given border
// color and optional subtitle. Used by extensions via ctx.PrintBlock.
func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
theme := GetTheme()
var borderClr = lipgloss.Color("#89b4fa")
if borderColor != "" {
borderClr = lipgloss.Color(borderColor)
}
content := text
if subtitle != "" {
sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + subtitle)
content = content + "\n" + sub
}
rendered := renderContentBlock(
content,
c.messageRenderer.width,
WithAlign(lipgloss.Left),
WithBorderColor(borderClr),
WithMarginBottom(1),
)
fmt.Println(rendered)
}
// DisplayCancellation displays a system message indicating that the current
// AI generation has been cancelled by the user (typically via ESC key).
func (c *CLI) DisplayCancellation() {
+13
View File
@@ -129,6 +129,19 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
h.lastDisplayed = e.Content
}
case app.ExtensionPrintEvent:
h.stopSpinner()
switch e.Level {
case "info":
h.cli.DisplayInfo(e.Text)
case "error":
h.cli.DisplayError(fmt.Errorf("%s", e.Text))
case "block":
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
default:
fmt.Println(e.Text)
}
case app.StepCompleteEvent:
h.stopSpinner()
+43
View File
@@ -530,6 +530,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.state = stateInput
m.canceling = false
case app.ExtensionPrintEvent:
// Extension output — route through styled renderers when a level is set.
switch msg.Level {
case "info":
cmds = append(cmds, m.printSystemMessage(msg.Text))
case "error":
cmds = append(cmds, m.printErrorResponse(app.StepErrorEvent{
Err: fmt.Errorf("%s", msg.Text),
}))
case "block":
cmds = append(cmds, m.printExtensionBlock(msg))
default:
cmds = append(cmds, tea.Println(msg.Text))
}
default:
// Pass unrecognised messages to all children.
if m.input != nil {
@@ -791,6 +806,34 @@ func (m *AppModel) printSystemMessage(text string) tea.Cmd {
return tea.Println(rendered)
}
// printExtensionBlock renders a custom styled block from an extension with
// caller-chosen border color and optional subtitle, then emits it to scrollback.
func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) tea.Cmd {
theme := GetTheme()
// Resolve border color: use the extension's hex value, fall back to theme accent.
var borderClr = lipgloss.Color("#89b4fa") // default blue
if evt.BorderColor != "" {
borderClr = lipgloss.Color(evt.BorderColor)
}
// Build content: main text + optional subtitle line.
content := evt.Text
if evt.Subtitle != "" {
sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + evt.Subtitle)
content = strings.TrimSuffix(content, "\n") + "\n" + sub
}
rendered := renderContentBlock(
content,
m.width,
WithAlign(lipgloss.Left),
WithBorderColor(borderClr),
WithMarginBottom(1),
)
return tea.Println(rendered)
}
// printHelpMessage renders the help text listing all available slash commands.
func (m *AppModel) printHelpMessage() tea.Cmd {
help := "## Available Commands\n\n" +