diff --git a/cmd/root.go b/cmd/root.go index e24cad3e..10189350 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -489,6 +489,26 @@ func statusBarProviderForUI(k *kit.Kit) func() []ui.StatusBarEntryData { } } +// beforeForkProviderForUI returns a callback that emits a BeforeFork event +// and returns (cancelled, reason). Returns nil if extensions are disabled — +// the UI treats nil as "no hook". +func beforeForkProviderForUI(k *kit.Kit) func(string, bool, string) (bool, string) { + if !k.HasExtensions() { + return nil + } + return k.EmitBeforeFork +} + +// beforeSessionSwitchProviderForUI returns a callback that emits a +// BeforeSessionSwitch event and returns (cancelled, reason). Returns nil +// if extensions are disabled — the UI treats nil as "no hook". +func beforeSessionSwitchProviderForUI(k *kit.Kit) func(string) (bool, string) { + if !k.HasExtensions() { + return nil + } + return k.EmitBeforeSessionSwitch +} + func runNormalMode(ctx context.Context) error { // Validate flag combinations if quietFlag && promptFlag == "" { @@ -840,10 +860,12 @@ func runNormalMode(ctx context.Context) error { getEditorInterceptor := editorInterceptorProviderForUI(kitInstance) getUIVisibility := uiVisibilityProviderForUI(kitInstance) getStatusBarEntries := statusBarProviderForUI(kitInstance) + emitBeforeFork := beforeForkProviderForUI(kitInstance) + emitBeforeSessionSwitch := beforeSessionSwitchProviderForUI(kitInstance) // Check if running in non-interactive mode if promptFlag != "" { - return runNonInteractiveModeApp(ctx, appInstance, cli, promptFlag, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries) + return runNonInteractiveModeApp(ctx, appInstance, cli, promptFlag, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch) } // Quiet mode is not allowed in interactive mode @@ -851,7 +873,7 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet flag can only be used with --prompt/-p") } - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries) + return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch) } // runNonInteractiveModeApp executes a single prompt via the app layer and exits, @@ -864,7 +886,7 @@ func runNormalMode(ctx context.Context) error { // // When --no-exit is set, after the prompt completes the interactive BubbleTea // TUI is started so the user can continue the conversation. -func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData) error { +func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string)) error { if jsonOutput { // JSON mode: no intermediate display, structured JSON output. result, err := appInstance.RunOnceResult(ctx, prompt) @@ -902,7 +924,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui // If --no-exit was requested, hand off to the interactive TUI. if noExit { - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries) + return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch) } return nil @@ -996,7 +1018,7 @@ func writeJSONError(err error) { // 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit). // // SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering. -func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData) error { +func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string)) error { // Determine terminal size; fall back gracefully. termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd())) if err != nil || termWidth == 0 { @@ -1005,27 +1027,29 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN } appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{ - CompactMode: viper.GetBool("compact"), - ModelName: modelName, - ProviderName: providerName, - LoadingMessage: loadingMessage, - Width: termWidth, - Height: termHeight, - ServerNames: serverNames, - ToolNames: toolNames, - MCPToolCount: mcpToolCount, - ExtensionToolCount: extensionToolCount, - UsageTracker: usageTracker, - ExtensionCommands: extCommands, - ContextPaths: contextPaths, - SkillItems: skillItems, - GetWidgets: getWidgets, - GetHeader: getHeader, - GetFooter: getFooter, - GetToolRenderer: getToolRenderer, - GetEditorInterceptor: getEditorInterceptor, - GetUIVisibility: getUIVisibility, - GetStatusBarEntries: getStatusBarEntries, + CompactMode: viper.GetBool("compact"), + ModelName: modelName, + ProviderName: providerName, + LoadingMessage: loadingMessage, + Width: termWidth, + Height: termHeight, + ServerNames: serverNames, + ToolNames: toolNames, + MCPToolCount: mcpToolCount, + ExtensionToolCount: extensionToolCount, + UsageTracker: usageTracker, + ExtensionCommands: extCommands, + ContextPaths: contextPaths, + SkillItems: skillItems, + GetWidgets: getWidgets, + GetHeader: getHeader, + GetFooter: getFooter, + GetToolRenderer: getToolRenderer, + GetEditorInterceptor: getEditorInterceptor, + GetUIVisibility: getUIVisibility, + GetStatusBarEntries: getStatusBarEntries, + EmitBeforeFork: emitBeforeFork, + EmitBeforeSessionSwitch: emitBeforeSessionSwitch, }) // Print startup info to stdout before Bubble Tea takes over the screen. diff --git a/examples/extensions/compact-notify.go b/examples/extensions/compact-notify.go new file mode 100644 index 00000000..fdedfa66 --- /dev/null +++ b/examples/extensions/compact-notify.go @@ -0,0 +1,56 @@ +//go:build ignore + +package main + +import ( + "fmt" + + "kit/ext" +) + +// Init registers a before-compact hook that notifies the user when +// compaction is about to happen and optionally blocks automatic compaction. +// +// When automatic compaction is triggered (via --auto-compact), the extension +// asks for user confirmation. Manual /compact commands are always allowed. +// +// This demonstrates the OnBeforeCompact event which allows extensions to +// inspect context usage stats and gate the compaction process. +// +// Usage: kit -e examples/extensions/compact-notify.go --auto-compact +func Init(api ext.API) { + api.OnBeforeCompact(func(e ext.BeforeCompactEvent, ctx ext.Context) *ext.BeforeCompactResult { + pct := int(e.UsagePercent * 100) + summary := fmt.Sprintf("Context: %dk/%dk tokens (%d%%), %d messages", + e.EstimatedTokens/1000, e.ContextLimit/1000, pct, e.MessageCount) + + if e.IsAutomatic { + // Auto-compaction: ask user first. + ctx.PrintBlock(ext.PrintBlockOpts{ + Text: "Auto-compaction triggered.\n" + summary, + BorderColor: "#f9e2af", + Subtitle: "compact-notify", + }) + + result := ctx.PromptConfirm(ext.PromptConfirmConfig{ + Message: "Allow automatic compaction?", + DefaultValue: true, + }) + if result.Cancelled || !result.Value { + return &ext.BeforeCompactResult{ + Cancel: true, + Reason: "Auto-compaction skipped by user.", + } + } + } else { + // Manual /compact: just notify. + ctx.PrintBlock(ext.PrintBlockOpts{ + Text: "Compacting conversation...\n" + summary, + BorderColor: "#89b4fa", + Subtitle: "compact-notify", + }) + } + + return nil // allow compaction + }) +} diff --git a/examples/extensions/confirm-destructive.go b/examples/extensions/confirm-destructive.go new file mode 100644 index 00000000..898deec7 --- /dev/null +++ b/examples/extensions/confirm-destructive.go @@ -0,0 +1,72 @@ +//go:build ignore + +package main + +import ( + "os/exec" + "strings" + + "kit/ext" +) + +// Init registers before-hooks for destructive session operations: +// - Forks: Asks for confirmation before branching to a different tree node. +// - New sessions: Checks for uncommitted git changes and warns before +// starting a new branch if the working tree is dirty. +// +// This demonstrates the OnBeforeFork and OnBeforeSessionSwitch events +// which allow extensions to cancel session lifecycle operations. +// +// Usage: kit -e examples/extensions/confirm-destructive.go --continue +func Init(api ext.API) { + // Gate /new command: warn if there are uncommitted git changes. + api.OnBeforeSessionSwitch(func(e ext.BeforeSessionSwitchEvent, ctx ext.Context) *ext.BeforeSessionSwitchResult { + if !isGitDirty() { + return nil // clean repo, allow switch + } + + result := ctx.PromptConfirm(ext.PromptConfirmConfig{ + Message: "Working tree has uncommitted changes. Start new session anyway?", + }) + if result.Cancelled || !result.Value { + return &ext.BeforeSessionSwitchResult{ + Cancel: true, + Reason: "Session switch cancelled: uncommitted git changes.", + } + } + return nil // user approved + }) + + // Gate fork: ask for confirmation before branching. + api.OnBeforeFork(func(e ext.BeforeForkEvent, ctx ext.Context) *ext.BeforeForkResult { + msg := "Branch to this point in the conversation?" + if e.IsUserMessage && e.UserText != "" { + // Show a preview of the user message being forked to. + preview := e.UserText + if len(preview) > 80 { + preview = preview[:77] + "..." + } + msg = "Fork and edit: " + preview + "\n\nContinue?" + } + + result := ctx.PromptConfirm(ext.PromptConfirmConfig{ + Message: msg, + }) + if result.Cancelled || !result.Value { + return &ext.BeforeForkResult{ + Cancel: true, + Reason: "Fork cancelled by user.", + } + } + return nil // user approved + }) +} + +// isGitDirty returns true if the git working tree has uncommitted changes. +func isGitDirty() bool { + out, err := exec.Command("git", "status", "--porcelain").Output() + if err != nil { + return false // not a git repo or git not available + } + return len(strings.TrimSpace(string(out))) > 0 +} diff --git a/internal/extensions/api.go b/internal/extensions/api.go index 41459920..9980c0b0 100644 --- a/internal/extensions/api.go +++ b/internal/extensions/api.go @@ -604,6 +604,9 @@ type API struct { registerToolRendererFn func(ToolRenderConfig) onModelChange func(func(ModelChangeEvent, Context)) onContextPrepare func(func(ContextPrepareEvent, Context) *ContextPrepareResult) + onBeforeFork func(func(BeforeForkEvent, Context) *BeforeForkResult) + onBeforeSessionSwitch func(func(BeforeSessionSwitchEvent, Context) *BeforeSessionSwitchResult) + onBeforeCompact func(func(BeforeCompactEvent, Context) *BeforeCompactResult) onCustomEvent func(name string, handler func(string)) registerOption func(OptionDef) } @@ -732,6 +735,27 @@ func (a *API) OnCustomEvent(name string, handler func(string)) { a.onCustomEvent(name, handler) } +// OnBeforeFork registers a handler that fires before the session tree is +// branched to a different entry point. Return a non-nil BeforeForkResult +// with Cancel=true to prevent the fork. +func (a *API) OnBeforeFork(handler func(BeforeForkEvent, Context) *BeforeForkResult) { + a.onBeforeFork(handler) +} + +// OnBeforeSessionSwitch registers a handler that fires before the session +// is switched to a new branch (e.g. /new command). Return a non-nil +// BeforeSessionSwitchResult with Cancel=true to prevent the switch. +func (a *API) OnBeforeSessionSwitch(handler func(BeforeSessionSwitchEvent, Context) *BeforeSessionSwitchResult) { + a.onBeforeSessionSwitch(handler) +} + +// OnBeforeCompact registers a handler that fires before context compaction +// runs. Return a non-nil BeforeCompactResult with Cancel=true to prevent +// compaction from proceeding. +func (a *API) OnBeforeCompact(handler func(BeforeCompactEvent, Context) *BeforeCompactResult) { + a.onBeforeCompact(handler) +} + // RegisterToolRenderer registers a custom renderer for a specific tool's // display in the TUI. The renderer controls the header (parameter summary) // and/or body (result display) of the tool's output block. If multiple @@ -1385,3 +1409,82 @@ type ContextPrepareResult struct { } func (ContextPrepareResult) isResult() {} + +// BeforeForkEvent fires before the session tree is branched to a different +// entry point (via the tree selector or /fork command). +type BeforeForkEvent struct { + // TargetID is the session entry ID being branched to. + TargetID string + // IsUserMessage is true if the selected entry is a user message + // (which causes the fork to target the parent entry). + IsUserMessage bool + // UserText is the user message text (non-empty only when IsUserMessage is true). + UserText string +} + +func (e BeforeForkEvent) Type() EventType { return BeforeFork } + +// BeforeForkResult controls whether the fork proceeds. Return Cancel=true +// with an optional Reason to block the fork. +type BeforeForkResult struct { + // Cancel, when true, prevents the fork from proceeding. + Cancel bool + // Reason is a human-readable explanation shown to the user when + // Cancel is true. Empty string uses a default message. + Reason string +} + +func (BeforeForkResult) isResult() {} + +// BeforeSessionSwitchEvent fires before the session is switched to a new +// branch (e.g. /new or /clear commands). +type BeforeSessionSwitchEvent struct { + // Reason describes why the switch is happening: "new" for /new command, + // "clear" for /clear command. + Reason string +} + +func (e BeforeSessionSwitchEvent) Type() EventType { return BeforeSessionSwitch } + +// BeforeSessionSwitchResult controls whether the session switch proceeds. +// Return Cancel=true with an optional Reason to block the switch. +type BeforeSessionSwitchResult struct { + // Cancel, when true, prevents the session switch from proceeding. + Cancel bool + // Reason is a human-readable explanation shown to the user when + // Cancel is true. Empty string uses a default message. + Reason string +} + +func (BeforeSessionSwitchResult) isResult() {} + +// BeforeCompactEvent fires before context compaction runs. Provides +// information about the current context state to help extensions decide +// whether to allow or block compaction. +type BeforeCompactEvent struct { + // EstimatedTokens is the estimated token count of the conversation. + EstimatedTokens int + // ContextLimit is the model's context window size in tokens. + ContextLimit int + // UsagePercent is the fraction of context used (0.0–1.0). + UsagePercent float64 + // MessageCount is the number of messages in the conversation. + MessageCount int + // IsAutomatic is true when compaction was triggered automatically + // (as opposed to manual /compact command). + IsAutomatic bool +} + +func (e BeforeCompactEvent) Type() EventType { return BeforeCompact } + +// BeforeCompactResult controls whether compaction proceeds. Return +// Cancel=true with an optional Reason to block compaction. +type BeforeCompactResult struct { + // Cancel, when true, prevents compaction from proceeding. + Cancel bool + // Reason is a human-readable explanation shown to the user when + // Cancel is true. Empty string uses a default message. + Reason string +} + +func (BeforeCompactResult) isResult() {} diff --git a/internal/extensions/events.go b/internal/extensions/events.go index 717cfde6..5d50dd64 100644 --- a/internal/extensions/events.go +++ b/internal/extensions/events.go @@ -56,6 +56,18 @@ const ( // before the messages are sent to the LLM. Handlers can filter, reorder, // or inject messages into the context window. ContextPrepare EventType = "context_prepare" + + // BeforeFork fires before the session tree is branched to a different + // entry point. Handlers can cancel the fork by returning Cancel=true. + BeforeFork EventType = "before_fork" + + // BeforeSessionSwitch fires before the session is switched to a new + // branch (e.g. /new command). Handlers can cancel by returning Cancel=true. + BeforeSessionSwitch EventType = "before_session_switch" + + // BeforeCompact fires before context compaction runs. Handlers can + // cancel compaction by returning Cancel=true. + BeforeCompact EventType = "before_compact" ) // AllEventTypes returns every supported event type. @@ -66,6 +78,7 @@ func AllEventTypes() []EventType { MessageStart, MessageUpdate, MessageEnd, SessionStart, SessionShutdown, ModelChange, ContextPrepare, + BeforeFork, BeforeSessionSwitch, BeforeCompact, } } diff --git a/internal/extensions/events_test.go b/internal/extensions/events_test.go index e61b9969..f13228ca 100644 --- a/internal/extensions/events_test.go +++ b/internal/extensions/events_test.go @@ -4,8 +4,8 @@ import "testing" func TestAllEventTypes_Count(t *testing.T) { all := AllEventTypes() - if len(all) != 15 { - t.Fatalf("expected 15 event types, got %d", len(all)) + if len(all) != 18 { + t.Fatalf("expected 18 event types, got %d", len(all)) } } @@ -52,6 +52,9 @@ func TestEventType_TypeMethod(t *testing.T) { {SessionShutdownEvent{}, SessionShutdown}, {ModelChangeEvent{NewModel: "a/b"}, ModelChange}, {ContextPrepareEvent{Messages: []ContextMessage{{Index: 0, Role: "user", Content: "hi"}}}, ContextPrepare}, + {BeforeForkEvent{TargetID: "abc"}, BeforeFork}, + {BeforeSessionSwitchEvent{Reason: "new"}, BeforeSessionSwitch}, + {BeforeCompactEvent{EstimatedTokens: 1000}, BeforeCompact}, } for _, tt := range tests { diff --git a/internal/extensions/loader.go b/internal/extensions/loader.go index c0b4d6bc..204cbeaf 100644 --- a/internal/extensions/loader.go +++ b/internal/extensions/loader.go @@ -298,6 +298,33 @@ func loadSingleExtension(path string) (*LoadedExtension, error) { return *r }) }, + onBeforeFork: func(h func(BeforeForkEvent, Context) *BeforeForkResult) { + reg(BeforeFork, func(e Event, c Context) Result { + r := h(e.(BeforeForkEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onBeforeSessionSwitch: func(h func(BeforeSessionSwitchEvent, Context) *BeforeSessionSwitchResult) { + reg(BeforeSessionSwitch, func(e Event, c Context) Result { + r := h(e.(BeforeSessionSwitchEvent), c) + if r == nil { + return nil + } + return *r + }) + }, + onBeforeCompact: func(h func(BeforeCompactEvent, Context) *BeforeCompactResult) { + reg(BeforeCompact, func(e Event, c Context) Result { + r := h(e.(BeforeCompactEvent), c) + if r == nil { + return nil + } + return *r + }) + }, registerToolFn: func(tool ToolDef) { ext.Tools = append(ext.Tools, tool) }, diff --git a/internal/extensions/runner.go b/internal/extensions/runner.go index 1149158e..577d7e1f 100644 --- a/internal/extensions/runner.go +++ b/internal/extensions/runner.go @@ -527,6 +527,12 @@ func isBlocking(result Result) bool { return r.Block case InputResult: return r.Action == "handled" + case BeforeForkResult: + return r.Cancel + case BeforeSessionSwitchResult: + return r.Cancel + case BeforeCompactResult: + return r.Cancel } return false } diff --git a/internal/extensions/symbols.go b/internal/extensions/symbols.go index 6df16427..ee061ff0 100644 --- a/internal/extensions/symbols.go +++ b/internal/extensions/symbols.go @@ -97,6 +97,14 @@ func Symbols() interp.Exports { "ContextPrepareEvent": reflect.ValueOf((*ContextPrepareEvent)(nil)), "ContextPrepareResult": reflect.ValueOf((*ContextPrepareResult)(nil)), + // Session lifecycle types + "BeforeForkEvent": reflect.ValueOf((*BeforeForkEvent)(nil)), + "BeforeForkResult": reflect.ValueOf((*BeforeForkResult)(nil)), + "BeforeSessionSwitchEvent": reflect.ValueOf((*BeforeSessionSwitchEvent)(nil)), + "BeforeSessionSwitchResult": reflect.ValueOf((*BeforeSessionSwitchResult)(nil)), + "BeforeCompactEvent": reflect.ValueOf((*BeforeCompactEvent)(nil)), + "BeforeCompactResult": reflect.ValueOf((*BeforeCompactResult)(nil)), + // Event structs "ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)), "ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)), diff --git a/internal/ui/model.go b/internal/ui/model.go index f03348a5..d27f9b24 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -272,6 +272,17 @@ type AppModelOptions struct { // extension entries alongside the built-in model/usage display. // May be nil if no extensions are loaded. GetStatusBarEntries func() []StatusBarEntryData + + // EmitBeforeFork, if non-nil, is called before branching to a + // different session tree entry. Returns (cancelled, reason) where + // cancelled=true means the fork should be aborted. May be nil if + // no extensions are loaded. + EmitBeforeFork func(targetID string, isUserMsg bool, userText string) (bool, string) + + // EmitBeforeSessionSwitch, if non-nil, is called before switching + // to a new session branch (e.g. /new, /clear). Returns (cancelled, + // reason). May be nil if no extensions are loaded. + EmitBeforeSessionSwitch func(reason string) (bool, string) } // AppModel is the root Bubble Tea model for the interactive TUI. It owns the @@ -385,6 +396,14 @@ type AppModel struct { // getStatusBarEntries returns extension-provided status bar entries. May be nil. getStatusBarEntries func() []StatusBarEntryData + // emitBeforeFork emits a before-fork event to extensions. Returns + // (cancelled, reason). May be nil if no extensions are loaded. + emitBeforeFork func(targetID string, isUserMsg bool, userText string) (bool, string) + + // emitBeforeSessionSwitch emits a before-session-switch event to extensions. + // Returns (cancelled, reason). May be nil if no extensions are loaded. + emitBeforeSessionSwitch func(reason string) (bool, string) + // prompt holds the state of an active interactive prompt overlay. Nil // when no prompt is active. Managed by updatePromptState(). prompt *promptOverlay @@ -500,6 +519,8 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel { m.getEditorInterceptor = opts.GetEditorInterceptor m.getUIVisibility = opts.GetUIVisibility m.getStatusBarEntries = opts.GetStatusBarEntries + m.emitBeforeFork = opts.EmitBeforeFork + m.emitBeforeSessionSwitch = opts.EmitBeforeSessionSwitch // Store context/skills metadata and tool counts for startup display. m.contextPaths = opts.ContextPaths @@ -662,6 +683,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } } } + + // Emit before-fork event — extensions can cancel the operation. + if m.emitBeforeFork != nil { + if cancelled, reason := m.emitBeforeFork(targetID, msg.IsUser, msg.UserText); cancelled { + m.treeSelector = nil + m.state = stateInput + return m, m.printSystemMessage(reason) + } + } + _ = ts.Branch(targetID) m.appCtrl.ClearMessages() @@ -1912,6 +1943,13 @@ func (m *AppModel) handleForkCommand() tea.Cmd { // handleNewCommand starts a fresh session by resetting the tree leaf. func (m *AppModel) handleNewCommand() tea.Cmd { + // Emit before-session-switch event — extensions can cancel. + if m.emitBeforeSessionSwitch != nil { + if cancelled, reason := m.emitBeforeSessionSwitch("new"); cancelled { + return m.printSystemMessage(reason) + } + } + ts := m.appCtrl.GetTreeSession() if ts == nil { // No tree session — just clear messages. diff --git a/pkg/kit/compaction.go b/pkg/kit/compaction.go index 08e26386..8f08ade7 100644 --- a/pkg/kit/compaction.go +++ b/pkg/kit/compaction.go @@ -86,6 +86,12 @@ func (m *Kit) GetContextStats() ContextStats { // After compaction, the tree session is cleared and replaced with the // compacted messages (summary + preserved recent messages). func (m *Kit) Compact(ctx context.Context, opts *CompactionOptions, customInstructions string) (*CompactionResult, error) { + return m.compactInternal(ctx, opts, customInstructions, false) +} + +// compactInternal is the shared compaction implementation. The isAutomatic +// flag distinguishes auto-triggered compaction from manual /compact. +func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, customInstructions string, isAutomatic bool) (*CompactionResult, error) { if opts == nil { if m.compactionOpts != nil { opts = m.compactionOpts @@ -106,6 +112,24 @@ func (m *Kit) Compact(ctx context.Context, opts *CompactionOptions, customInstru return nil, fmt.Errorf("cannot compact: need at least 2 messages") } + // Run before-compact hook — extensions can cancel compaction. + if m.beforeCompact.hasHooks() { + stats := m.GetContextStats() + if hookResult := m.beforeCompact.run(BeforeCompactHook{ + EstimatedTokens: stats.EstimatedTokens, + ContextLimit: stats.ContextLimit, + UsagePercent: stats.UsagePercent, + MessageCount: stats.MessageCount, + IsAutomatic: isAutomatic, + }); hookResult != nil && hookResult.Cancel { + reason := hookResult.Reason + if reason == "" { + reason = "compaction cancelled by extension" + } + return nil, fmt.Errorf("%s", reason) + } + } + model := m.agent.GetModel() result, newMessages, err := compaction.Compact(ctx, model, messages, *opts, customInstructions) if err != nil { diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go index b2534140..6301609f 100644 --- a/pkg/kit/extensions_bridge.go +++ b/pkg/kit/extensions_bridge.go @@ -157,4 +157,25 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) { return &ContextPrepareResult{Messages: rebuilt} }) } + + // --- Compaction hook --- + // Extension BeforeCompact → SDK BeforeCompact hook. + if runner.HasHandlers(extensions.BeforeCompact) { + m.OnBeforeCompact(HookPriorityNormal, func(h BeforeCompactHook) *BeforeCompactResult { + result, _ := runner.Emit(extensions.BeforeCompactEvent{ + EstimatedTokens: h.EstimatedTokens, + ContextLimit: h.ContextLimit, + UsagePercent: h.UsagePercent, + MessageCount: h.MessageCount, + IsAutomatic: h.IsAutomatic, + }) + if r, ok := result.(extensions.BeforeCompactResult); ok && r.Cancel { + return &BeforeCompactResult{ + Cancel: true, + Reason: r.Reason, + } + } + return nil + }) + } } diff --git a/pkg/kit/hooks.go b/pkg/kit/hooks.go index 8c5921fc..55a52c94 100644 --- a/pkg/kit/hooks.go +++ b/pkg/kit/hooks.go @@ -91,6 +91,28 @@ type ContextPrepareResult struct { Messages []fantasy.Message } +// BeforeCompactHook is the input for hooks that fire before compaction runs. +type BeforeCompactHook struct { + // EstimatedTokens is the estimated token count of the conversation. + EstimatedTokens int + // ContextLimit is the model's context window size in tokens. + ContextLimit int + // UsagePercent is the fraction of context used (0.0–1.0). + UsagePercent float64 + // MessageCount is the number of messages in the conversation. + MessageCount int + // IsAutomatic is true when compaction was triggered automatically. + IsAutomatic bool +} + +// BeforeCompactResult controls whether compaction proceeds. +type BeforeCompactResult struct { + // Cancel, when true, prevents compaction from proceeding. + Cancel bool + // Reason is a human-readable explanation when Cancel is true. + Reason string +} + // --------------------------------------------------------------------------- // Generic hook registry with priority ordering // --------------------------------------------------------------------------- @@ -205,6 +227,14 @@ func (m *Kit) OnContextPrepare(p HookPriority, h func(ContextPrepareHook) *Conte return m.contextPrepare.register(p, h) } +// OnBeforeCompact registers a hook that fires before context compaction runs. +// Return a non-nil BeforeCompactResult with Cancel=true to prevent compaction. +// Hooks execute in priority order; the first non-nil result wins. +// Returns an unregister function. +func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeCompactResult) func() { + return m.beforeCompact.register(p, h) +} + // --------------------------------------------------------------------------- // Tool wrapping via hooks // --------------------------------------------------------------------------- diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index 2a858f4c..4298f779 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -52,6 +52,7 @@ type Kit struct { beforeTurn *hookRegistry[BeforeTurnHook, BeforeTurnResult] afterTurn *hookRegistry[AfterTurnHook, AfterTurnResult] contextPrepare *hookRegistry[ContextPrepareHook, ContextPrepareResult] + beforeCompact *hookRegistry[BeforeCompactHook, BeforeCompactResult] // lastInputTokens stores the API-reported input token count from the // most recent turn. Used by GetContextStats() to return accurate usage @@ -643,6 +644,48 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ }, nil } +// EmitBeforeFork emits a BeforeFork event to extensions and returns +// whether the fork was cancelled and the reason. No-op if extensions are +// disabled (returns false, ""). +func (m *Kit) EmitBeforeFork(targetID string, isUserMsg bool, userText string) (cancelled bool, reason string) { + if m.extRunner == nil || !m.extRunner.HasHandlers(extensions.BeforeFork) { + return false, "" + } + result, _ := m.extRunner.Emit(extensions.BeforeForkEvent{ + TargetID: targetID, + IsUserMessage: isUserMsg, + UserText: userText, + }) + if r, ok := result.(extensions.BeforeForkResult); ok && r.Cancel { + reason := r.Reason + if reason == "" { + reason = "Fork cancelled by extension." + } + return true, reason + } + return false, "" +} + +// EmitBeforeSessionSwitch emits a BeforeSessionSwitch event to extensions +// and returns whether the switch was cancelled and the reason. No-op if +// extensions are disabled (returns false, ""). +func (m *Kit) EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string) { + if m.extRunner == nil || !m.extRunner.HasHandlers(extensions.BeforeSessionSwitch) { + return false, "" + } + result, _ := m.extRunner.Emit(extensions.BeforeSessionSwitchEvent{ + Reason: switchReason, + }) + if r, ok := result.(extensions.BeforeSessionSwitchResult); ok && r.Cancel { + reason := r.Reason + if reason == "" { + reason = "Session switch cancelled by extension." + } + return true, reason + } + return false, "" +} + // HasExtensions returns true if the extension runner is configured and active. func (m *Kit) HasExtensions() bool { return m.extRunner != nil @@ -825,6 +868,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { beforeTurn := newHookRegistry[BeforeTurnHook, BeforeTurnResult]() afterTurn := newHookRegistry[AfterTurnHook, AfterTurnResult]() contextPrepare := newHookRegistry[ContextPrepareHook, ContextPrepareResult]() + beforeCompact := newHookRegistry[BeforeCompactHook, BeforeCompactResult]() // Build agent setup options, pulling CLI-specific fields when available. setupOpts := kitsetup.AgentSetupOptions{ @@ -869,6 +913,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) { beforeTurn: beforeTurn, afterTurn: afterTurn, contextPrepare: contextPrepare, + beforeCompact: beforeCompact, } // Bridge extension events to SDK hooks. @@ -1144,7 +1189,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr // Auto-compact if enabled and conversation is near the context limit. if m.autoCompact && m.ShouldCompact() { - _, _ = m.Compact(ctx, m.compactionOpts, "") // best-effort + _, _ = m.compactInternal(ctx, m.compactionOpts, "", true) // best-effort, automatic } // Build context from the tree so only the current branch is sent.