mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
aede76d807
- ctx.SuspendTUI(callback): releases terminal for interactive subprocesses (vim, shell, htop), automatically restores TUI when callback returns. Uses BubbleTea v2 ReleaseTerminal/RestoreTerminal. - api.RegisterMessageRenderer(config) + ctx.RenderMessage(name, content): named render functions for branded/styled extension output. Renderers receive content and terminal width, return ANSI-styled strings. - ctx.ReloadExtensions(): hot-reloads all extensions from disk. Emits SessionShutdown to old extensions, reloads source, emits SessionStart to new. Event handlers, commands, renderers, shortcuts update immediately. TUI command list refreshes via WidgetUpdateEvent. Extension tools are NOT updated (baked into agent at creation, documented limitation). New example extensions: interactive-shell.go, branded-output.go, dev-reload.go
368 lines
10 KiB
Go
368 lines
10 KiB
Go
package extensions
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/traefik/yaegi/interp"
|
|
"github.com/traefik/yaegi/stdlib"
|
|
"github.com/traefik/yaegi/stdlib/unrestricted"
|
|
)
|
|
|
|
// Discovery paths searched in order (lowest to highest precedence):
|
|
//
|
|
// ~/.config/kit/extensions/*.go global single files
|
|
// ~/.config/kit/extensions/*/main.go global subdirectories
|
|
// .kit/extensions/*.go project-local single files
|
|
// .kit/extensions/*/main.go project-local subdirectories
|
|
//
|
|
// Explicit paths passed via --extension / -e flags are appended last.
|
|
|
|
// LoadExtensions discovers and loads extensions from standard locations and
|
|
// any extra paths. Each extension is loaded into its own Yaegi interpreter
|
|
// for isolation. Extensions that fail to load are logged and skipped.
|
|
func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
|
paths := discoverExtensionPaths(extraPaths)
|
|
if len(paths) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var loaded []LoadedExtension
|
|
for _, p := range paths {
|
|
ext, err := loadSingleExtension(p)
|
|
if err != nil {
|
|
log.Warn("skipping extension", "path", p, "err", err)
|
|
continue
|
|
}
|
|
loaded = append(loaded, *ext)
|
|
log.Debug("loaded extension", "path", p,
|
|
"handlers", countHandlers(ext),
|
|
"tools", len(ext.Tools),
|
|
"commands", len(ext.Commands),
|
|
"tool_renderers", len(ext.ToolRenderers))
|
|
}
|
|
return loaded, nil
|
|
}
|
|
|
|
// discoverExtensionPaths returns deduplicated paths to extension files in
|
|
// load-order (global first, then project-local, then explicit).
|
|
func discoverExtensionPaths(extraPaths []string) []string {
|
|
seen := make(map[string]bool)
|
|
var paths []string
|
|
|
|
add := func(p string) {
|
|
abs, err := filepath.Abs(p)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if seen[abs] {
|
|
return
|
|
}
|
|
seen[abs] = true
|
|
paths = append(paths, abs)
|
|
}
|
|
|
|
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
|
globalDir := globalExtensionsDir()
|
|
for _, p := range findExtensionsInDir(globalDir) {
|
|
add(p)
|
|
}
|
|
|
|
// Project-local extensions: .kit/extensions/
|
|
localDir := filepath.Join(".kit", "extensions")
|
|
for _, p := range findExtensionsInDir(localDir) {
|
|
add(p)
|
|
}
|
|
|
|
// Explicit paths (highest precedence)
|
|
for _, p := range extraPaths {
|
|
info, err := os.Stat(p)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if info.IsDir() {
|
|
for _, found := range findExtensionsInDir(p) {
|
|
add(found)
|
|
}
|
|
} else if strings.HasSuffix(p, ".go") {
|
|
add(p)
|
|
}
|
|
}
|
|
|
|
return paths
|
|
}
|
|
|
|
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
|
func findExtensionsInDir(dir string) []string {
|
|
info, err := os.Stat(dir)
|
|
if err != nil || !info.IsDir() {
|
|
return nil
|
|
}
|
|
|
|
var results []string
|
|
|
|
entries, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
full := filepath.Join(dir, entry.Name())
|
|
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
|
results = append(results, full)
|
|
} else if entry.IsDir() {
|
|
main := filepath.Join(full, "main.go")
|
|
if _, err := os.Stat(main); err == nil {
|
|
results = append(results, main)
|
|
}
|
|
}
|
|
}
|
|
return results
|
|
}
|
|
|
|
// globalExtensionsDir returns the global extensions directory, respecting
|
|
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/extensions.
|
|
func globalExtensionsDir() string {
|
|
base := os.Getenv("XDG_CONFIG_HOME")
|
|
if base == "" {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
base = filepath.Join(home, ".config")
|
|
}
|
|
return filepath.Join(base, "kit", "extensions")
|
|
}
|
|
|
|
// loadSingleExtension loads one .go file into a fresh Yaegi interpreter,
|
|
// calls the Init(ext.API) function, and returns the registered handlers.
|
|
func loadSingleExtension(path string) (*LoadedExtension, error) {
|
|
ext := &LoadedExtension{
|
|
Path: path,
|
|
Handlers: make(map[EventType][]HandlerFunc),
|
|
}
|
|
|
|
// Create a fresh interpreter.
|
|
i := interp.New(interp.Options{})
|
|
|
|
// Expose the Go stdlib. The base set covers most packages; the
|
|
// unrestricted set adds os/exec so extensions can spawn processes.
|
|
if err := i.Use(stdlib.Symbols); err != nil {
|
|
return nil, fmt.Errorf("loading stdlib symbols: %w", err)
|
|
}
|
|
if err := i.Use(unrestricted.Symbols); err != nil {
|
|
return nil, fmt.Errorf("loading unrestricted symbols: %w", err)
|
|
}
|
|
|
|
// Expose KIT's extension API types so the extension can
|
|
// import "kit/ext" and use ext.ToolCall, ext.API, etc.
|
|
if err := i.Use(Symbols()); err != nil {
|
|
return nil, fmt.Errorf("loading extension symbols: %w", err)
|
|
}
|
|
|
|
// Read and evaluate the extension source file.
|
|
src, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading file: %w", err)
|
|
}
|
|
|
|
if _, err := i.Eval(string(src)); err != nil {
|
|
return nil, fmt.Errorf("evaluating source: %w", err)
|
|
}
|
|
|
|
// Extract the Init function. Extensions must export:
|
|
// func Init(api ext.API)
|
|
initVal, err := i.Eval("Init")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no Init function: %w", err)
|
|
}
|
|
|
|
initFn, ok := initVal.Interface().(func(API))
|
|
if !ok {
|
|
return nil, fmt.Errorf("init has wrong signature (want func(ext.API), got %T)", initVal.Interface())
|
|
}
|
|
|
|
// Build the API object that wires typed registration methods back to
|
|
// the extension's internal handler map. Each method wraps the concrete
|
|
// handler into the internal HandlerFunc type.
|
|
reg := func(event EventType, fn HandlerFunc) {
|
|
ext.Handlers[event] = append(ext.Handlers[event], fn)
|
|
}
|
|
|
|
api := API{
|
|
onToolCall: func(h func(ToolCallEvent, Context) *ToolCallResult) {
|
|
reg(ToolCall, func(e Event, c Context) Result {
|
|
r := h(e.(ToolCallEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
|
|
reg(ToolExecutionStart, func(e Event, c Context) Result {
|
|
h(e.(ToolExecutionStartEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onToolExecEnd: func(h func(ToolExecutionEndEvent, Context)) {
|
|
reg(ToolExecutionEnd, func(e Event, c Context) Result {
|
|
h(e.(ToolExecutionEndEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onToolResult: func(h func(ToolResultEvent, Context) *ToolResultResult) {
|
|
reg(ToolResult, func(e Event, c Context) Result {
|
|
r := h(e.(ToolResultEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onInput: func(h func(InputEvent, Context) *InputResult) {
|
|
reg(Input, func(e Event, c Context) Result {
|
|
r := h(e.(InputEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onBeforeAgentStart: func(h func(BeforeAgentStartEvent, Context) *BeforeAgentStartResult) {
|
|
reg(BeforeAgentStart, func(e Event, c Context) Result {
|
|
r := h(e.(BeforeAgentStartEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onAgentStart: func(h func(AgentStartEvent, Context)) {
|
|
reg(AgentStart, func(e Event, c Context) Result {
|
|
h(e.(AgentStartEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onAgentEnd: func(h func(AgentEndEvent, Context)) {
|
|
reg(AgentEnd, func(e Event, c Context) Result {
|
|
h(e.(AgentEndEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onMessageStart: func(h func(MessageStartEvent, Context)) {
|
|
reg(MessageStart, func(e Event, c Context) Result {
|
|
h(e.(MessageStartEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onMessageUpdate: func(h func(MessageUpdateEvent, Context)) {
|
|
reg(MessageUpdate, func(e Event, c Context) Result {
|
|
h(e.(MessageUpdateEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onMessageEnd: func(h func(MessageEndEvent, Context)) {
|
|
reg(MessageEnd, func(e Event, c Context) Result {
|
|
h(e.(MessageEndEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onSessionStart: func(h func(SessionStartEvent, Context)) {
|
|
reg(SessionStart, func(e Event, c Context) Result {
|
|
h(e.(SessionStartEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onSessionShutdown: func(h func(SessionShutdownEvent, Context)) {
|
|
reg(SessionShutdown, func(e Event, c Context) Result {
|
|
h(e.(SessionShutdownEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onModelChange: func(h func(ModelChangeEvent, Context)) {
|
|
reg(ModelChange, func(e Event, c Context) Result {
|
|
h(e.(ModelChangeEvent), c)
|
|
return nil
|
|
})
|
|
},
|
|
onContextPrepare: func(h func(ContextPrepareEvent, Context) *ContextPrepareResult) {
|
|
reg(ContextPrepare, func(e Event, c Context) Result {
|
|
r := h(e.(ContextPrepareEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onBeforeFork: func(h func(BeforeForkEvent, Context) *BeforeForkResult) {
|
|
reg(BeforeFork, func(e Event, c Context) Result {
|
|
r := h(e.(BeforeForkEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onBeforeSessionSwitch: func(h func(BeforeSessionSwitchEvent, Context) *BeforeSessionSwitchResult) {
|
|
reg(BeforeSessionSwitch, func(e Event, c Context) Result {
|
|
r := h(e.(BeforeSessionSwitchEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
onBeforeCompact: func(h func(BeforeCompactEvent, Context) *BeforeCompactResult) {
|
|
reg(BeforeCompact, func(e Event, c Context) Result {
|
|
r := h(e.(BeforeCompactEvent), c)
|
|
if r == nil {
|
|
return nil
|
|
}
|
|
return *r
|
|
})
|
|
},
|
|
registerToolFn: func(tool ToolDef) {
|
|
ext.Tools = append(ext.Tools, tool)
|
|
},
|
|
registerCmdFn: func(cmd CommandDef) {
|
|
ext.Commands = append(ext.Commands, cmd)
|
|
},
|
|
registerToolRendererFn: func(config ToolRenderConfig) {
|
|
ext.ToolRenderers = append(ext.ToolRenderers, config)
|
|
},
|
|
registerMessageRendererFn: func(config MessageRendererConfig) {
|
|
ext.MessageRenderers = append(ext.MessageRenderers, config)
|
|
},
|
|
onCustomEvent: func(name string, handler func(string)) {
|
|
if ext.CustomEventHandlers == nil {
|
|
ext.CustomEventHandlers = make(map[string][]func(string))
|
|
}
|
|
ext.CustomEventHandlers[name] = append(ext.CustomEventHandlers[name], handler)
|
|
},
|
|
registerOption: func(opt OptionDef) {
|
|
ext.Options = append(ext.Options, opt)
|
|
},
|
|
registerShortcutFn: func(def ShortcutDef, handler func(Context)) {
|
|
ext.Shortcuts = append(ext.Shortcuts, ShortcutEntry{Def: def, Handler: handler})
|
|
},
|
|
}
|
|
|
|
// Call Init — the extension registers its handlers, tools, commands.
|
|
initFn(api)
|
|
|
|
return ext, nil
|
|
}
|
|
|
|
// countHandlers returns the total number of registered handlers across all events.
|
|
func countHandlers(ext *LoadedExtension) int {
|
|
n := 0
|
|
for _, handlers := range ext.Handlers {
|
|
n += len(handlers)
|
|
}
|
|
return n
|
|
}
|