mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
8f5efee837
Add three new extension events that allow extensions to gate destructive session operations and compaction: - OnBeforeFork: fires before branching in the tree selector; handler can cancel with reason (e.g. dirty-repo guard) - OnBeforeSessionSwitch: fires before /new resets the session branch; handler can cancel with reason - OnBeforeCompact: fires before context compaction (auto or manual); handler receives token stats and IsAutomatic flag, can cancel Includes SDK hook registry (beforeCompact), extension bridge, UI callbacks threaded through AppModelOptions, and two example extensions: - confirm-destructive.go: git dirty check + fork confirmation - compact-notify.go: compaction notification + auto-compact gating
539 lines
16 KiB
Go
539 lines
16 KiB
Go
package extensions
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// Runner manages loaded extensions and dispatches events to their handlers
|
|
// sequentially. Handlers execute in extension
|
|
// load order; for cancellable events the first blocking result wins.
|
|
type Runner struct {
|
|
extensions []LoadedExtension
|
|
ctx Context
|
|
widgets map[string]WidgetConfig // keyed by widget ID
|
|
statusEntries map[string]StatusBarEntry // keyed by status key
|
|
header *HeaderFooterConfig // nil = no custom header
|
|
footer *HeaderFooterConfig // nil = no custom footer
|
|
customEditor *EditorConfig // nil = no custom editor interceptor
|
|
uiVisibility *UIVisibility // nil = show everything (default)
|
|
disabledTools map[string]bool // nil = all tools enabled
|
|
customEventSubs map[string][]func(string) // inter-extension event bus
|
|
optionOverrides map[string]string // runtime option overrides
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// LoadedExtension represents a single extension that has been discovered,
|
|
// loaded, and initialised. It holds the registered handlers and any custom
|
|
// tools, commands, or tool renderers the extension provided.
|
|
type LoadedExtension struct {
|
|
Path string
|
|
Handlers map[EventType][]HandlerFunc
|
|
Tools []ToolDef
|
|
Commands []CommandDef
|
|
ToolRenderers []ToolRenderConfig
|
|
CustomEventHandlers map[string][]func(string) // inter-extension event bus
|
|
Options []OptionDef // registered configuration options
|
|
}
|
|
|
|
// NewRunner creates a Runner from a set of loaded extensions.
|
|
func NewRunner(exts []LoadedExtension) *Runner {
|
|
return &Runner{extensions: exts}
|
|
}
|
|
|
|
// SetContext updates the runtime context (session ID, model, etc.) that is
|
|
// passed to every handler invocation. Thread-safe.
|
|
func (r *Runner) SetContext(ctx Context) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.ctx = ctx
|
|
}
|
|
|
|
// HasHandlers returns true if any loaded extension has at least one handler
|
|
// registered for the given event type.
|
|
func (r *Runner) HasHandlers(event EventType) bool {
|
|
for i := range r.extensions {
|
|
if len(r.extensions[i].Handlers[event]) > 0 {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Emit dispatches an event to all matching handlers sequentially. It returns
|
|
// the accumulated result from all handlers, or nil if no handler responded.
|
|
//
|
|
// For blocking events (ToolCall, Input), the first blocking result short-circuits:
|
|
// - ToolCallResult{Block: true} stops iteration and returns immediately.
|
|
// - InputResult{Action: "handled"} stops iteration and returns immediately.
|
|
//
|
|
// For chainable events (ToolResult), each handler sees the accumulated result
|
|
// from previous handlers. The final merged result is returned.
|
|
//
|
|
// Panics in handlers are recovered and logged; they do not crash the process.
|
|
func (r *Runner) Emit(event Event) (Result, error) {
|
|
r.mu.RLock()
|
|
ctx := r.ctx
|
|
r.mu.RUnlock()
|
|
|
|
var accumulated Result
|
|
|
|
for i := range r.extensions {
|
|
ext := &r.extensions[i]
|
|
handlers := ext.Handlers[event.Type()]
|
|
for _, handler := range handlers {
|
|
result, err := safeCall(handler, event, ctx)
|
|
if err != nil {
|
|
log.Warn("extension handler error",
|
|
"path", ext.Path,
|
|
"event", event.Type(),
|
|
"err", err)
|
|
continue
|
|
}
|
|
if result == nil {
|
|
continue
|
|
}
|
|
|
|
// Check for blocking/short-circuit results.
|
|
if isBlocking(result) {
|
|
return result, nil
|
|
}
|
|
|
|
// Chain: keep the latest non-nil result. For ToolResultResult
|
|
// the caller is responsible for applying the modifications.
|
|
accumulated = result
|
|
}
|
|
}
|
|
return accumulated, nil
|
|
}
|
|
|
|
// RegisteredTools returns all custom tools registered by loaded extensions.
|
|
func (r *Runner) RegisteredTools() []ToolDef {
|
|
var tools []ToolDef
|
|
for i := range r.extensions {
|
|
tools = append(tools, r.extensions[i].Tools...)
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// RegisteredCommands returns all slash commands registered by loaded extensions.
|
|
func (r *Runner) RegisteredCommands() []CommandDef {
|
|
var cmds []CommandDef
|
|
for i := range r.extensions {
|
|
cmds = append(cmds, r.extensions[i].Commands...)
|
|
}
|
|
return cmds
|
|
}
|
|
|
|
// GetContext returns the current runtime context. Thread-safe.
|
|
func (r *Runner) GetContext() Context {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
return r.ctx
|
|
}
|
|
|
|
// Extensions returns the loaded extensions for inspection (e.g. CLI list).
|
|
func (r *Runner) Extensions() []LoadedExtension {
|
|
return r.extensions
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Widget management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetWidget places or updates a persistent widget. The widget is identified
|
|
// by config.ID; calling SetWidget with the same ID replaces the previous
|
|
// content. Thread-safe.
|
|
func (r *Runner) SetWidget(config WidgetConfig) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.widgets == nil {
|
|
r.widgets = make(map[string]WidgetConfig)
|
|
}
|
|
r.widgets[config.ID] = config
|
|
}
|
|
|
|
// RemoveWidget removes a widget by ID. No-op if the ID does not exist.
|
|
// Thread-safe.
|
|
func (r *Runner) RemoveWidget(id string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
delete(r.widgets, id)
|
|
}
|
|
|
|
// GetWidgets returns all widgets matching the given placement, sorted by
|
|
// priority (ascending). Thread-safe.
|
|
func (r *Runner) GetWidgets(placement WidgetPlacement) []WidgetConfig {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
var result []WidgetConfig
|
|
for _, w := range r.widgets {
|
|
if w.Placement == placement {
|
|
result = append(result, w)
|
|
}
|
|
}
|
|
sort.Slice(result, func(i, j int) bool {
|
|
if result[i].Priority != result[j].Priority {
|
|
return result[i].Priority < result[j].Priority
|
|
}
|
|
return result[i].ID < result[j].ID // stable tie-break
|
|
})
|
|
return result
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Status bar management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetStatusEntry places or updates a keyed status bar entry. Thread-safe.
|
|
func (r *Runner) SetStatusEntry(entry StatusBarEntry) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.statusEntries == nil {
|
|
r.statusEntries = make(map[string]StatusBarEntry)
|
|
}
|
|
r.statusEntries[entry.Key] = entry
|
|
}
|
|
|
|
// RemoveStatusEntry removes a status bar entry by key. Thread-safe.
|
|
func (r *Runner) RemoveStatusEntry(key string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
delete(r.statusEntries, key)
|
|
}
|
|
|
|
// GetStatusEntries returns all status bar entries, sorted by priority
|
|
// (ascending). Thread-safe.
|
|
func (r *Runner) GetStatusEntries() []StatusBarEntry {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
result := make([]StatusBarEntry, 0, len(r.statusEntries))
|
|
for _, e := range r.statusEntries {
|
|
result = append(result, e)
|
|
}
|
|
sort.Slice(result, func(i, j int) bool {
|
|
if result[i].Priority != result[j].Priority {
|
|
return result[i].Priority < result[j].Priority
|
|
}
|
|
return result[i].Key < result[j].Key
|
|
})
|
|
return result
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Header/Footer management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetHeader places or replaces the custom header. Thread-safe.
|
|
func (r *Runner) SetHeader(config HeaderFooterConfig) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.header = &config
|
|
}
|
|
|
|
// RemoveHeader removes the custom header. No-op if none is set. Thread-safe.
|
|
func (r *Runner) RemoveHeader() {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.header = nil
|
|
}
|
|
|
|
// GetHeader returns the current custom header, or nil if none is set.
|
|
// Thread-safe.
|
|
func (r *Runner) GetHeader() *HeaderFooterConfig {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
if r.header == nil {
|
|
return nil
|
|
}
|
|
// Return a copy to avoid races on the caller side.
|
|
h := *r.header
|
|
return &h
|
|
}
|
|
|
|
// SetFooter places or replaces the custom footer. Thread-safe.
|
|
func (r *Runner) SetFooter(config HeaderFooterConfig) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.footer = &config
|
|
}
|
|
|
|
// RemoveFooter removes the custom footer. No-op if none is set. Thread-safe.
|
|
func (r *Runner) RemoveFooter() {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.footer = nil
|
|
}
|
|
|
|
// GetFooter returns the current custom footer, or nil if none is set.
|
|
// Thread-safe.
|
|
func (r *Runner) GetFooter() *HeaderFooterConfig {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
if r.footer == nil {
|
|
return nil
|
|
}
|
|
// Return a copy to avoid races on the caller side.
|
|
f := *r.footer
|
|
return &f
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Editor interceptor management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetEditor installs an editor interceptor that wraps the built-in input
|
|
// editor. Only one interceptor is active at a time; calling SetEditor replaces
|
|
// any previous interceptor. Thread-safe.
|
|
func (r *Runner) SetEditor(config EditorConfig) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.customEditor = &config
|
|
}
|
|
|
|
// ResetEditor removes the active editor interceptor and restores the default
|
|
// built-in editor behavior. No-op if no interceptor is set. Thread-safe.
|
|
func (r *Runner) ResetEditor() {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.customEditor = nil
|
|
}
|
|
|
|
// GetEditor returns the current editor interceptor, or nil if none is set.
|
|
// Thread-safe. Returns a shallow copy — function fields are reference types
|
|
// so the copy is safe.
|
|
func (r *Runner) GetEditor() *EditorConfig {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
if r.customEditor == nil {
|
|
return nil
|
|
}
|
|
e := *r.customEditor
|
|
return &e
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// UI visibility management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetUIVisibility updates the UI visibility overrides. Thread-safe.
|
|
func (r *Runner) SetUIVisibility(v UIVisibility) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
r.uiVisibility = &v
|
|
}
|
|
|
|
// GetUIVisibility returns the current UI visibility overrides, or nil if
|
|
// none have been set (meaning show everything). Thread-safe.
|
|
func (r *Runner) GetUIVisibility() *UIVisibility {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
if r.uiVisibility == nil {
|
|
return nil
|
|
}
|
|
v := *r.uiVisibility
|
|
return &v
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tool renderer management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GetToolRenderer returns the custom renderer for the named tool, or nil if
|
|
// no extension registered a renderer for it. If multiple extensions register
|
|
// renderers for the same tool, the last one (by load order) wins. Thread-safe
|
|
// (extensions are immutable after loading).
|
|
func (r *Runner) GetToolRenderer(toolName string) *ToolRenderConfig {
|
|
// Walk extensions in reverse so last-registered wins.
|
|
for i := len(r.extensions) - 1; i >= 0; i-- {
|
|
for j := len(r.extensions[i].ToolRenderers) - 1; j >= 0; j-- {
|
|
if r.extensions[i].ToolRenderers[j].ToolName == toolName {
|
|
config := r.extensions[i].ToolRenderers[j]
|
|
return &config
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Inter-extension event bus
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SubscribeCustomEvent registers a handler for a named custom event. Handlers
|
|
// execute in registration order when EmitCustomEvent is called. Thread-safe.
|
|
func (r *Runner) SubscribeCustomEvent(name string, handler func(string)) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.customEventSubs == nil {
|
|
r.customEventSubs = make(map[string][]func(string))
|
|
}
|
|
r.customEventSubs[name] = append(r.customEventSubs[name], handler)
|
|
}
|
|
|
|
// EmitCustomEvent dispatches a named event to all subscribed handlers.
|
|
// Handlers run synchronously in extension load order. Panics are recovered
|
|
// and logged. Thread-safe.
|
|
func (r *Runner) EmitCustomEvent(name, data string) {
|
|
// Collect handlers: extension-registered (Init-time) + dynamic subs.
|
|
r.mu.RLock()
|
|
dynamicHandlers := r.customEventSubs[name]
|
|
r.mu.RUnlock()
|
|
|
|
safeInvoke := func(h func(string)) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
log.Warn("custom event handler panicked",
|
|
"event", name,
|
|
"err", fmt.Sprintf("%v", rec))
|
|
}
|
|
}()
|
|
h(data)
|
|
}
|
|
|
|
// Extension-registered handlers first (in load order).
|
|
for i := range r.extensions {
|
|
for _, h := range r.extensions[i].CustomEventHandlers[name] {
|
|
safeInvoke(h)
|
|
}
|
|
}
|
|
// Then dynamic subscriptions.
|
|
for _, h := range dynamicHandlers {
|
|
safeInvoke(h)
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tool management
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SetActiveTools restricts the tool set to the named tools. All tools not in
|
|
// the list are disabled. Passing nil or an empty slice re-enables all tools.
|
|
// Thread-safe.
|
|
func (r *Runner) SetActiveTools(names []string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if len(names) == 0 {
|
|
r.disabledTools = nil
|
|
return
|
|
}
|
|
active := make(map[string]bool, len(names))
|
|
for _, n := range names {
|
|
active[n] = true
|
|
}
|
|
r.disabledTools = active // non-nil = only these tools are allowed
|
|
}
|
|
|
|
// IsToolDisabled returns true if the tool has been disabled via SetActiveTools.
|
|
// Thread-safe.
|
|
func (r *Runner) IsToolDisabled(toolName string) bool {
|
|
r.mu.RLock()
|
|
defer r.mu.RUnlock()
|
|
if r.disabledTools == nil {
|
|
return false // no filter = all enabled
|
|
}
|
|
return !r.disabledTools[toolName]
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Extension options
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// GetOption resolves a named option value in priority order:
|
|
// 1. Runtime override (via SetOption)
|
|
// 2. Environment variable: KIT_OPT_<NAME> (uppercased, dashes → underscores)
|
|
// 3. Viper config: options.<name>
|
|
// 4. Default value from RegisterOption
|
|
//
|
|
// Returns empty string if the option was never registered.
|
|
// Thread-safe.
|
|
func (r *Runner) GetOption(name string) string {
|
|
// 1. Runtime override.
|
|
r.mu.RLock()
|
|
if v, ok := r.optionOverrides[name]; ok {
|
|
r.mu.RUnlock()
|
|
return v
|
|
}
|
|
r.mu.RUnlock()
|
|
|
|
// 2. Environment variable: KIT_OPT_<NAME>
|
|
envKey := "KIT_OPT_" + strings.ToUpper(strings.ReplaceAll(name, "-", "_"))
|
|
if v := os.Getenv(envKey); v != "" {
|
|
return v
|
|
}
|
|
|
|
// 3. Viper config: options.<name>
|
|
configKey := "options." + name
|
|
if v := viper.GetString(configKey); v != "" {
|
|
return v
|
|
}
|
|
|
|
// 4. Default from registered option defs.
|
|
for i := range r.extensions {
|
|
for _, opt := range r.extensions[i].Options {
|
|
if opt.Name == name {
|
|
return opt.Default
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// SetOption stores a runtime override for a named option. This takes highest
|
|
// priority over env vars, config, and defaults. Thread-safe.
|
|
func (r *Runner) SetOption(name, value string) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if r.optionOverrides == nil {
|
|
r.optionOverrides = make(map[string]string)
|
|
}
|
|
r.optionOverrides[name] = value
|
|
}
|
|
|
|
// RegisteredOptions returns all option definitions from all loaded extensions.
|
|
func (r *Runner) RegisteredOptions() []OptionDef {
|
|
var opts []OptionDef
|
|
for i := range r.extensions {
|
|
opts = append(opts, r.extensions[i].Options...)
|
|
}
|
|
return opts
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// safeCall invokes a handler, recovering from panics.
|
|
func safeCall(handler HandlerFunc, event Event, ctx Context) (result Result, err error) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
err = fmt.Errorf("extension panicked: %v", rec)
|
|
}
|
|
}()
|
|
return handler(event, ctx), nil
|
|
}
|
|
|
|
// isBlocking returns true if the result should short-circuit further handlers.
|
|
func isBlocking(result Result) bool {
|
|
switch r := result.(type) {
|
|
case ToolCallResult:
|
|
return r.Block
|
|
case InputResult:
|
|
return r.Action == "handled"
|
|
case BeforeForkResult:
|
|
return r.Cancel
|
|
case BeforeSessionSwitchResult:
|
|
return r.Cancel
|
|
case BeforeCompactResult:
|
|
return r.Cancel
|
|
}
|
|
return false
|
|
}
|