diff --git a/cmd/extensions.go b/cmd/extensions.go new file mode 100644 index 00000000..363efcd1 --- /dev/null +++ b/cmd/extensions.go @@ -0,0 +1,175 @@ +package cmd + +import ( + "fmt" + "os" + "text/tabwriter" + + "github.com/mark3labs/kit/internal/extensions" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var extensionsCmd = &cobra.Command{ + Use: "extensions", + Short: "Manage KIT extensions", + Long: "Commands for listing, validating, and scaffolding KIT extensions", +} + +var extensionsListCmd = &cobra.Command{ + Use: "list", + Short: "List discovered extensions and their handlers", + RunE: func(cmd *cobra.Command, args []string) error { + loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension")) + if err != nil { + return fmt.Errorf("loading extensions: %w", err) + } + + if len(loaded) == 0 { + fmt.Println("No extensions found.") + fmt.Println() + fmt.Println("Extension search paths:") + fmt.Println(" ~/.config/kit/extensions/*.go (global)") + fmt.Println(" .kit/extensions/*.go (project)") + fmt.Println() + fmt.Println("Run 'kit extensions init' to create an example extension.") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + _, _ = fmt.Fprintln(w, "EXTENSION\tEVENT\tHANDLERS\tTOOLS\tCOMMANDS") + + for _, ext := range loaded { + totalHandlers := 0 + for _, handlers := range ext.Handlers { + totalHandlers += len(handlers) + } + first := true + for event, handlers := range ext.Handlers { + if first { + _, _ = fmt.Fprintf(w, "%s\t%s\t%d\t%d\t%d\n", + ext.Path, event, len(handlers), len(ext.Tools), len(ext.Commands)) + first = false + } else { + _, _ = fmt.Fprintf(w, "\t%s\t%d\t\t\n", + event, len(handlers)) + } + } + if first { + // Extension loaded but registered no handlers + _, _ = fmt.Fprintf(w, "%s\t(none)\t0\t%d\t%d\n", + ext.Path, len(ext.Tools), len(ext.Commands)) + } + } + + return w.Flush() + }, +} + +var extensionsValidateCmd = &cobra.Command{ + Use: "validate", + Short: "Validate all extension files can be loaded", + RunE: func(cmd *cobra.Command, args []string) error { + loaded, err := extensions.LoadExtensions(viper.GetStringSlice("extension")) + if err != nil { + return fmt.Errorf("validation failed: %w", err) + } + + fmt.Printf("Loaded %d extension(s) successfully\n", len(loaded)) + for _, ext := range loaded { + total := 0 + for _, h := range ext.Handlers { + total += len(h) + } + fmt.Printf(" %s (%d handlers, %d tools, %d commands)\n", + ext.Path, total, len(ext.Tools), len(ext.Commands)) + } + return nil + }, +} + +var extensionsInitCmd = &cobra.Command{ + Use: "init", + Short: "Generate an example extension file", + RunE: func(cmd *cobra.Command, args []string) error { + dir := ".kit/extensions" + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("creating extensions directory: %w", err) + } + + example := `package main + +import ( + "fmt" + "os" + "strings" + "time" + + "kit/ext" +) + +// Init is called when the extension is loaded. Register handlers here. +func Init(api ext.API) { + // Log every tool call to a file. + api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult { + f, err := os.OpenFile("/tmp/kit-tool-log.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + defer f.Close() + fmt.Fprintf(f, "[%s] tool=%s\n", time.Now().Format(time.RFC3339), tc.ToolName) + } + return nil // don't block + }) + + // Block dangerous bash commands. + api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult { + if tc.ToolName == "bash" && strings.Contains(tc.Input, "rm -rf /") { + return &ext.ToolCallResult{Block: true, Reason: "Blocked: dangerous command"} + } + return nil + }) + + // Handle custom ! commands. Use ctx.Print/PrintInfo/PrintError/PrintBlock + // instead of fmt.Println — BubbleTea captures stdout in interactive mode. + // + // ctx.Print("text") — plain text + // ctx.PrintInfo("text") — styled system message block + // ctx.PrintError("text") — styled error block + // ctx.PrintBlock(opts) — custom block with border color and subtitle + api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult { + switch ie.Text { + case "!time": + ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339)) + return &ext.InputResult{Action: "handled"} + + case "!status": + ctx.PrintBlock(ext.PrintBlockOpts{ + Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD, + BorderColor: "#a6e3a1", + Subtitle: "my-extension", + }) + return &ext.InputResult{Action: "handled"} + } + return nil + }) +} +` + + path := dir + "/example.go" + if err := os.WriteFile(path, []byte(example), 0644); err != nil { + return fmt.Errorf("writing example: %w", err) + } + + fmt.Printf("Created %s with example extension\n", path) + fmt.Println() + fmt.Println("The extension will be auto-loaded on the next kit run.") + fmt.Println("Use --no-extensions to disable all extensions.") + return nil + }, +} + +func init() { + rootCmd.AddCommand(extensionsCmd) + extensionsCmd.AddCommand(extensionsListCmd) + extensionsCmd.AddCommand(extensionsValidateCmd) + extensionsCmd.AddCommand(extensionsInitCmd) +} diff --git a/cmd/root.go b/cmd/root.go index 5f2e80ee..f0a5d6d4 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,6 +14,7 @@ import ( "github.com/mark3labs/kit/internal/agent" "github.com/mark3labs/kit/internal/app" "github.com/mark3labs/kit/internal/config" + "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/session" "github.com/mark3labs/kit/internal/ui" "github.com/spf13/cobra" @@ -57,7 +58,9 @@ var ( numGPU int32 mainGPU int32 - // Hooks control + // Extensions control + noExtensionsFlag bool + extensionPaths []string // TLS configuration tlsSkipVerify bool @@ -301,6 +304,10 @@ func init() { BoolVarP(&resumeFlag, "resume", "r", false, "interactive session picker") rootCmd.PersistentFlags(). BoolVar(&noSessionFlag, "no-session", false, "ephemeral mode — no session persistence") + rootCmd.PersistentFlags(). + BoolVar(&noExtensionsFlag, "no-extensions", false, "disable all extensions and hooks") + rootCmd.PersistentFlags(). + StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)") flags := rootCmd.PersistentFlags() flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)") @@ -338,6 +345,8 @@ func init() { _ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers")) _ = viper.BindPFlag("main-gpu", rootCmd.PersistentFlags().Lookup("main-gpu")) _ = viper.BindPFlag("tls-skip-verify", rootCmd.PersistentFlags().Lookup("tls-skip-verify")) + _ = viper.BindPFlag("no-extensions", rootCmd.PersistentFlags().Lookup("no-extensions")) + _ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension")) // Defaults are already set in flag definitions, no need to duplicate in viper @@ -542,7 +551,7 @@ func runNormalMode(ctx context.Context) error { } // Create the app.App instance now that session messages are loaded. - appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames) + appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner) appOpts.SessionManager = sessionManager appOpts.TreeSession = treeSession @@ -564,6 +573,22 @@ func runNormalMode(ctx context.Context) error { appInstance := app.New(appOpts, messages) defer appInstance.Close() + // Emit SessionStart event to extensions. + if agentResult.ExtRunner != nil { + agentResult.ExtRunner.SetContext(extensions.Context{ + CWD: cwd, + Model: modelName, + Interactive: promptFlag == "", + Print: func(text string) { appInstance.PrintFromExtension("", text) }, + PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) }, + PrintError: func(text string) { appInstance.PrintFromExtension("error", text) }, + PrintBlock: appInstance.PrintBlockFromExtension, + }) + if agentResult.ExtRunner.HasHandlers(extensions.SessionStart) { + _, _ = agentResult.ExtRunner.Emit(extensions.SessionStartEvent{}) + } + } + // Check if running in non-interactive mode if promptFlag != "" { return runNonInteractiveModeApp(ctx, appInstance, cli, promptFlag, quietFlag, noExitFlag, modelName, parsedProvider, mcpAgent.GetLoadingMessage(), serverNames, toolNames, usageTracker) diff --git a/cmd/script.go b/cmd/script.go index a7d4ecc1..36f507db 100644 --- a/cmd/script.go +++ b/cmd/script.go @@ -534,7 +534,7 @@ func runScriptMode(ctx context.Context, mcpConfig *config.Config, prompt string, DisplayDebugConfig(cli, mcpAgent, mcpConfig, parsedProvider) // Build app options. - appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames) + appOpts := BuildAppOptions(mcpAgent, mcpConfig, modelName, serverNames, toolNames, agentResult.ExtRunner) if cli != nil { if tracker := cli.GetUsageTracker(); tracker != nil { appOpts.UsageTracker = tracker diff --git a/cmd/setup.go b/cmd/setup.go index bc50b844..c5b43109 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -5,9 +5,13 @@ import ( "fmt" "strings" + "charm.land/fantasy" + "github.com/mark3labs/kit/internal/agent" "github.com/mark3labs/kit/internal/app" "github.com/mark3labs/kit/internal/config" + "github.com/mark3labs/kit/internal/extensions" + "github.com/mark3labs/kit/internal/hooks" "github.com/mark3labs/kit/internal/models" "github.com/mark3labs/kit/internal/tools" "github.com/mark3labs/kit/internal/ui" @@ -70,6 +74,9 @@ type AgentSetupOptions struct { type AgentSetupResult struct { Agent *agent.Agent BufferedLogger *tools.BufferedDebugLogger + // ExtRunner is the extension runner (nil when --no-extensions or no + // extensions were discovered). + ExtRunner *extensions.Runner } // SetupAgent creates an agent from the current viper state + the provided @@ -93,6 +100,20 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, } } + // Load extensions unless --no-extensions is set. Extensions must be loaded + // BEFORE agent creation so their tool wrapper and custom tools are included + // in the Fantasy agent's tool list. + var extRunner *extensions.Runner + var extCreationOpts extensionCreationOpts + if !viper.GetBool("no-extensions") { + var extErr error + extRunner, extCreationOpts, extErr = setupExtensions() + if extErr != nil { + // Extension loading failures are non-fatal. + fmt.Printf("Warning: Failed to load extensions: %v\n", extErr) + } + } + a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{ ModelConfig: modelConfig, MCPConfig: opts.MCPConfig, @@ -103,6 +124,8 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, Quiet: quietFlag, SpinnerFunc: opts.SpinnerFunc, DebugLogger: debugLogger, + ToolWrapper: extCreationOpts.toolWrapper, + ExtraTools: extCreationOpts.extraTools, }) if err != nil { return nil, fmt.Errorf("failed to create agent: %w", err) @@ -110,10 +133,57 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, return &AgentSetupResult{ Agent: a, + ExtRunner: extRunner, BufferedLogger: bufferedLogger, }, nil } +// extensionCreationOpts holds the tool wrapper and extra tools that need to be +// passed into agent creation, extracted from loaded extensions. +type extensionCreationOpts struct { + toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool + extraTools []fantasy.AgentTool +} + +// setupExtensions discovers and loads Yaegi extensions plus legacy hooks.yml, +// builds the runner, and returns the tool wrapper/extra tools needed by the +// agent factory. +func setupExtensions() (*extensions.Runner, extensionCreationOpts, error) { + extraPaths := viper.GetStringSlice("extension") + loaded, err := extensions.LoadExtensions(extraPaths) + if err != nil { + return nil, extensionCreationOpts{}, err + } + + // Also load legacy hooks.yml as a compat extension. + hooksCfg, _ := hooks.LoadHooksConfig() + if hooksCfg != nil && len(hooksCfg.Hooks) > 0 { + compat := extensions.HooksAsExtension(hooksCfg) + if compat != nil { + loaded = append([]extensions.LoadedExtension{*compat}, loaded...) + } + } + + if len(loaded) == 0 { + return nil, extensionCreationOpts{}, nil + } + + runner := extensions.NewRunner(loaded) + + // Build the tool wrapper that intercepts tool calls through the runner. + wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool { + return extensions.WrapToolsWithExtensions(tools, runner) + } + + // Collect custom tools registered by extensions. + extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools()) + + return runner, extensionCreationOpts{ + toolWrapper: wrapper, + extraTools: extTools, + }, nil +} + // CollectAgentMetadata extracts model display info and tool/server name lists // from the agent. This is used by both root.go and script.go to populate // app.Options and UI setup. @@ -138,7 +208,7 @@ func CollectAgentMetadata(mcpAgent *agent.Agent, mcpConfig *config.Config) (prov // BuildAppOptions constructs the app.Options struct from the current state. // Both root.go and script.go converge here after agent creation. -func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string) app.Options { +func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName string, serverNames, toolNames []string, extRunner *extensions.Runner) app.Options { return app.Options{ Agent: mcpAgent, MCPConfig: mcpConfig, @@ -149,6 +219,7 @@ func BuildAppOptions(mcpAgent *agent.Agent, mcpConfig *config.Config, modelName Quiet: quietFlag, Debug: viper.GetBool("debug"), CompactMode: viper.GetBool("compact"), + Extensions: extRunner, } } diff --git a/examples/extensions/tool-logger.go b/examples/extensions/tool-logger.go new file mode 100644 index 00000000..392e333f --- /dev/null +++ b/examples/extensions/tool-logger.go @@ -0,0 +1,81 @@ +//go:build ignore + +package main + +import ( + "fmt" + "os" + "time" + + "kit/ext" +) + +// Init registers handlers that log all tool calls and session lifecycle +// events to /tmp/kit-tool-log.txt. +func Init(api ext.API) { + logFile := "/tmp/kit-tool-log.txt" + + // Log every tool call before execution. + api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult { + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + defer f.Close() + fmt.Fprintf(f, "[%s] CALL tool=%s model=%s\n", + time.Now().Format(time.RFC3339), tc.ToolName, ctx.Model) + } + return nil + }) + + // Log tool results after execution. + api.OnToolResult(func(tr ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult { + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + defer f.Close() + status := "ok" + if tr.IsError { + status = "error" + } + fmt.Fprintf(f, "[%s] RESULT tool=%s status=%s bytes=%d\n", + time.Now().Format(time.RFC3339), tr.ToolName, status, len(tr.Content)) + } + return nil // don't modify the result + }) + + // Log session start/shutdown. + api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) { + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + defer f.Close() + fmt.Fprintf(f, "[%s] SESSION_START cwd=%s\n", + time.Now().Format(time.RFC3339), ctx.CWD) + } + }) + + api.OnSessionShutdown(func(_ ext.SessionShutdownEvent, ctx ext.Context) { + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err == nil { + defer f.Close() + fmt.Fprintf(f, "[%s] SESSION_SHUTDOWN\n", + time.Now().Format(time.RFC3339)) + } + }) + + // "!time" — prints the current time as a styled info block. + // "!status" — prints a custom block with green border and subtitle. + api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult { + switch ie.Text { + case "!time": + ctx.PrintInfo("Current time: " + time.Now().Format(time.RFC3339)) + return &ext.InputResult{Action: "handled"} + + case "!status": + ctx.PrintBlock(ext.PrintBlockOpts{ + Text: "Session active\nModel: " + ctx.Model + "\nCWD: " + ctx.CWD, + BorderColor: "#a6e3a1", + Subtitle: "tool-logger extension", + }) + return &ext.InputResult{Action: "handled"} + } + return nil + }) +} diff --git a/go.mod b/go.mod index 861f57da..e84f9825 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/charmbracelet/colorprofile v0.4.2 // indirect github.com/charmbracelet/harmonica v0.2.0 // indirect github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect + github.com/charmbracelet/log v0.4.2 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 // indirect github.com/charmbracelet/x/cellbuf v0.0.15 // indirect github.com/charmbracelet/x/exp/charmtone v0.0.0-20260223200540-d6a276319c45 // indirect @@ -61,6 +62,7 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.5.0 // indirect @@ -95,6 +97,7 @@ require ( github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/traefik/yaegi v0.16.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect @@ -108,6 +111,7 @@ require ( go.opentelemetry.io/otel/trace v1.40.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto v0.48.0 // indirect + golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/time v0.14.0 // indirect diff --git a/go.sum b/go.sum index a263ed22..ec2c42a3 100644 --- a/go.sum +++ b/go.sum @@ -86,6 +86,8 @@ github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao= github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE= github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834/go.mod h1:aKC/t2arECF6rNOnaKaVU6y4t4ZeHQzqfxedE/VkVhA= +github.com/charmbracelet/log v0.4.2 h1:hYt8Qj6a8yLnvR+h7MwsJv/XvmBJXiueUcI3cIxsyig= +github.com/charmbracelet/log v0.4.2/go.mod h1:qifHGX/tc7eluv2R6pWIpyHDDrrb/AG71Pf2ysQu5nw= github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73 h1:Af/L28Xh+pddhouT/6lJ7IAIYfu5tWJOB0iqt+mXsYM= github.com/charmbracelet/ultraviolet v0.0.0-20260223171050-89c142e4aa73/go.mod h1:E6/0abq9uG2SnM8IbLB9Y5SW09uIgfaFETk8aRzgXUQ= github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8= @@ -132,6 +134,8 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao= github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -251,6 +255,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E= +github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= diff --git a/internal/agent/agent.go b/internal/agent/agent.go index b4ce0c6a..523bcd45 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -24,6 +24,15 @@ type AgentConfig struct { MaxSteps int StreamingEnabled bool DebugLogger tools.DebugLogger + + // ToolWrapper is an optional function that wraps the combined tool list + // before it is passed to the Fantasy agent. Used by the extensions system + // to intercept tool calls/results. + ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool + + // ExtraTools are additional tools to include alongside core and MCP tools. + // Used by extensions to register custom tools. + ExtraTools []fantasy.AgentTool } // ToolCallHandler is a function type for handling tool calls as they happen. @@ -109,6 +118,16 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) { } } + // Append any extra tools provided by extensions. + if len(agentConfig.ExtraTools) > 0 { + allTools = append(allTools, agentConfig.ExtraTools...) + } + + // Apply tool wrapper (extension interception layer) if configured. + if agentConfig.ToolWrapper != nil { + allTools = agentConfig.ToolWrapper(allTools) + } + // Build fantasy agent options var agentOpts []fantasy.AgentOption diff --git a/internal/agent/factory.go b/internal/agent/factory.go index a35b0c02..0929d6bd 100644 --- a/internal/agent/factory.go +++ b/internal/agent/factory.go @@ -4,6 +4,8 @@ import ( "context" "fmt" + "charm.land/fantasy" + "github.com/mark3labs/kit/internal/config" "github.com/mark3labs/kit/internal/models" "github.com/mark3labs/kit/internal/tools" @@ -34,6 +36,10 @@ type AgentCreationOptions struct { SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller) // DebugLogger is an optional logger for debugging MCP communications DebugLogger tools.DebugLogger // Optional debug logger + // ToolWrapper wraps the combined tool list before Fantasy agent creation. + ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool + // ExtraTools are additional tools to include (e.g. from extensions). + ExtraTools []fantasy.AgentTool } // CreateAgent creates an agent with optional spinner for Ollama models. @@ -47,6 +53,8 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error MaxSteps: opts.MaxSteps, StreamingEnabled: opts.StreamingEnabled, DebugLogger: opts.DebugLogger, + ToolWrapper: opts.ToolWrapper, + ExtraTools: opts.ExtraTools, } var agent *Agent diff --git a/internal/app/app.go b/internal/app/app.go index 09f493c0..1af909d5 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,6 +9,7 @@ import ( "charm.land/fantasy" "github.com/mark3labs/kit/internal/agent" + "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/session" ) @@ -254,6 +255,11 @@ func (a *App) Close() { cancel := a.cancelStep a.mu.Unlock() + // --- Extension: SessionShutdown --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.SessionShutdown) { + _, _ = a.opts.Extensions.Emit(extensions.SessionShutdownEvent{}) + } + // Cancel any in-flight step and the root context. cancel() a.rootCancel() @@ -362,6 +368,23 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M } } + // --- Extension: Input event (can transform or handle the prompt) --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.Input) { + result, _ := a.opts.Extensions.Emit(extensions.InputEvent{ + Text: prompt, + Source: a.inputSource(), + }) + if r, ok := result.(extensions.InputResult); ok { + switch r.Action { + case "transform": + prompt = r.Text + case "handled": + // Extension handled the input; skip the agent entirely. + return &agent.GenerateWithLoopResult{}, nil + } + } + } + // Add user message to the store immediately so history is consistent // even if the step is later cancelled. userMsg := fantasy.NewUserMessage(prompt) @@ -385,9 +408,39 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M // Track message count before agent runs so we can diff new messages. sentCount := len(msgs) + // --- Extension: BeforeAgentStart --- + // Extensions can inject a system message or prepend context text into the + // conversation before the agent runs. + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.BeforeAgentStart) { + result, _ := a.opts.Extensions.Emit(extensions.BeforeAgentStartEvent{Prompt: prompt}) + if r, ok := result.(extensions.BeforeAgentStartResult); ok { + if r.SystemPrompt != nil && *r.SystemPrompt != "" { + // Prepend a system message so the LLM sees extension-provided + // instructions. This supplements (not replaces) the agent's + // configured system prompt. + msgs = append([]fantasy.Message{fantasy.NewSystemMessage(*r.SystemPrompt)}, msgs...) + } + if r.InjectText != nil && *r.InjectText != "" { + // Prepend a user message with the injected context so it + // appears early in the conversation window. + msgs = append([]fantasy.Message{fantasy.NewUserMessage(*r.InjectText)}, msgs...) + } + } + } + + // --- Extension: AgentStart --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentStart) { + _, _ = a.opts.Extensions.Emit(extensions.AgentStartEvent{Prompt: prompt}) + } + // Signal spinner start. sendFn(SpinnerEvent{Show: true}) + // --- Extension: MessageStart --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageStart) { + _, _ = a.opts.Extensions.Emit(extensions.MessageStartEvent{}) + } + result, err := a.opts.Agent.GenerateWithLoopAndStreaming(ctx, msgs, // onToolCall func(toolName, toolArgs string) { @@ -416,14 +469,42 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M }, // onStreamingResponse — spinner keeps running alongside streaming text func(chunk string) { + // Extension: MessageUpdate (observe streaming chunks) + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageUpdate) { + _, _ = a.opts.Extensions.Emit(extensions.MessageUpdateEvent{Chunk: chunk}) + } sendFn(StreamChunkEvent{Content: chunk}) }, ) if err != nil { + // --- Extension: AgentEnd with error --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) { + _, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{ + Response: "", + StopReason: "error", + }) + } return nil, err } + // --- Extension: MessageEnd --- + responseText := "" + if result.FinalResponse != nil { + responseText = result.FinalResponse.Content.Text() + } + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.MessageEnd) { + _, _ = a.opts.Extensions.Emit(extensions.MessageEndEvent{Content: responseText}) + } + + // --- Extension: AgentEnd with success --- + if a.opts.Extensions != nil && a.opts.Extensions.HasHandlers(extensions.AgentEnd) { + _, _ = a.opts.Extensions.Emit(extensions.AgentEndEvent{ + Response: responseText, + StopReason: "completed", + }) + } + // Replace the store with the full updated conversation returned by the agent // (includes tool call/result messages added during the step). a.store.Replace(result.ConversationMessages) @@ -439,6 +520,17 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M return result, nil } +// inputSource returns a string identifying how the current session receives +// input — used by the Input extension event. +func (a *App) inputSource() string { + a.mu.Lock() + defer a.mu.Unlock() + if a.program != nil { + return "interactive" + } + return "cli" +} + // -------------------------------------------------------------------------- // Internal: event helpers // -------------------------------------------------------------------------- @@ -454,6 +546,45 @@ func (a *App) sendEvent(msg tea.Msg) { } } +// PrintFromExtension outputs text from an extension to the user. The level +// controls styling: "" for plain text, "info" for a system message block, +// "error" for an error block. In interactive mode it sends an +// ExtensionPrintEvent through the program so the TUI can render it with the +// appropriate renderer. In non-interactive mode it falls back to stdout. +func (a *App) PrintFromExtension(level, text string) { + a.mu.Lock() + prog := a.program + a.mu.Unlock() + if prog != nil { + prog.Send(ExtensionPrintEvent{Text: text, Level: level}) + return + } + // Non-interactive fallback: write directly to stdout. + fmt.Println(text) +} + +// PrintBlockFromExtension outputs a custom styled block from an extension. +func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) { + a.mu.Lock() + prog := a.program + a.mu.Unlock() + if prog != nil { + prog.Send(ExtensionPrintEvent{ + Text: opts.Text, + Level: "block", + BorderColor: opts.BorderColor, + Subtitle: opts.Subtitle, + }) + return + } + // Non-interactive fallback. + if opts.Subtitle != "" { + fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle) + } else { + fmt.Println(opts.Text) + } +} + // updateUsage records token usage from a completed agent step into the configured // UsageTracker (if any). It uses the actual token counts from the agent result's // TotalUsage field when available; otherwise it falls back to text-based estimation. diff --git a/internal/app/events.go b/internal/app/events.go index be6c8f73..b294d90f 100644 --- a/internal/app/events.go +++ b/internal/app/events.go @@ -96,3 +96,23 @@ type MessageCreatedEvent struct { // Message is the fantasy message that was added to the store. Message fantasy.Message } + +// ExtensionPrintEvent is sent when an extension calls ctx.Print, ctx.PrintInfo, +// ctx.PrintError, or ctx.PrintBlock. The TUI renders it via the appropriate +// renderer and tea.Println (scrollback); the CLI handler uses +// DisplayInfo/DisplayError or plain fmt.Println. This exists because BubbleTea +// captures stdout, so plain fmt.Println inside extensions would be swallowed. +type ExtensionPrintEvent struct { + // Text is the content the extension wants to display to the user. + Text string + // Level controls the rendering style: + // "" — plain text (no styling) + // "info" — system message block (bordered, themed) + // "error" — error block (red border, bold text) + // "block" — custom block with BorderColor and Subtitle + Level string + // BorderColor is a hex color (e.g. "#a6e3a1") for Level="block". + BorderColor string + // Subtitle is optional muted text below the content for Level="block". + Subtitle string +} diff --git a/internal/app/options.go b/internal/app/options.go index d326dfc7..4f67bcf1 100644 --- a/internal/app/options.go +++ b/internal/app/options.go @@ -7,6 +7,7 @@ import ( "github.com/mark3labs/kit/internal/agent" "github.com/mark3labs/kit/internal/config" + "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/session" ) @@ -94,4 +95,10 @@ type Options struct { // EstimateAndUpdateUsage as a fallback) using the usage data returned by the // agent. Satisfied by *ui.UsageTracker; wired in cmd/root.go. UsageTracker UsageUpdater + + // Extensions is the optional extension runner. When non-nil, lifecycle + // events (Input, BeforeAgentStart, AgentEnd, etc.) are emitted through + // it. Tool-level events (ToolCall, ToolResult) are handled by wrapper.go + // at the tool layer, not here. + Extensions *extensions.Runner } diff --git a/internal/extensions/api.go b/internal/extensions/api.go new file mode 100644 index 00000000..fb6683cf --- /dev/null +++ b/internal/extensions/api.go @@ -0,0 +1,325 @@ +package extensions + +// --------------------------------------------------------------------------- +// Internal types (used by runner, NOT exposed to Yaegi) +// --------------------------------------------------------------------------- + +// Event is the interface satisfied by all event types internally. +type Event interface { + Type() EventType +} + +// Result is the interface satisfied by all result types internally. +type Result interface { + isResult() +} + +// HandlerFunc is the internal handler signature used by the runner. +type HandlerFunc func(event Event, ctx Context) Result + +// --------------------------------------------------------------------------- +// Context (exposed to Yaegi — concrete struct, no interfaces) +// --------------------------------------------------------------------------- + +// Context provides runtime information to handlers about the current session. +type Context struct { + SessionID string + CWD string + Model string + Interactive bool + + // Print outputs plain text to the user. In interactive mode this + // routes through BubbleTea's scrollback (tea.Println); in + // non-interactive mode it writes to stdout. Extensions must use + // this instead of fmt.Println, which is swallowed by BubbleTea. + Print func(string) + + // PrintInfo outputs text as a styled system message block (bordered, + // themed). Use this for informational notices the user should see. + PrintInfo func(string) + + // PrintError outputs text as a styled error block (red border, bold). + // Use this for error messages or warnings. + PrintError func(string) + + // PrintBlock outputs text as a custom styled block with caller-chosen + // border color and optional subtitle. Example: + // + // ctx.PrintBlock(ext.PrintBlockOpts{ + // Text: "Deployment complete!", + // BorderColor: "#a6e3a1", + // Subtitle: "my-extension", + // }) + PrintBlock func(PrintBlockOpts) +} + +// PrintBlockOpts configures a custom styled block for PrintBlock. +type PrintBlockOpts struct { + // Text is the main content to display. + Text string + // BorderColor is a hex color string (e.g. "#a6e3a1") for the left border. + // Defaults to the theme's system color if empty. + BorderColor string + // Subtitle is optional text shown below the content in muted style + // (e.g. extension name, timestamp). Empty means no subtitle line. + Subtitle string +} + +// --------------------------------------------------------------------------- +// API — the object passed to each extension's Init function. +// +// Instead of a generic On(EventType, HandlerFunc) that uses interfaces, +// we expose event-specific methods with concrete function signatures. +// This avoids Yaegi's genInterfaceWrapper crash entirely — no interfaces +// cross the Yaegi boundary. +// --------------------------------------------------------------------------- + +// API is passed to each extension's Init function. Extensions use it to +// register typed event handlers, custom tools, and slash commands. +type API struct { + // Event-specific registration functions (wired by the loader). + onToolCall func(func(ToolCallEvent, Context) *ToolCallResult) + onToolExecStart func(func(ToolExecutionStartEvent, Context)) + onToolExecEnd func(func(ToolExecutionEndEvent, Context)) + onToolResult func(func(ToolResultEvent, Context) *ToolResultResult) + onInput func(func(InputEvent, Context) *InputResult) + onBeforeAgentStart func(func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) + onAgentStart func(func(AgentStartEvent, Context)) + onAgentEnd func(func(AgentEndEvent, Context)) + onMessageStart func(func(MessageStartEvent, Context)) + onMessageUpdate func(func(MessageUpdateEvent, Context)) + onMessageEnd func(func(MessageEndEvent, Context)) + onSessionStart func(func(SessionStartEvent, Context)) + onSessionShutdown func(func(SessionShutdownEvent, Context)) + registerToolFn func(ToolDef) + registerCmdFn func(CommandDef) +} + +// OnToolCall registers a handler that fires before a tool executes. +// Return a non-nil ToolCallResult with Block=true to prevent execution. +func (a *API) OnToolCall(handler func(ToolCallEvent, Context) *ToolCallResult) { + a.onToolCall(handler) +} + +// OnToolExecutionStart registers a handler for tool execution start. +func (a *API) OnToolExecutionStart(handler func(ToolExecutionStartEvent, Context)) { + a.onToolExecStart(handler) +} + +// OnToolExecutionEnd registers a handler for tool execution end. +func (a *API) OnToolExecutionEnd(handler func(ToolExecutionEndEvent, Context)) { + a.onToolExecEnd(handler) +} + +// OnToolResult registers a handler that fires after tool execution. +// Return a non-nil ToolResultResult to modify the output. +func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultResult) { + a.onToolResult(handler) +} + +// OnInput registers a handler that fires when user input is received. +// Return a non-nil InputResult to transform or handle the input. +func (a *API) OnInput(handler func(InputEvent, Context) *InputResult) { + a.onInput(handler) +} + +// OnBeforeAgentStart registers a handler that fires before the agent loop. +func (a *API) OnBeforeAgentStart(handler func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) { + a.onBeforeAgentStart(handler) +} + +// OnAgentStart registers a handler for when the agent loop begins. +func (a *API) OnAgentStart(handler func(AgentStartEvent, Context)) { + a.onAgentStart(handler) +} + +// OnAgentEnd registers a handler for when the agent finishes responding. +func (a *API) OnAgentEnd(handler func(AgentEndEvent, Context)) { + a.onAgentEnd(handler) +} + +// OnMessageStart registers a handler for when an assistant message begins. +func (a *API) OnMessageStart(handler func(MessageStartEvent, Context)) { + a.onMessageStart(handler) +} + +// OnMessageUpdate registers a handler for streaming text chunks. +func (a *API) OnMessageUpdate(handler func(MessageUpdateEvent, Context)) { + a.onMessageUpdate(handler) +} + +// OnMessageEnd registers a handler for when the assistant message is complete. +func (a *API) OnMessageEnd(handler func(MessageEndEvent, Context)) { + a.onMessageEnd(handler) +} + +// OnSessionStart registers a handler for when a session is loaded or created. +func (a *API) OnSessionStart(handler func(SessionStartEvent, Context)) { + a.onSessionStart(handler) +} + +// OnSessionShutdown registers a handler for when the application is closing. +func (a *API) OnSessionShutdown(handler func(SessionShutdownEvent, Context)) { + a.onSessionShutdown(handler) +} + +// RegisterTool adds a custom tool that the LLM can invoke. +func (a *API) RegisterTool(tool ToolDef) { + a.registerToolFn(tool) +} + +// RegisterCommand adds a slash command available in interactive mode. +func (a *API) RegisterCommand(cmd CommandDef) { + a.registerCmdFn(cmd) +} + +// --------------------------------------------------------------------------- +// ToolDef / CommandDef +// --------------------------------------------------------------------------- + +// ToolDef describes a custom tool registered by an extension. +type ToolDef struct { + Name string + Description string + Parameters string // JSON Schema string + Execute func(input string) (string, error) +} + +// CommandDef describes a slash command registered by an extension. +type CommandDef struct { + Name string + Description string + Execute func(args string) (string, error) +} + +// --------------------------------------------------------------------------- +// Typed events (all concrete structs — safe for Yaegi) +// --------------------------------------------------------------------------- + +// ToolCallEvent fires before a tool executes. +type ToolCallEvent struct { + ToolName string + ToolCallID string + Input string // JSON-encoded tool parameters +} + +func (e ToolCallEvent) Type() EventType { return ToolCall } + +// ToolCallResult controls whether the tool call proceeds. +type ToolCallResult struct { + Block bool + Reason string +} + +func (ToolCallResult) isResult() {} + +// ToolExecutionStartEvent fires when a tool begins executing. +type ToolExecutionStartEvent struct { + ToolName string +} + +func (e ToolExecutionStartEvent) Type() EventType { return ToolExecutionStart } + +// ToolExecutionEndEvent fires when a tool finishes executing. +type ToolExecutionEndEvent struct { + ToolName string +} + +func (e ToolExecutionEndEvent) Type() EventType { return ToolExecutionEnd } + +// ToolResultEvent fires after tool execution with the output. +type ToolResultEvent struct { + ToolName string + Input string + Content string + IsError bool +} + +func (e ToolResultEvent) Type() EventType { return ToolResult } + +// ToolResultResult can modify the tool's output before it reaches the LLM. +type ToolResultResult struct { + Content *string // nil = unchanged + IsError *bool // nil = unchanged +} + +func (ToolResultResult) isResult() {} + +// InputEvent fires when user input is received. +type InputEvent struct { + Text string + Source string // "interactive", "cli", "script", "queue" +} + +func (e InputEvent) Type() EventType { return Input } + +// InputResult controls what happens with user input. +// +// Action: "continue" (default), "transform", "handled" +type InputResult struct { + Action string + Text string // replacement text when Action="transform" +} + +func (InputResult) isResult() {} + +// BeforeAgentStartEvent fires before the agent loop begins. +type BeforeAgentStartEvent struct { + Prompt string +} + +func (e BeforeAgentStartEvent) Type() EventType { return BeforeAgentStart } + +// BeforeAgentStartResult can inject context before the agent runs. +type BeforeAgentStartResult struct { + InjectText *string + SystemPrompt *string +} + +func (BeforeAgentStartResult) isResult() {} + +// AgentStartEvent fires when the agent loop begins. +type AgentStartEvent struct { + Prompt string +} + +func (e AgentStartEvent) Type() EventType { return AgentStart } + +// AgentEndEvent fires when the agent finishes responding. +type AgentEndEvent struct { + Response string + StopReason string // "completed", "cancelled", "error" +} + +func (e AgentEndEvent) Type() EventType { return AgentEnd } + +// MessageStartEvent fires when a new assistant message begins. +type MessageStartEvent struct{} + +func (e MessageStartEvent) Type() EventType { return MessageStart } + +// MessageUpdateEvent fires for each streaming text chunk. +type MessageUpdateEvent struct { + Chunk string +} + +func (e MessageUpdateEvent) Type() EventType { return MessageUpdate } + +// MessageEndEvent fires when the assistant message is complete. +type MessageEndEvent struct { + Content string +} + +func (e MessageEndEvent) Type() EventType { return MessageEnd } + +// SessionStartEvent fires when a session is loaded or created. +type SessionStartEvent struct { + SessionID string +} + +func (e SessionStartEvent) Type() EventType { return SessionStart } + +// SessionShutdownEvent fires when the application is closing. +type SessionShutdownEvent struct{} + +func (e SessionShutdownEvent) Type() EventType { return SessionShutdown } diff --git a/internal/extensions/compat.go b/internal/extensions/compat.go new file mode 100644 index 00000000..40e3be4a --- /dev/null +++ b/internal/extensions/compat.go @@ -0,0 +1,111 @@ +package extensions + +import ( + "context" + "encoding/json" + + "github.com/mark3labs/kit/internal/hooks" +) + +// HooksAsExtension wraps an existing hooks.HookConfig as a LoadedExtension +// so that legacy .kit/hooks.yml configurations continue to work alongside +// the new Yaegi extension system. The adapter translates the old event names +// and shell-command execution model into extension HandlerFunc handlers. +func HooksAsExtension(config *hooks.HookConfig) *LoadedExtension { + if config == nil || len(config.Hooks) == 0 { + return nil + } + + ext := &LoadedExtension{ + Path: "hooks.yml (compat)", + Handlers: make(map[EventType][]HandlerFunc), + } + + executor := hooks.NewExecutor(config, "", "") + + // Map PreToolUse → ToolCall + if matchers, ok := config.Hooks[hooks.PreToolUse]; ok && len(matchers) > 0 { + ext.Handlers[ToolCall] = []HandlerFunc{ + func(event Event, _ Context) Result { + tc, ok := event.(ToolCallEvent) + if !ok { + return nil + } + input := &hooks.PreToolUseInput{ + ToolName: tc.ToolName, + ToolInput: json.RawMessage(tc.Input), + } + output, err := executor.ExecuteHooks(context.Background(), hooks.PreToolUse, input) + if err != nil || output == nil { + return nil + } + if output.Decision == "block" { + return ToolCallResult{Block: true, Reason: output.Reason} + } + return nil + }, + } + } + + // Map PostToolUse → ToolResult + if matchers, ok := config.Hooks[hooks.PostToolUse]; ok && len(matchers) > 0 { + ext.Handlers[ToolResult] = []HandlerFunc{ + func(event Event, _ Context) Result { + tr, ok := event.(ToolResultEvent) + if !ok { + return nil + } + input := &hooks.PostToolUseInput{ + ToolName: tr.ToolName, + ToolInput: json.RawMessage(tr.Input), + ToolResponse: json.RawMessage(tr.Content), + } + _, _ = executor.ExecuteHooks(context.Background(), hooks.PostToolUse, input) + return nil // legacy hooks don't modify results + }, + } + } + + // Map UserPromptSubmit → Input + if matchers, ok := config.Hooks[hooks.UserPromptSubmit]; ok && len(matchers) > 0 { + ext.Handlers[Input] = []HandlerFunc{ + func(event Event, _ Context) Result { + ie, ok := event.(InputEvent) + if !ok { + return nil + } + input := &hooks.UserPromptSubmitInput{ + Prompt: ie.Text, + } + output, err := executor.ExecuteHooks(context.Background(), hooks.UserPromptSubmit, input) + if err != nil || output == nil { + return nil + } + if output.Decision == "block" { + return InputResult{Action: "handled"} + } + return nil + }, + } + } + + // Map Stop → AgentEnd + if matchers, ok := config.Hooks[hooks.Stop]; ok && len(matchers) > 0 { + ext.Handlers[AgentEnd] = []HandlerFunc{ + func(event Event, _ Context) Result { + ae, ok := event.(AgentEndEvent) + if !ok { + return nil + } + input := &hooks.StopInput{ + Response: ae.Response, + StopReason: ae.StopReason, + } + _, _ = executor.ExecuteHooks(context.Background(), hooks.Stop, input) + return nil + }, + } + } + + return ext +} diff --git a/internal/extensions/events.go b/internal/extensions/events.go new file mode 100644 index 00000000..5eac0b02 --- /dev/null +++ b/internal/extensions/events.go @@ -0,0 +1,69 @@ +// Package extensions implements a Pi-style in-process extension system for KIT. +// Extensions are plain Go files loaded at runtime via Yaegi (a Go interpreter). +// They register event handlers using an API object, enabling tool interception, +// input transformation, and lifecycle observation — all without recompilation. +package extensions + +// EventType identifies a point in KIT's lifecycle where extensions can hook in. +type EventType string + +const ( + // ToolCall fires before a tool executes. Handlers can block execution. + ToolCall EventType = "tool_call" + + // ToolExecutionStart fires when a tool begins executing. + ToolExecutionStart EventType = "tool_execution_start" + + // ToolExecutionEnd fires when a tool finishes executing. + ToolExecutionEnd EventType = "tool_execution_end" + + // ToolResult fires after a tool executes. Handlers can modify the result. + ToolResult EventType = "tool_result" + + // Input fires when user input is received. Handlers can transform or handle it. + Input EventType = "input" + + // BeforeAgentStart fires before the agent loop begins for a prompt. + BeforeAgentStart EventType = "before_agent_start" + + // AgentStart fires when the agent loop begins processing. + AgentStart EventType = "agent_start" + + // AgentEnd fires when the agent finishes responding. + AgentEnd EventType = "agent_end" + + // MessageStart fires when a new assistant message begins. + MessageStart EventType = "message_start" + + // MessageUpdate fires for each streaming text chunk. + MessageUpdate EventType = "message_update" + + // MessageEnd fires when the assistant message is complete. + MessageEnd EventType = "message_end" + + // SessionStart fires when a session is loaded or created. + SessionStart EventType = "session_start" + + // SessionShutdown fires when the application is closing. + SessionShutdown EventType = "session_shutdown" +) + +// AllEventTypes returns every supported event type. +func AllEventTypes() []EventType { + return []EventType{ + ToolCall, ToolExecutionStart, ToolExecutionEnd, ToolResult, + Input, BeforeAgentStart, AgentStart, AgentEnd, + MessageStart, MessageUpdate, MessageEnd, + SessionStart, SessionShutdown, + } +} + +// IsValid returns true if the event type is a recognised lifecycle event. +func (e EventType) IsValid() bool { + for _, valid := range AllEventTypes() { + if e == valid { + return true + } + } + return false +} diff --git a/internal/extensions/events_test.go b/internal/extensions/events_test.go new file mode 100644 index 00000000..d8c8973e --- /dev/null +++ b/internal/extensions/events_test.go @@ -0,0 +1,60 @@ +package extensions + +import "testing" + +func TestAllEventTypes_Count(t *testing.T) { + all := AllEventTypes() + if len(all) != 13 { + t.Fatalf("expected 13 event types, got %d", len(all)) + } +} + +func TestAllEventTypes_NoDuplicates(t *testing.T) { + seen := make(map[EventType]bool) + for _, et := range AllEventTypes() { + if seen[et] { + t.Fatalf("duplicate event type: %s", et) + } + seen[et] = true + } +} + +func TestEventType_IsValid(t *testing.T) { + for _, et := range AllEventTypes() { + if !et.IsValid() { + t.Errorf("expected %s to be valid", et) + } + } + + invalid := EventType("nonexistent_event") + if invalid.IsValid() { + t.Error("expected 'nonexistent_event' to be invalid") + } +} + +func TestEventType_TypeMethod(t *testing.T) { + tests := []struct { + event Event + want EventType + }{ + {ToolCallEvent{ToolName: "test"}, ToolCall}, + {ToolExecutionStartEvent{ToolName: "test"}, ToolExecutionStart}, + {ToolExecutionEndEvent{ToolName: "test"}, ToolExecutionEnd}, + {ToolResultEvent{ToolName: "test"}, ToolResult}, + {InputEvent{Text: "hello"}, Input}, + {BeforeAgentStartEvent{Prompt: "test"}, BeforeAgentStart}, + {AgentStartEvent{Prompt: "test"}, AgentStart}, + {AgentEndEvent{Response: "done"}, AgentEnd}, + {MessageStartEvent{}, MessageStart}, + {MessageUpdateEvent{Chunk: "hi"}, MessageUpdate}, + {MessageEndEvent{Content: "done"}, MessageEnd}, + {SessionStartEvent{SessionID: "abc"}, SessionStart}, + {SessionShutdownEvent{}, SessionShutdown}, + } + + for _, tt := range tests { + if got := tt.event.Type(); got != tt.want { + t.Errorf("event %T.Type() = %s, want %s", tt.event, got, tt.want) + } + } +} diff --git a/internal/extensions/loader.go b/internal/extensions/loader.go new file mode 100644 index 00000000..46ade83f --- /dev/null +++ b/internal/extensions/loader.go @@ -0,0 +1,301 @@ +package extensions + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/charmbracelet/log" + "github.com/traefik/yaegi/interp" + "github.com/traefik/yaegi/stdlib" +) + +// Discovery paths searched in order (lowest to highest precedence): +// +// ~/.config/kit/extensions/*.go global single files +// ~/.config/kit/extensions/*/main.go global subdirectories +// .kit/extensions/*.go project-local single files +// .kit/extensions/*/main.go project-local subdirectories +// +// Explicit paths passed via --extension / -e flags are appended last. + +// LoadExtensions discovers and loads extensions from standard locations and +// any extra paths. Each extension is loaded into its own Yaegi interpreter +// for isolation. Extensions that fail to load are logged and skipped. +func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) { + paths := discoverExtensionPaths(extraPaths) + if len(paths) == 0 { + return nil, nil + } + + var loaded []LoadedExtension + for _, p := range paths { + ext, err := loadSingleExtension(p) + if err != nil { + log.Warn("skipping extension", "path", p, "err", err) + continue + } + loaded = append(loaded, *ext) + log.Debug("loaded extension", "path", p, + "handlers", countHandlers(ext), + "tools", len(ext.Tools), + "commands", len(ext.Commands)) + } + return loaded, nil +} + +// discoverExtensionPaths returns deduplicated paths to extension files in +// load-order (global first, then project-local, then explicit). +func discoverExtensionPaths(extraPaths []string) []string { + seen := make(map[string]bool) + var paths []string + + add := func(p string) { + abs, err := filepath.Abs(p) + if err != nil { + return + } + if seen[abs] { + return + } + seen[abs] = true + paths = append(paths, abs) + } + + // Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/) + globalDir := globalExtensionsDir() + for _, p := range findExtensionsInDir(globalDir) { + add(p) + } + + // Project-local extensions: .kit/extensions/ + localDir := filepath.Join(".kit", "extensions") + for _, p := range findExtensionsInDir(localDir) { + add(p) + } + + // Explicit paths (highest precedence) + for _, p := range extraPaths { + info, err := os.Stat(p) + if err != nil { + continue + } + if info.IsDir() { + for _, found := range findExtensionsInDir(p) { + add(found) + } + } else if strings.HasSuffix(p, ".go") { + add(p) + } + } + + return paths +} + +// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs. +func findExtensionsInDir(dir string) []string { + info, err := os.Stat(dir) + if err != nil || !info.IsDir() { + return nil + } + + var results []string + + entries, err := os.ReadDir(dir) + if err != nil { + return nil + } + + for _, entry := range entries { + full := filepath.Join(dir, entry.Name()) + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") { + results = append(results, full) + } else if entry.IsDir() { + main := filepath.Join(full, "main.go") + if _, err := os.Stat(main); err == nil { + results = append(results, main) + } + } + } + return results +} + +// globalExtensionsDir returns the global extensions directory, respecting +// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/extensions. +func globalExtensionsDir() string { + base := os.Getenv("XDG_CONFIG_HOME") + if base == "" { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + base = filepath.Join(home, ".config") + } + return filepath.Join(base, "kit", "extensions") +} + +// loadSingleExtension loads one .go file into a fresh Yaegi interpreter, +// calls the Init(ext.API) function, and returns the registered handlers. +func loadSingleExtension(path string) (*LoadedExtension, error) { + ext := &LoadedExtension{ + Path: path, + Handlers: make(map[EventType][]HandlerFunc), + } + + // Create a fresh interpreter. + i := interp.New(interp.Options{}) + + // Expose a safe subset of the Go stdlib. + if err := i.Use(stdlib.Symbols); err != nil { + return nil, fmt.Errorf("loading stdlib symbols: %w", err) + } + + // Expose KIT's extension API types so the extension can + // import "kit/ext" and use ext.ToolCall, ext.API, etc. + if err := i.Use(Symbols()); err != nil { + return nil, fmt.Errorf("loading extension symbols: %w", err) + } + + // Read and evaluate the extension source file. + src, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading file: %w", err) + } + + if _, err := i.Eval(string(src)); err != nil { + return nil, fmt.Errorf("evaluating source: %w", err) + } + + // Extract the Init function. Extensions must export: + // func Init(api ext.API) + initVal, err := i.Eval("Init") + if err != nil { + return nil, fmt.Errorf("no Init function: %w", err) + } + + initFn, ok := initVal.Interface().(func(API)) + if !ok { + return nil, fmt.Errorf("Init has wrong signature (want func(ext.API), got %T)", initVal.Interface()) + } + + // Build the API object that wires typed registration methods back to + // the extension's internal handler map. Each method wraps the concrete + // handler into the internal HandlerFunc type. + reg := func(event EventType, fn HandlerFunc) { + ext.Handlers[event] = append(ext.Handlers[event], fn) + } + + api := API{ + onToolCall: func(h func(ToolCallEvent, Context) *ToolCallResult) { + reg(ToolCall, func(e Event, c Context) Result { + r := h(e.(ToolCallEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) { + reg(ToolExecutionStart, func(e Event, c Context) Result { + h(e.(ToolExecutionStartEvent), c) + return nil + }) + }, + onToolExecEnd: func(h func(ToolExecutionEndEvent, Context)) { + reg(ToolExecutionEnd, func(e Event, c Context) Result { + h(e.(ToolExecutionEndEvent), c) + return nil + }) + }, + onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) { + reg(ToolResult, func(e Event, c Context) Result { + r := h(e.(ToolResultEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onInput: func(h func(InputEvent, Context) *InputResult) { + reg(Input, func(e Event, c Context) Result { + r := h(e.(InputEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onBeforeAgentStart: func(h func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) { + reg(BeforeAgentStart, func(e Event, c Context) Result { + r := h(e.(BeforeAgentStartEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onAgentStart: func(h func(AgentStartEvent, Context)) { + reg(AgentStart, func(e Event, c Context) Result { + h(e.(AgentStartEvent), c) + return nil + }) + }, + onAgentEnd: func(h func(AgentEndEvent, Context)) { + reg(AgentEnd, func(e Event, c Context) Result { + h(e.(AgentEndEvent), c) + return nil + }) + }, + onMessageStart: func(h func(MessageStartEvent, Context)) { + reg(MessageStart, func(e Event, c Context) Result { + h(e.(MessageStartEvent), c) + return nil + }) + }, + onMessageUpdate: func(h func(MessageUpdateEvent, Context)) { + reg(MessageUpdate, func(e Event, c Context) Result { + h(e.(MessageUpdateEvent), c) + return nil + }) + }, + onMessageEnd: func(h func(MessageEndEvent, Context)) { + reg(MessageEnd, func(e Event, c Context) Result { + h(e.(MessageEndEvent), c) + return nil + }) + }, + onSessionStart: func(h func(SessionStartEvent, Context)) { + reg(SessionStart, func(e Event, c Context) Result { + h(e.(SessionStartEvent), c) + return nil + }) + }, + onSessionShutdown: func(h func(SessionShutdownEvent, Context)) { + reg(SessionShutdown, func(e Event, c Context) Result { + h(e.(SessionShutdownEvent), c) + return nil + }) + }, + registerToolFn: func(tool ToolDef) { + ext.Tools = append(ext.Tools, tool) + }, + registerCmdFn: func(cmd CommandDef) { + ext.Commands = append(ext.Commands, cmd) + }, + } + + // Call Init — the extension registers its handlers, tools, commands. + initFn(api) + + return ext, nil +} + +// countHandlers returns the total number of registered handlers across all events. +func countHandlers(ext *LoadedExtension) int { + n := 0 + for _, handlers := range ext.Handlers { + n += len(handlers) + } + return n +} diff --git a/internal/extensions/loader_test.go b/internal/extensions/loader_test.go new file mode 100644 index 00000000..e20dacfd --- /dev/null +++ b/internal/extensions/loader_test.go @@ -0,0 +1,604 @@ +package extensions + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiscoverExtensionPaths_ExplicitFile(t *testing.T) { + // Create a temp dir with a .go file. + dir := t.TempDir() + f := filepath.Join(dir, "my-ext.go") + if err := os.WriteFile(f, []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + + paths := discoverExtensionPaths([]string{f}) + if len(paths) == 0 { + t.Fatal("expected at least 1 path") + } + + abs, _ := filepath.Abs(f) + found := false + for _, p := range paths { + if p == abs { + found = true + break + } + } + if !found { + t.Errorf("expected %q in discovered paths %v", abs, paths) + } +} + +func TestDiscoverExtensionPaths_ExplicitDir(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "ext.go") + if err := os.WriteFile(f, []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + + paths := discoverExtensionPaths([]string{dir}) + abs, _ := filepath.Abs(f) + found := false + for _, p := range paths { + if p == abs { + found = true + break + } + } + if !found { + t.Errorf("expected %q in discovered paths %v", abs, paths) + } +} + +func TestDiscoverExtensionPaths_SubdirMainGo(t *testing.T) { + dir := t.TempDir() + subdir := filepath.Join(dir, "my-plugin") + if err := os.MkdirAll(subdir, 0755); err != nil { + t.Fatal(err) + } + main := filepath.Join(subdir, "main.go") + if err := os.WriteFile(main, []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + + paths := discoverExtensionPaths([]string{dir}) + abs, _ := filepath.Abs(main) + found := false + for _, p := range paths { + if p == abs { + found = true + break + } + } + if !found { + t.Errorf("expected %q in discovered paths %v", abs, paths) + } +} + +func TestDiscoverExtensionPaths_Dedup(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "ext.go") + if err := os.WriteFile(f, []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + + // Pass the same file twice. + paths := discoverExtensionPaths([]string{f, f}) + count := 0 + abs, _ := filepath.Abs(f) + for _, p := range paths { + if p == abs { + count++ + } + } + if count != 1 { + t.Errorf("expected dedup to 1, got %d", count) + } +} + +func TestDiscoverExtensionPaths_NonGoFileIgnored(t *testing.T) { + dir := t.TempDir() + f := filepath.Join(dir, "readme.txt") + if err := os.WriteFile(f, []byte("hello"), 0644); err != nil { + t.Fatal(err) + } + + paths := discoverExtensionPaths([]string{f}) + for _, p := range paths { + abs, _ := filepath.Abs(f) + if p == abs { + t.Error("non-.go file should not be discovered") + } + } +} + +func TestDiscoverExtensionPaths_NonexistentIgnored(t *testing.T) { + paths := discoverExtensionPaths([]string{"/nonexistent/path/ext.go"}) + for _, p := range paths { + if p == "/nonexistent/path/ext.go" { + t.Error("nonexistent path should not be discovered") + } + } +} + +func TestFindExtensionsInDir_EmptyDir(t *testing.T) { + dir := t.TempDir() + results := findExtensionsInDir(dir) + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} + +func TestFindExtensionsInDir_NonexistentDir(t *testing.T) { + results := findExtensionsInDir("/nonexistent/dir") + if len(results) != 0 { + t.Errorf("expected 0 results, got %d", len(results)) + } +} + +func TestFindExtensionsInDir_MixedContent(t *testing.T) { + dir := t.TempDir() + + // .go file at top level + if err := os.WriteFile(filepath.Join(dir, "ext.go"), []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + // non-.go file (should be ignored) + if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("hi"), 0644); err != nil { + t.Fatal(err) + } + // subdir with main.go + sub := filepath.Join(dir, "plugin") + if err := os.MkdirAll(sub, 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sub, "main.go"), []byte("package main"), 0644); err != nil { + t.Fatal(err) + } + // subdir without main.go (should be ignored) + empty := filepath.Join(dir, "empty") + if err := os.MkdirAll(empty, 0755); err != nil { + t.Fatal(err) + } + + results := findExtensionsInDir(dir) + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d: %v", len(results), results) + } +} + +func TestLoadSingleExtension_ValidExtension(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult { + return nil + }) + api.OnSessionStart(func(se ext.SessionStartEvent, ctx ext.Context) { + }) +} +` + f := filepath.Join(dir, "valid.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + if ext.Path != f { + t.Errorf("expected path %q, got %q", f, ext.Path) + } + if len(ext.Handlers[ToolCall]) != 1 { + t.Errorf("expected 1 ToolCall handler, got %d", len(ext.Handlers[ToolCall])) + } + if len(ext.Handlers[SessionStart]) != 1 { + t.Errorf("expected 1 SessionStart handler, got %d", len(ext.Handlers[SessionStart])) + } +} + +func TestLoadSingleExtension_NoInitFunction(t *testing.T) { + dir := t.TempDir() + src := `package main + +func Hello() string { return "hi" } +` + f := filepath.Join(dir, "noinit.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadSingleExtension(f) + if err == nil { + t.Fatal("expected error for missing Init function") + } +} + +func TestLoadSingleExtension_SyntaxError(t *testing.T) { + dir := t.TempDir() + src := `package main +func Init( { broken } +` + f := filepath.Join(dir, "broken.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadSingleExtension(f) + if err == nil { + t.Fatal("expected error for syntax error") + } +} + +func TestLoadSingleExtension_WrongSignature(t *testing.T) { + dir := t.TempDir() + src := `package main + +func Init(s string) {} +` + f := filepath.Join(dir, "wrongsig.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadSingleExtension(f) + if err == nil { + t.Fatal("expected error for wrong Init signature") + } +} + +func TestLoadSingleExtension_RegistersTool(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.RegisterTool(ext.ToolDef{ + Name: "my_tool", + Description: "does stuff", + Parameters: "{\"type\":\"object\"}", + Execute: func(input string) (string, error) { + return "result: " + input, nil + }, + }) +} +` + f := filepath.Join(dir, "toolreg.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + if len(ext.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(ext.Tools)) + } + if ext.Tools[0].Name != "my_tool" { + t.Errorf("expected tool name 'my_tool', got %q", ext.Tools[0].Name) + } +} + +func TestLoadSingleExtension_RegistersCommand(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.RegisterCommand(ext.CommandDef{ + Name: "hello", + Description: "says hello", + Execute: func(args string) (string, error) { + return "hello " + args, nil + }, + }) +} +` + f := filepath.Join(dir, "cmdreg.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + if len(ext.Commands) != 1 { + t.Fatalf("expected 1 command, got %d", len(ext.Commands)) + } + if ext.Commands[0].Name != "hello" { + t.Errorf("expected command name 'hello', got %q", ext.Commands[0].Name) + } +} + +func TestLoadExtensions_SkipsBadFiles(t *testing.T) { + dir := t.TempDir() + + // Good extension + good := `package main +import "kit/ext" +func Init(api ext.API) { + api.OnSessionStart(func(_ ext.SessionStartEvent, _ ext.Context) {}) +} +` + if err := os.WriteFile(filepath.Join(dir, "good.go"), []byte(good), 0644); err != nil { + t.Fatal(err) + } + + // Bad extension (syntax error) + bad := `package main +func Init( { broken } +` + if err := os.WriteFile(filepath.Join(dir, "bad.go"), []byte(bad), 0644); err != nil { + t.Fatal(err) + } + + loaded, err := LoadExtensions([]string{dir}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Should have loaded the good one and skipped the bad one. + if len(loaded) != 1 { + t.Fatalf("expected 1 loaded extension, got %d", len(loaded)) + } +} + +func TestLoadSingleExtension_HandlerExecution(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult { + if tc.ToolName == "banned" { + return &ext.ToolCallResult{Block: true, Reason: "tool is banned"} + } + return nil + }) +} +` + f := filepath.Join(dir, "blocker.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + + // Build a runner and test the handler actually works. + r := NewRunner([]LoadedExtension{*ext}) + result, err := r.Emit(ToolCallEvent{ToolName: "banned", Input: "{}"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tcr, ok := result.(ToolCallResult) + if !ok { + t.Fatalf("expected ToolCallResult, got %T", result) + } + if !tcr.Block { + t.Error("expected Block=true for banned tool") + } + if tcr.Reason != "tool is banned" { + t.Errorf("expected reason 'tool is banned', got %q", tcr.Reason) + } + + // Non-banned tool should pass through. + result2, _ := r.Emit(ToolCallEvent{ToolName: "allowed", Input: "{}"}) + if result2 != nil { + t.Errorf("expected nil result for allowed tool, got %v", result2) + } +} + +func TestGlobalExtensionsDir_XDG(t *testing.T) { + // Save and restore XDG_CONFIG_HOME. + orig := os.Getenv("XDG_CONFIG_HOME") + defer os.Setenv("XDG_CONFIG_HOME", orig) + + os.Setenv("XDG_CONFIG_HOME", "/custom/config") + dir := globalExtensionsDir() + expected := "/custom/config/kit/extensions" + if dir != expected { + t.Errorf("expected %q, got %q", expected, dir) + } +} + +func TestGlobalExtensionsDir_Default(t *testing.T) { + orig := os.Getenv("XDG_CONFIG_HOME") + defer os.Setenv("XDG_CONFIG_HOME", orig) + + os.Setenv("XDG_CONFIG_HOME", "") + dir := globalExtensionsDir() + home, _ := os.UserHomeDir() + expected := filepath.Join(home, ".config", "kit", "extensions") + if dir != expected { + t.Errorf("expected %q, got %q", expected, dir) + } +} + +func TestLoadSingleExtension_ContextPrint(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult { + if ie.Text == "!hello" && ctx.Print != nil { + ctx.Print("Hello from extension!") + return &ext.InputResult{Action: "handled"} + } + return nil + }) +} +` + f := filepath.Join(dir, "printer.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + + // Wire up a Print function and verify it's called. + var printed []string + r := NewRunner([]LoadedExtension{*ext}) + r.SetContext(Context{ + Print: func(text string) { + printed = append(printed, text) + }, + }) + + result, err := r.Emit(InputEvent{Text: "!hello", Source: "interactive"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ir, ok := result.(InputResult) + if !ok { + t.Fatalf("expected InputResult, got %T", result) + } + if ir.Action != "handled" { + t.Errorf("expected Action 'handled', got %q", ir.Action) + } + if len(printed) != 1 || printed[0] != "Hello from extension!" { + t.Errorf("expected Print to capture 'Hello from extension!', got %v", printed) + } +} + +func TestLoadSingleExtension_ContextPrintInfo(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult { + if ie.Text == "!info" && ctx.PrintInfo != nil { + ctx.PrintInfo("Styled info from extension") + return &ext.InputResult{Action: "handled"} + } + if ie.Text == "!error" && ctx.PrintError != nil { + ctx.PrintError("Styled error from extension") + return &ext.InputResult{Action: "handled"} + } + return nil + }) +} +` + f := filepath.Join(dir, "styled.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + + var infos, errors []string + r := NewRunner([]LoadedExtension{*ext}) + r.SetContext(Context{ + PrintInfo: func(text string) { infos = append(infos, text) }, + PrintError: func(text string) { errors = append(errors, text) }, + }) + + result, _ := r.Emit(InputEvent{Text: "!info"}) + if ir, ok := result.(InputResult); !ok || ir.Action != "handled" { + t.Fatal("expected handled result for !info") + } + if len(infos) != 1 || infos[0] != "Styled info from extension" { + t.Errorf("expected PrintInfo capture, got %v", infos) + } + + result, _ = r.Emit(InputEvent{Text: "!error"}) + if ir, ok := result.(InputResult); !ok || ir.Action != "handled" { + t.Fatal("expected handled result for !error") + } + if len(errors) != 1 || errors[0] != "Styled error from extension" { + t.Errorf("expected PrintError capture, got %v", errors) + } +} + +func TestLoadSingleExtension_ContextPrintBlock(t *testing.T) { + dir := t.TempDir() + src := `package main + +import "kit/ext" + +func Init(api ext.API) { + api.OnInput(func(ie ext.InputEvent, ctx ext.Context) *ext.InputResult { + if ie.Text == "!status" && ctx.PrintBlock != nil { + ctx.PrintBlock(ext.PrintBlockOpts{ + Text: "All systems go\nModel: " + ctx.Model, + BorderColor: "#a6e3a1", + Subtitle: "test-ext", + }) + return &ext.InputResult{Action: "handled"} + } + return nil + }) +} +` + f := filepath.Join(dir, "block.go") + if err := os.WriteFile(f, []byte(src), 0644); err != nil { + t.Fatal(err) + } + + ext, err := loadSingleExtension(f) + if err != nil { + t.Fatalf("failed to load extension: %v", err) + } + + var captured []PrintBlockOpts + r := NewRunner([]LoadedExtension{*ext}) + r.SetContext(Context{ + Model: "claude-4", + PrintBlock: func(opts PrintBlockOpts) { + captured = append(captured, opts) + }, + }) + + result, _ := r.Emit(InputEvent{Text: "!status", Source: "interactive"}) + if ir, ok := result.(InputResult); !ok || ir.Action != "handled" { + t.Fatal("expected handled result for !status") + } + if len(captured) != 1 { + t.Fatalf("expected 1 PrintBlock call, got %d", len(captured)) + } + if captured[0].BorderColor != "#a6e3a1" { + t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor) + } + if captured[0].Subtitle != "test-ext" { + t.Errorf("expected subtitle 'test-ext', got %q", captured[0].Subtitle) + } + // Verify the text includes the model from context. + if captured[0].Text != "All systems go\nModel: claude-4" { + t.Errorf("unexpected text: %q", captured[0].Text) + } +} + +func TestCountHandlers(t *testing.T) { + ext := &LoadedExtension{ + Handlers: map[EventType][]HandlerFunc{ + ToolCall: {func(Event, Context) Result { return nil }, func(Event, Context) Result { return nil }}, + SessionStart: {func(Event, Context) Result { return nil }}, + }, + } + if n := countHandlers(ext); n != 3 { + t.Errorf("expected 3 handlers, got %d", n) + } +} diff --git a/internal/extensions/runner.go b/internal/extensions/runner.go new file mode 100644 index 00000000..8336d356 --- /dev/null +++ b/internal/extensions/runner.go @@ -0,0 +1,146 @@ +package extensions + +import ( + "fmt" + "sync" + + "github.com/charmbracelet/log" +) + +// Runner manages loaded extensions and dispatches events to their handlers +// sequentially, mirroring Pi's ExtensionRunner. Handlers execute in extension +// load order; for cancellable events the first blocking result wins. +type Runner struct { + extensions []LoadedExtension + ctx Context + mu sync.RWMutex +} + +// LoadedExtension represents a single extension that has been discovered, +// loaded, and initialised. It holds the registered handlers and any custom +// tools or commands the extension provided. +type LoadedExtension struct { + Path string + Handlers map[EventType][]HandlerFunc + Tools []ToolDef + Commands []CommandDef +} + +// NewRunner creates a Runner from a set of loaded extensions. +func NewRunner(exts []LoadedExtension) *Runner { + return &Runner{extensions: exts} +} + +// SetContext updates the runtime context (session ID, model, etc.) that is +// passed to every handler invocation. Thread-safe. +func (r *Runner) SetContext(ctx Context) { + r.mu.Lock() + defer r.mu.Unlock() + r.ctx = ctx +} + +// HasHandlers returns true if any loaded extension has at least one handler +// registered for the given event type. +func (r *Runner) HasHandlers(event EventType) bool { + for i := range r.extensions { + if len(r.extensions[i].Handlers[event]) > 0 { + return true + } + } + return false +} + +// Emit dispatches an event to all matching handlers sequentially. It returns +// the accumulated result from all handlers, or nil if no handler responded. +// +// For blocking events (ToolCall, Input), the first blocking result short-circuits: +// - ToolCallResult{Block: true} stops iteration and returns immediately. +// - InputResult{Action: "handled"} stops iteration and returns immediately. +// +// For chainable events (ToolResult), each handler sees the accumulated result +// from previous handlers. The final merged result is returned. +// +// Panics in handlers are recovered and logged; they do not crash the process. +func (r *Runner) Emit(event Event) (Result, error) { + r.mu.RLock() + ctx := r.ctx + r.mu.RUnlock() + + var accumulated Result + + for i := range r.extensions { + ext := &r.extensions[i] + handlers := ext.Handlers[event.Type()] + for _, handler := range handlers { + result, err := safeCall(handler, event, ctx) + if err != nil { + log.Warn("extension handler error", + "path", ext.Path, + "event", event.Type(), + "err", err) + continue + } + if result == nil { + continue + } + + // Check for blocking/short-circuit results. + if isBlocking(result) { + return result, nil + } + + // Chain: keep the latest non-nil result. For ToolResultResult + // the caller is responsible for applying the modifications. + accumulated = result + } + } + return accumulated, nil +} + +// RegisteredTools returns all custom tools registered by loaded extensions. +func (r *Runner) RegisteredTools() []ToolDef { + var tools []ToolDef + for i := range r.extensions { + tools = append(tools, r.extensions[i].Tools...) + } + return tools +} + +// RegisteredCommands returns all slash commands registered by loaded extensions. +func (r *Runner) RegisteredCommands() []CommandDef { + var cmds []CommandDef + for i := range r.extensions { + cmds = append(cmds, r.extensions[i].Commands...) + } + return cmds +} + +// Extensions returns the loaded extensions for inspection (e.g. CLI list). +func (r *Runner) Extensions() []LoadedExtension { + return r.extensions +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// safeCall invokes a handler, recovering from panics. +func safeCall(handler HandlerFunc, event Event, ctx Context) (result Result, err error) { + defer func() { + if rec := recover(); rec != nil { + err = fmt.Errorf("extension panicked: %v", rec) + } + }() + return handler(event, ctx), nil +} + +// isBlocking returns true if the result should short-circuit further handlers. +func isBlocking(result Result) bool { + switch r := result.(type) { + case ToolCallResult: + return r.Block + case InputResult: + return r.Action == "handled" + } + return false +} diff --git a/internal/extensions/runner_test.go b/internal/extensions/runner_test.go new file mode 100644 index 00000000..ec67be80 --- /dev/null +++ b/internal/extensions/runner_test.go @@ -0,0 +1,573 @@ +package extensions + +import ( + "testing" +) + +// makeRunner builds a Runner with the given extensions for testing. +func makeRunner(exts ...LoadedExtension) *Runner { + return NewRunner(exts) +} + +// makeHandlerExt creates a LoadedExtension with handlers registered for the given events. +func makeHandlerExt(path string, handlers map[EventType][]HandlerFunc) LoadedExtension { + return LoadedExtension{ + Path: path, + Handlers: handlers, + } +} + +func TestRunner_EmitNoHandlers(t *testing.T) { + r := makeRunner() + result, err := r.Emit(ToolCallEvent{ToolName: "test"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != nil { + t.Fatalf("expected nil result, got %v", result) + } +} + +func TestRunner_EmitSequentialOrder(t *testing.T) { + var order []int + ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { order = append(order, 1); return nil }, + }, + }) + ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { order = append(order, 2); return nil }, + }, + }) + ext3 := makeHandlerExt("ext3.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { order = append(order, 3); return nil }, + }, + }) + + r := makeRunner(ext1, ext2, ext3) + _, err := r.Emit(SessionStartEvent{SessionID: "test"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 { + t.Fatalf("expected sequential order [1,2,3], got %v", order) + } +} + +func TestRunner_EmitMultipleHandlersPerExtension(t *testing.T) { + var calls int + ext := makeHandlerExt("multi.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { calls++; return nil }, + func(e Event, c Context) Result { calls++; return nil }, + }, + }) + + r := makeRunner(ext) + _, err := r.Emit(SessionStartEvent{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if calls != 2 { + t.Fatalf("expected 2 calls, got %d", calls) + } +} + +func TestRunner_EmitToolCallBlocking(t *testing.T) { + var secondCalled bool + ext1 := makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { + return ToolCallResult{Block: true, Reason: "denied"} + }, + }, + }) + ext2 := makeHandlerExt("second.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { + secondCalled = true + return nil + }, + }, + }) + + r := makeRunner(ext1, ext2) + result, err := r.Emit(ToolCallEvent{ToolName: "bash", Input: "{}"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if secondCalled { + t.Error("second handler should not have been called after block") + } + tcr, ok := result.(ToolCallResult) + if !ok { + t.Fatalf("expected ToolCallResult, got %T", result) + } + if !tcr.Block { + t.Error("expected Block=true") + } + if tcr.Reason != "denied" { + t.Errorf("expected reason 'denied', got %q", tcr.Reason) + } +} + +func TestRunner_EmitToolCallNonBlocking(t *testing.T) { + ext := makeHandlerExt("allow.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { + return ToolCallResult{Block: false} + }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(ToolCallEvent{ToolName: "bash"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tcr, ok := result.(ToolCallResult) + if !ok { + t.Fatalf("expected ToolCallResult, got %T", result) + } + if tcr.Block { + t.Error("expected Block=false for non-blocking result") + } +} + +func TestRunner_EmitInputBlocking(t *testing.T) { + ext := makeHandlerExt("input-handler.go", map[EventType][]HandlerFunc{ + Input: { + func(e Event, c Context) Result { + return InputResult{Action: "handled"} + }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(InputEvent{Text: "secret", Source: "interactive"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ir, ok := result.(InputResult) + if !ok { + t.Fatalf("expected InputResult, got %T", result) + } + if ir.Action != "handled" { + t.Errorf("expected Action 'handled', got %q", ir.Action) + } +} + +func TestRunner_EmitInputTransform(t *testing.T) { + ext := makeHandlerExt("transform.go", map[EventType][]HandlerFunc{ + Input: { + func(e Event, c Context) Result { + ie := e.(InputEvent) + return InputResult{Action: "transform", Text: ie.Text + " transformed"} + }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(InputEvent{Text: "hello", Source: "cli"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + ir, ok := result.(InputResult) + if !ok { + t.Fatalf("expected InputResult, got %T", result) + } + if ir.Action != "transform" { + t.Errorf("expected Action 'transform', got %q", ir.Action) + } + if ir.Text != "hello transformed" { + t.Errorf("expected transformed text, got %q", ir.Text) + } +} + +func TestRunner_EmitToolResultChaining(t *testing.T) { + modified := "modified content" + ext := makeHandlerExt("modifier.go", map[EventType][]HandlerFunc{ + ToolResult: { + func(e Event, c Context) Result { + return ToolResultResult{Content: &modified} + }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(ToolResultEvent{ToolName: "read", Content: "original"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + trr, ok := result.(ToolResultResult) + if !ok { + t.Fatalf("expected ToolResultResult, got %T", result) + } + if trr.Content == nil || *trr.Content != "modified content" { + t.Error("expected content to be modified") + } +} + +func TestRunner_EmitPanicRecovery(t *testing.T) { + var secondCalled bool + ext := makeHandlerExt("panicker.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { panic("boom") }, + func(e Event, c Context) Result { secondCalled = true; return nil }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(SessionStartEvent{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // After a panic, the runner should continue to the next handler. + if !secondCalled { + t.Error("second handler should still be called after panic in first") + } + if result != nil { + t.Errorf("expected nil result, got %v", result) + } +} + +func TestRunner_EmitEventPassedCorrectly(t *testing.T) { + var receivedName string + var receivedInput string + ext := makeHandlerExt("inspect.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { + tc := e.(ToolCallEvent) + receivedName = tc.ToolName + receivedInput = tc.Input + return nil + }, + }, + }) + + r := makeRunner(ext) + _, _ = r.Emit(ToolCallEvent{ToolName: "bash", ToolCallID: "123", Input: `{"cmd":"ls"}`}) + if receivedName != "bash" { + t.Errorf("expected tool name 'bash', got %q", receivedName) + } + if receivedInput != `{"cmd":"ls"}` { + t.Errorf("expected input '{\"cmd\":\"ls\"}', got %q", receivedInput) + } +} + +func TestRunner_SetContext(t *testing.T) { + var receivedCtx Context + ext := makeHandlerExt("ctx.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { + receivedCtx = c + return nil + }, + }, + }) + + r := makeRunner(ext) + r.SetContext(Context{ + SessionID: "sess-123", + CWD: "/tmp", + Model: "claude-4", + Interactive: true, + }) + + _, _ = r.Emit(SessionStartEvent{}) + if receivedCtx.SessionID != "sess-123" { + t.Errorf("expected SessionID 'sess-123', got %q", receivedCtx.SessionID) + } + if receivedCtx.CWD != "/tmp" { + t.Errorf("expected CWD '/tmp', got %q", receivedCtx.CWD) + } + if receivedCtx.Model != "claude-4" { + t.Errorf("expected Model 'claude-4', got %q", receivedCtx.Model) + } + if !receivedCtx.Interactive { + t.Error("expected Interactive=true") + } +} + +func TestRunner_HasHandlers(t *testing.T) { + ext := makeHandlerExt("test.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { return nil }, + }, + }) + + r := makeRunner(ext) + if !r.HasHandlers(ToolCall) { + t.Error("expected HasHandlers(ToolCall) = true") + } + if r.HasHandlers(SessionStart) { + t.Error("expected HasHandlers(SessionStart) = false") + } +} + +func TestRunner_RegisteredTools(t *testing.T) { + ext := LoadedExtension{ + Path: "tools.go", + Handlers: make(map[EventType][]HandlerFunc), + Tools: []ToolDef{ + {Name: "tool1", Description: "first"}, + {Name: "tool2", Description: "second"}, + }, + } + + r := makeRunner(ext) + tools := r.RegisteredTools() + if len(tools) != 2 { + t.Fatalf("expected 2 tools, got %d", len(tools)) + } + if tools[0].Name != "tool1" || tools[1].Name != "tool2" { + t.Error("tools not returned in expected order") + } +} + +func TestRunner_RegisteredCommands(t *testing.T) { + ext := LoadedExtension{ + Path: "cmds.go", + Handlers: make(map[EventType][]HandlerFunc), + Commands: []CommandDef{ + {Name: "cmd1", Description: "first"}, + }, + } + + r := makeRunner(ext) + cmds := r.RegisteredCommands() + if len(cmds) != 1 { + t.Fatalf("expected 1 command, got %d", len(cmds)) + } + if cmds[0].Name != "cmd1" { + t.Errorf("expected command name 'cmd1', got %q", cmds[0].Name) + } +} + +func TestRunner_Extensions(t *testing.T) { + ext1 := makeHandlerExt("a.go", map[EventType][]HandlerFunc{}) + ext2 := makeHandlerExt("b.go", map[EventType][]HandlerFunc{}) + r := makeRunner(ext1, ext2) + if len(r.Extensions()) != 2 { + t.Fatalf("expected 2 extensions, got %d", len(r.Extensions())) + } +} + +func TestRunner_EmitOnlyMatchingEvent(t *testing.T) { + var called bool + ext := makeHandlerExt("mismatch.go", map[EventType][]HandlerFunc{ + ToolCall: { + func(e Event, c Context) Result { called = true; return nil }, + }, + }) + + r := makeRunner(ext) + _, _ = r.Emit(SessionStartEvent{}) // different event type + if called { + t.Error("ToolCall handler should not be called for SessionStart event") + } +} + +func TestRunner_EmitBeforeAgentStartResult(t *testing.T) { + injected := "extra context" + ext := makeHandlerExt("inject.go", map[EventType][]HandlerFunc{ + BeforeAgentStart: { + func(e Event, c Context) Result { + return BeforeAgentStartResult{InjectText: &injected} + }, + }, + }) + + r := makeRunner(ext) + result, err := r.Emit(BeforeAgentStartEvent{Prompt: "hello"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + bar, ok := result.(BeforeAgentStartResult) + if !ok { + t.Fatalf("expected BeforeAgentStartResult, got %T", result) + } + if bar.InjectText == nil || *bar.InjectText != "extra context" { + t.Error("expected InjectText to be set") + } +} + +func TestRunner_LastResultWins(t *testing.T) { + // When multiple handlers return non-nil, non-blocking results, + // the last one should be returned (accumulated). + first := "first" + second := "second" + ext := makeHandlerExt("chain.go", map[EventType][]HandlerFunc{ + ToolResult: { + func(e Event, c Context) Result { + return ToolResultResult{Content: &first} + }, + func(e Event, c Context) Result { + return ToolResultResult{Content: &second} + }, + }, + }) + + r := makeRunner(ext) + result, _ := r.Emit(ToolResultEvent{ToolName: "test", Content: "orig"}) + trr := result.(ToolResultResult) + if trr.Content == nil || *trr.Content != "second" { + t.Errorf("expected last result to win, got %v", trr.Content) + } +} + +func TestRunner_ContextPrint(t *testing.T) { + var printed []string + var receivedCtx Context + ext := makeHandlerExt("print.go", map[EventType][]HandlerFunc{ + Input: { + func(e Event, c Context) Result { + receivedCtx = c + if c.Print != nil { + c.Print("hello from extension") + } + return nil + }, + }, + }) + + r := makeRunner(ext) + r.SetContext(Context{ + Print: func(text string) { + printed = append(printed, text) + }, + }) + + _, _ = r.Emit(InputEvent{Text: "test"}) + if receivedCtx.Print == nil { + t.Fatal("expected Print to be non-nil in context") + } + if len(printed) != 1 || printed[0] != "hello from extension" { + t.Errorf("expected Print to capture 'hello from extension', got %v", printed) + } +} + +func TestRunner_ContextPrintInfo(t *testing.T) { + var infos []string + ext := makeHandlerExt("info.go", map[EventType][]HandlerFunc{ + SessionStart: { + func(e Event, c Context) Result { + if c.PrintInfo != nil { + c.PrintInfo("extension loaded successfully") + } + return nil + }, + }, + }) + + r := makeRunner(ext) + r.SetContext(Context{ + PrintInfo: func(text string) { + infos = append(infos, text) + }, + }) + + _, _ = r.Emit(SessionStartEvent{}) + if len(infos) != 1 || infos[0] != "extension loaded successfully" { + t.Errorf("expected PrintInfo to capture message, got %v", infos) + } +} + +func TestRunner_ContextPrintError(t *testing.T) { + var errors []string + ext := makeHandlerExt("err.go", map[EventType][]HandlerFunc{ + ToolResult: { + func(e Event, c Context) Result { + tr := e.(ToolResultEvent) + if tr.IsError && c.PrintError != nil { + c.PrintError("tool failed: " + tr.ToolName) + } + return nil + }, + }, + }) + + r := makeRunner(ext) + r.SetContext(Context{ + PrintError: func(text string) { + errors = append(errors, text) + }, + }) + + _, _ = r.Emit(ToolResultEvent{ToolName: "bash", IsError: true, Content: "exit 1"}) + if len(errors) != 1 || errors[0] != "tool failed: bash" { + t.Errorf("expected PrintError to capture message, got %v", errors) + } +} + +func TestRunner_ContextPrintBlock(t *testing.T) { + var captured []PrintBlockOpts + ext := makeHandlerExt("block.go", map[EventType][]HandlerFunc{ + Input: { + func(e Event, c Context) Result { + if c.PrintBlock != nil { + c.PrintBlock(PrintBlockOpts{ + Text: "deploy complete", + BorderColor: "#a6e3a1", + Subtitle: "deploy-ext", + }) + } + return InputResult{Action: "handled"} + }, + }, + }) + + r := makeRunner(ext) + r.SetContext(Context{ + PrintBlock: func(opts PrintBlockOpts) { + captured = append(captured, opts) + }, + }) + + _, _ = r.Emit(InputEvent{Text: "!deploy"}) + if len(captured) != 1 { + t.Fatalf("expected 1 PrintBlock call, got %d", len(captured)) + } + if captured[0].Text != "deploy complete" { + t.Errorf("expected text 'deploy complete', got %q", captured[0].Text) + } + if captured[0].BorderColor != "#a6e3a1" { + t.Errorf("expected border '#a6e3a1', got %q", captured[0].BorderColor) + } + if captured[0].Subtitle != "deploy-ext" { + t.Errorf("expected subtitle 'deploy-ext', got %q", captured[0].Subtitle) + } +} + +func TestRunner_ContextPrintNilSafe(t *testing.T) { + // When Print/PrintInfo/PrintError/PrintBlock are not set (nil), guarded calls should not panic. + ext := makeHandlerExt("nilprint.go", map[EventType][]HandlerFunc{ + Input: { + func(e Event, c Context) Result { + if c.Print != nil { + c.Print("should not happen") + } + if c.PrintInfo != nil { + c.PrintInfo("should not happen") + } + if c.PrintError != nil { + c.PrintError("should not happen") + } + if c.PrintBlock != nil { + c.PrintBlock(PrintBlockOpts{Text: "nope"}) + } + return nil + }, + }, + }) + + r := makeRunner(ext) + // Context without any Print functions set. + r.SetContext(Context{Model: "test"}) + _, err := r.Emit(InputEvent{Text: "test"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/internal/extensions/symbols.go b/internal/extensions/symbols.go new file mode 100644 index 00000000..39ec1c2f --- /dev/null +++ b/internal/extensions/symbols.go @@ -0,0 +1,49 @@ +package extensions + +import ( + "reflect" + + "github.com/traefik/yaegi/interp" +) + +// Symbols returns the Yaegi export table that makes KIT's extension API +// available to interpreted Go code. Extensions import these types as: +// +// import "kit/ext" +// +// IMPORTANT: Only concrete types (structs, constants) are exported. Interfaces +// (Event, Result) and the HandlerFunc type are NOT exported because Yaegi +// cannot generate interface wrappers for them. Instead, extensions use +// event-specific methods like api.OnToolCall() which accept concrete function +// signatures. +func Symbols() interp.Exports { + return interp.Exports{ + "kit/ext/ext": map[string]reflect.Value{ + // Struct types (nil pointer trick for type registration) + "API": reflect.ValueOf((*API)(nil)), + "Context": reflect.ValueOf((*Context)(nil)), + "ToolDef": reflect.ValueOf((*ToolDef)(nil)), + "CommandDef": reflect.ValueOf((*CommandDef)(nil)), + "PrintBlockOpts": reflect.ValueOf((*PrintBlockOpts)(nil)), + + // Event structs + "ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)), + "ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)), + "ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)), + "ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)), + "ToolResultEvent": reflect.ValueOf((*ToolResultEvent)(nil)), + "ToolResultResult": reflect.ValueOf((*ToolResultResult)(nil)), + "InputEvent": reflect.ValueOf((*InputEvent)(nil)), + "InputResult": reflect.ValueOf((*InputResult)(nil)), + "BeforeAgentStartEvent": reflect.ValueOf((*BeforeAgentStartEvent)(nil)), + "BeforeAgentStartResult": reflect.ValueOf((*BeforeAgentStartResult)(nil)), + "AgentStartEvent": reflect.ValueOf((*AgentStartEvent)(nil)), + "AgentEndEvent": reflect.ValueOf((*AgentEndEvent)(nil)), + "MessageStartEvent": reflect.ValueOf((*MessageStartEvent)(nil)), + "MessageUpdateEvent": reflect.ValueOf((*MessageUpdateEvent)(nil)), + "MessageEndEvent": reflect.ValueOf((*MessageEndEvent)(nil)), + "SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)), + "SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)), + }, + } +} diff --git a/internal/extensions/wrapper.go b/internal/extensions/wrapper.go new file mode 100644 index 00000000..fd106311 --- /dev/null +++ b/internal/extensions/wrapper.go @@ -0,0 +1,134 @@ +package extensions + +import ( + "context" + "fmt" + + "charm.land/fantasy" +) + +// WrapToolsWithExtensions wraps each tool so that ToolCall and ToolResult +// events are emitted through the extension runner before and after execution. +// This is the Go equivalent of Pi's wrapper.ts pattern. +// +// If the runner has no relevant handlers the original tools are returned +// unchanged (zero overhead). +func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantasy.AgentTool { + if runner == nil { + return tools + } + if !runner.HasHandlers(ToolCall) && !runner.HasHandlers(ToolResult) && + !runner.HasHandlers(ToolExecutionStart) && !runner.HasHandlers(ToolExecutionEnd) { + return tools + } + + wrapped := make([]fantasy.AgentTool, len(tools)) + for i, tool := range tools { + wrapped[i] = &wrappedTool{inner: tool, runner: runner} + } + return wrapped +} + +// ExtensionToolsAsFantasy converts ToolDef values registered by extensions +// into fantasy.AgentTool implementations so the LLM can invoke them. +func ExtensionToolsAsFantasy(defs []ToolDef) []fantasy.AgentTool { + tools := make([]fantasy.AgentTool, 0, len(defs)) + for _, def := range defs { + tools = append(tools, &extensionTool{def: def}) + } + return tools +} + +// --------------------------------------------------------------------------- +// wrappedTool — intercepts tool calls through the extension runner +// --------------------------------------------------------------------------- + +type wrappedTool struct { + inner fantasy.AgentTool + runner *Runner +} + +func (w *wrappedTool) Info() fantasy.ToolInfo { return w.inner.Info() } +func (w *wrappedTool) ProviderOptions() fantasy.ProviderOptions { return w.inner.ProviderOptions() } +func (w *wrappedTool) SetProviderOptions(o fantasy.ProviderOptions) { w.inner.SetProviderOptions(o) } + +func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + toolName := w.inner.Info().Name + + // 1. Emit ToolCall — extensions can block execution. + if w.runner.HasHandlers(ToolCall) { + result, _ := w.runner.Emit(ToolCallEvent{ + ToolName: toolName, + ToolCallID: call.ID, + Input: call.Input, + }) + if r, ok := result.(ToolCallResult); ok && r.Block { + reason := r.Reason + if reason == "" { + reason = "blocked by extension" + } + return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)), + fmt.Errorf("tool blocked by extension: %s", reason) + } + } + + // 2. Emit ToolExecutionStart. + if w.runner.HasHandlers(ToolExecutionStart) { + _, _ = w.runner.Emit(ToolExecutionStartEvent{ToolName: toolName}) + } + + // 3. Execute the actual tool. + resp, err := w.inner.Run(ctx, call) + + // 4. Emit ToolExecutionEnd. + if w.runner.HasHandlers(ToolExecutionEnd) { + _, _ = w.runner.Emit(ToolExecutionEndEvent{ToolName: toolName}) + } + + // 5. Emit ToolResult — extensions can modify output. + if w.runner.HasHandlers(ToolResult) { + result, _ := w.runner.Emit(ToolResultEvent{ + ToolName: toolName, + Input: call.Input, + Content: resp.Content, + IsError: err != nil || resp.IsError, + }) + if r, ok := result.(ToolResultResult); ok { + if r.Content != nil { + resp.Content = *r.Content + } + if r.IsError != nil { + resp.IsError = *r.IsError + } + } + } + + return resp, err +} + +// --------------------------------------------------------------------------- +// extensionTool — wraps a ToolDef into a fantasy.AgentTool +// --------------------------------------------------------------------------- + +type extensionTool struct { + def ToolDef + providerOptions fantasy.ProviderOptions +} + +func (t *extensionTool) Info() fantasy.ToolInfo { + return fantasy.ToolInfo{ + Name: t.def.Name, + Description: t.def.Description, + } +} + +func (t *extensionTool) ProviderOptions() fantasy.ProviderOptions { return t.providerOptions } +func (t *extensionTool) SetProviderOptions(o fantasy.ProviderOptions) { t.providerOptions = o } + +func (t *extensionTool) Run(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + result, err := t.def.Execute(call.Input) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), err + } + return fantasy.NewTextResponse(result), nil +} diff --git a/internal/extensions/wrapper_test.go b/internal/extensions/wrapper_test.go new file mode 100644 index 00000000..e4359d88 --- /dev/null +++ b/internal/extensions/wrapper_test.go @@ -0,0 +1,241 @@ +package extensions + +import ( + "context" + "testing" + + "charm.land/fantasy" +) + +// mockTool implements fantasy.AgentTool for testing. +type mockTool struct { + name string + runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) + provOpt fantasy.ProviderOptions +} + +func (m *mockTool) Info() fantasy.ToolInfo { + return fantasy.ToolInfo{Name: m.name, Description: "mock tool"} +} +func (m *mockTool) ProviderOptions() fantasy.ProviderOptions { return m.provOpt } +func (m *mockTool) SetProviderOptions(o fantasy.ProviderOptions) { m.provOpt = o } +func (m *mockTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + if m.runFn != nil { + return m.runFn(ctx, call) + } + return fantasy.NewTextResponse("ok"), nil +} + +func newMockTool(name string) *mockTool { + return &mockTool{name: name} +} + +func TestWrapToolsWithExtensions_NilRunner(t *testing.T) { + tools := []fantasy.AgentTool{newMockTool("test")} + result := WrapToolsWithExtensions(tools, nil) + if len(result) != 1 { + t.Fatalf("expected 1 tool, got %d", len(result)) + } + // Should be the same pointer (unwrapped). + if result[0] != tools[0] { + t.Error("expected original tool when runner is nil") + } +} + +func TestWrapToolsWithExtensions_NoRelevantHandlers(t *testing.T) { + r := makeRunner(makeHandlerExt("other.go", map[EventType][]HandlerFunc{ + SessionStart: {func(e Event, c Context) Result { return nil }}, + })) + tools := []fantasy.AgentTool{newMockTool("test")} + result := WrapToolsWithExtensions(tools, r) + if result[0] != tools[0] { + t.Error("expected original tool when no tool handlers exist") + } +} + +func TestWrapToolsWithExtensions_WrapsWhenHandlersExist(t *testing.T) { + r := makeRunner(makeHandlerExt("tc.go", map[EventType][]HandlerFunc{ + ToolCall: {func(e Event, c Context) Result { return nil }}, + })) + tools := []fantasy.AgentTool{newMockTool("test")} + result := WrapToolsWithExtensions(tools, r) + if result[0] == tools[0] { + t.Error("expected wrapped tool when ToolCall handlers exist") + } + // Verify Info() is passed through. + if result[0].Info().Name != "test" { + t.Errorf("expected name 'test', got %q", result[0].Info().Name) + } +} + +func TestWrappedTool_NormalExecution(t *testing.T) { + var toolCallSeen, toolResultSeen bool + r := makeRunner(makeHandlerExt("observe.go", map[EventType][]HandlerFunc{ + ToolCall: {func(e Event, c Context) Result { + toolCallSeen = true + return nil + }}, + ToolResult: {func(e Event, c Context) Result { + toolResultSeen = true + return nil + }}, + })) + + mock := newMockTool("bash") + mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("output"), nil + } + + tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r) + resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1", Input: "{}"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "output" { + t.Errorf("expected 'output', got %q", resp.Content) + } + if !toolCallSeen { + t.Error("ToolCall handler was not invoked") + } + if !toolResultSeen { + t.Error("ToolResult handler was not invoked") + } +} + +func TestWrappedTool_BlockExecution(t *testing.T) { + var toolRan bool + r := makeRunner(makeHandlerExt("blocker.go", map[EventType][]HandlerFunc{ + ToolCall: {func(e Event, c Context) Result { + return ToolCallResult{Block: true, Reason: "forbidden"} + }}, + })) + + mock := newMockTool("danger") + mock.runFn = func(_ context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + toolRan = true + return fantasy.NewTextResponse("bad"), nil + } + + tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r) + resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"}) + if toolRan { + t.Error("tool should not have run after block") + } + if err == nil { + t.Error("expected error from blocked tool") + } + if resp.IsError != true { + t.Error("expected IsError=true from blocked response") + } +} + +func TestWrappedTool_ModifyResult(t *testing.T) { + modified := "redacted" + r := makeRunner(makeHandlerExt("redactor.go", map[EventType][]HandlerFunc{ + ToolCall: {func(e Event, c Context) Result { return nil }}, + ToolResult: {func(e Event, c Context) Result { + return ToolResultResult{Content: &modified} + }}, + })) + + mock := newMockTool("read") + mock.runFn = func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { + return fantasy.NewTextResponse("secret data"), nil + } + + tools := WrapToolsWithExtensions([]fantasy.AgentTool{mock}, r) + resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "redacted" { + t.Errorf("expected 'redacted', got %q", resp.Content) + } +} + +func TestWrappedTool_ExecutionStartEnd(t *testing.T) { + var startSeen, endSeen bool + r := makeRunner(makeHandlerExt("lifecycle.go", map[EventType][]HandlerFunc{ + ToolCall: {func(e Event, c Context) Result { return nil }}, + ToolExecutionStart: {func(e Event, c Context) Result { startSeen = true; return nil }}, + ToolExecutionEnd: {func(e Event, c Context) Result { endSeen = true; return nil }}, + })) + + tools := WrapToolsWithExtensions([]fantasy.AgentTool{newMockTool("test")}, r) + _, _ = tools[0].Run(context.Background(), fantasy.ToolCall{ID: "1"}) + if !startSeen { + t.Error("ToolExecutionStart not emitted") + } + if !endSeen { + t.Error("ToolExecutionEnd not emitted") + } +} + +func TestExtensionToolsAsFantasy(t *testing.T) { + defs := []ToolDef{ + { + Name: "greet", + Description: "greets someone", + Parameters: `{"type":"object"}`, + Execute: func(input string) (string, error) { return "hello " + input, nil }, + }, + } + + tools := ExtensionToolsAsFantasy(defs) + if len(tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(tools)) + } + + info := tools[0].Info() + if info.Name != "greet" { + t.Errorf("expected name 'greet', got %q", info.Name) + } + if info.Description != "greets someone" { + t.Errorf("expected description 'greets someone', got %q", info.Description) + } + + resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "world"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Content != "hello world" { + t.Errorf("expected 'hello world', got %q", resp.Content) + } +} + +func TestExtensionTool_Error(t *testing.T) { + defs := []ToolDef{ + { + Name: "fail", + Execute: func(input string) (string, error) { return "", context.DeadlineExceeded }, + }, + } + + tools := ExtensionToolsAsFantasy(defs) + resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"}) + if err == nil { + t.Error("expected error") + } + if !resp.IsError { + t.Error("expected IsError=true") + } +} + +func TestExtensionTool_ProviderOptions(t *testing.T) { + defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}} + tools := ExtensionToolsAsFantasy(defs) + + // Initially nil. + opts := tools[0].ProviderOptions() + if opts != nil { + t.Error("expected nil ProviderOptions initially") + } + + // SetProviderOptions round-trips. + po := fantasy.ProviderOptions{} + tools[0].SetProviderOptions(po) + got := tools[0].ProviderOptions() + if got == nil { + t.Error("expected non-nil ProviderOptions after set") + } +} diff --git a/internal/ui/cli.go b/internal/ui/cli.go index 7ab4b7a6..f18fbf65 100644 --- a/internal/ui/cli.go +++ b/internal/ui/cli.go @@ -177,6 +177,32 @@ func (c *CLI) DisplayInfo(message string) { c.displayContainer() } +// DisplayExtensionBlock renders a custom styled block with the given border +// color and optional subtitle. Used by extensions via ctx.PrintBlock. +func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) { + theme := GetTheme() + + var borderClr = lipgloss.Color("#89b4fa") + if borderColor != "" { + borderClr = lipgloss.Color(borderColor) + } + + content := text + if subtitle != "" { + sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + subtitle) + content = content + "\n" + sub + } + + rendered := renderContentBlock( + content, + c.messageRenderer.width, + WithAlign(lipgloss.Left), + WithBorderColor(borderClr), + WithMarginBottom(1), + ) + fmt.Println(rendered) +} + // DisplayCancellation displays a system message indicating that the current // AI generation has been cancelled by the user (typically via ESC key). func (c *CLI) DisplayCancellation() { diff --git a/internal/ui/event_handler.go b/internal/ui/event_handler.go index 808dbc44..69e68584 100644 --- a/internal/ui/event_handler.go +++ b/internal/ui/event_handler.go @@ -129,6 +129,19 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) { h.lastDisplayed = e.Content } + case app.ExtensionPrintEvent: + h.stopSpinner() + switch e.Level { + case "info": + h.cli.DisplayInfo(e.Text) + case "error": + h.cli.DisplayError(fmt.Errorf("%s", e.Text)) + case "block": + h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle) + default: + fmt.Println(e.Text) + } + case app.StepCompleteEvent: h.stopSpinner() diff --git a/internal/ui/model.go b/internal/ui/model.go index 3e9a6455..6211a37e 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -530,6 +530,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.state = stateInput m.canceling = false + case app.ExtensionPrintEvent: + // Extension output — route through styled renderers when a level is set. + switch msg.Level { + case "info": + cmds = append(cmds, m.printSystemMessage(msg.Text)) + case "error": + cmds = append(cmds, m.printErrorResponse(app.StepErrorEvent{ + Err: fmt.Errorf("%s", msg.Text), + })) + case "block": + cmds = append(cmds, m.printExtensionBlock(msg)) + default: + cmds = append(cmds, tea.Println(msg.Text)) + } + default: // Pass unrecognised messages to all children. if m.input != nil { @@ -791,6 +806,34 @@ func (m *AppModel) printSystemMessage(text string) tea.Cmd { return tea.Println(rendered) } +// printExtensionBlock renders a custom styled block from an extension with +// caller-chosen border color and optional subtitle, then emits it to scrollback. +func (m *AppModel) printExtensionBlock(evt app.ExtensionPrintEvent) tea.Cmd { + theme := GetTheme() + + // Resolve border color: use the extension's hex value, fall back to theme accent. + var borderClr = lipgloss.Color("#89b4fa") // default blue + if evt.BorderColor != "" { + borderClr = lipgloss.Color(evt.BorderColor) + } + + // Build content: main text + optional subtitle line. + content := evt.Text + if evt.Subtitle != "" { + sub := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" " + evt.Subtitle) + content = strings.TrimSuffix(content, "\n") + "\n" + sub + } + + rendered := renderContentBlock( + content, + m.width, + WithAlign(lipgloss.Left), + WithBorderColor(borderClr), + WithMarginBottom(1), + ) + return tea.Println(rendered) +} + // printHelpMessage renders the help text listing all available slash commands. func (m *AppModel) printHelpMessage() tea.Cmd { help := "## Available Commands\n\n" +