diff --git a/cmd/root.go b/cmd/root.go index ea3cb414..3126aa7c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1262,9 +1262,57 @@ func runNormalMode(ctx context.Context) error { } } + // Bundle all the shared dependencies into a single struct that both + // run-mode entry points consume. This keeps the dispatch site and the + // function signatures readable. + deps := runModeDeps{ + appInstance: appInstance, + cli: cli, + modelName: modelName, + providerName: parsedProvider, + loadingMessage: kitInstance.GetLoadingMessage(), + serverNames: serverNames, + toolNames: toolNames, + mcpToolCount: mcpToolCount, + extensionToolCount: extensionToolCount, + usageTracker: usageTracker, + extCommands: extCommands, + promptTemplates: promptTemplates, + contextPaths: contextPaths, + skillItems: skillItems, + extensionItems: extensionItems, + getPromptTemplates: getPromptTemplates, + getSkillItems: getSkillItems, + getExtensionItems: getExtensionItems, + getToolNames: getToolNames, + getMCPToolCount: getMCPToolCount, + mcpPrompts: mcpPrompts, + getMCPPrompts: getMCPPrompts, + expandMCPPrompt: expandMCPPrompt, + getWidgets: getWidgets, + getHeader: getHeader, + getFooter: getFooter, + getToolRenderer: getToolRenderer, + getEditorInterceptor: getEditorInterceptor, + getUIVisibility: getUIVisibility, + getStatusBarEntries: getStatusBarEntries, + emitBeforeFork: emitBeforeFork, + emitBeforeSessionSwitch: emitBeforeSessionSwitch, + getGlobalShortcuts: getGlobalShortcuts, + getExtensionCommands: getExtensionCommands, + setModel: setModelForUI, + emitModelChange: emitModelChangeForUI, + isReasoningModel: kitInstance.IsReasoningModel(), + thinkingLevel: kitInstance.GetThinkingLevel(), + setThinkingLevel: setThinkingLevelForUI, + switchSession: switchSessionForUI, + reloadExtensions: reloadExtensionsForUI, + startupExtensionMessages: startupExtensionMessages, + } + // Check if running in non-interactive mode if positionalPrompt != "" { - return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI) + return runNonInteractiveModeApp(ctx, deps, positionalPrompt, quietFlag, jsonFlag, noExitFlag) } // Quiet mode is not allowed in interactive mode @@ -1272,7 +1320,7 @@ func runNormalMode(ctx context.Context) error { return fmt.Errorf("--quiet requires a prompt") } - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages) + return runInteractiveModeBubbleTea(ctx, deps) } // runNonInteractiveModeApp executes a single prompt via the app layer and exits, @@ -1285,7 +1333,10 @@ func runNormalMode(ctx context.Context) error { // // When --no-exit is set, after the prompt completes the interactive BubbleTea // TUI is started so the user can continue the conversation. -func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error { +func runNonInteractiveModeApp(ctx context.Context, deps runModeDeps, prompt string, quiet, jsonOutput, noExit bool) error { + appInstance := deps.appInstance + cli := deps.cli + modelName := deps.modelName // Expand @file references in the prompt before sending to the agent. // Text files are XML-inlined; binary files are extracted as multimodal parts. var fileParts []kit.LLMFilePart @@ -1346,12 +1397,67 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui // If --no-exit was requested, hand off to the interactive TUI. if noExit { - return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, extensionItems, getPromptTemplates, getSkillItems, getExtensionItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil) + // Drop the cli (interactive mode doesn't use it) and clear the + // interactive-only fields explicitly; deps carries everything else. + interactive := deps + interactive.cli = nil + interactive.startupExtensionMessages = nil + return runInteractiveModeBubbleTea(ctx, interactive) } return nil } +// runModeDeps bundles the shared dependencies that runNormalMode wires up +// once and threads to both runNonInteractiveModeApp and +// runInteractiveModeBubbleTea. Grouping them into a single struct keeps the +// call sites and signatures readable and makes it trivial to add a new +// provider callback without touching every call chain. +type runModeDeps struct { + appInstance *app.App + cli *ui.CLI // non-interactive only + modelName string + providerName string + loadingMessage string + serverNames []string + toolNames []string + mcpToolCount int + extensionToolCount int + usageTracker *ui.UsageTracker + extCommands []commands.ExtensionCommand + promptTemplates []*prompts.PromptTemplate + contextPaths []string + skillItems []ui.SkillItem + extensionItems []ui.ExtensionItem + getPromptTemplates func() []*prompts.PromptTemplate + getSkillItems func() []ui.SkillItem + getExtensionItems func() []ui.ExtensionItem + getToolNames func() []string + getMCPToolCount func() int + mcpPrompts []ui.MCPPromptInfo + getMCPPrompts func() []ui.MCPPromptInfo + expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error) + getWidgets func(string) []ui.WidgetData + getHeader func() *ui.WidgetData + getFooter func() *ui.WidgetData + getToolRenderer func(string) *ui.ToolRendererData + getEditorInterceptor func() *ui.EditorInterceptor + getUIVisibility func() *ui.UIVisibility + getStatusBarEntries func() []ui.StatusBarEntryData + emitBeforeFork func(string, bool, string) (bool, string) + emitBeforeSessionSwitch func(string) (bool, string) + getGlobalShortcuts func() map[string]func() + getExtensionCommands func() []commands.ExtensionCommand + setModel func(string) error + emitModelChange func(string, string, string) + isReasoningModel bool + thinkingLevel string + setThinkingLevel func(string) error + switchSession func(string) error + reloadExtensions func() error + startupExtensionMessages []string // interactive only +} + // --------------------------------------------------------------------------- // JSON output helpers (--json mode) // --------------------------------------------------------------------------- @@ -1444,7 +1550,8 @@ func writeJSONError(err error) { // 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit). // // SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering. -func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, extensionItems []ui.ExtensionItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getExtensionItems func() []ui.ExtensionItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error { +func runInteractiveModeBubbleTea(_ context.Context, deps runModeDeps) error { + appInstance := deps.appInstance // Redirect all log output (stdlib and charm) to a file so that log // messages don't write to stderr and corrupt the TUI. Bubble Tea // captures stdout for rendering; any stray stderr output from @@ -1467,49 +1574,49 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN cwd, _ := os.Getwd() appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{ - ModelName: modelName, - ProviderName: providerName, - LoadingMessage: loadingMessage, + ModelName: deps.modelName, + ProviderName: deps.providerName, + LoadingMessage: deps.loadingMessage, Cwd: cwd, Width: termWidth, Height: termHeight, - ServerNames: serverNames, - ToolNames: toolNames, - GetToolNames: getToolNames, - GetMCPToolCount: getMCPToolCount, - MCPToolCount: mcpToolCount, - ExtensionToolCount: extensionToolCount, - UsageTracker: usageTracker, - ExtensionCommands: extCommands, - PromptTemplates: promptTemplates, - GetPromptTemplates: getPromptTemplates, - MCPPrompts: mcpPrompts, - GetMCPPrompts: getMCPPrompts, - ExpandMCPPrompt: expandMCPPrompt, - ContextPaths: contextPaths, - SkillItems: skillItems, - GetSkillItems: getSkillItems, - ExtensionItems: extensionItems, - GetExtensionItems: getExtensionItems, - StartupExtensionMessages: startupExtensionMessages, - GetWidgets: getWidgets, - GetHeader: getHeader, - GetFooter: getFooter, - GetToolRenderer: getToolRenderer, - GetEditorInterceptor: getEditorInterceptor, - GetUIVisibility: getUIVisibility, - GetStatusBarEntries: getStatusBarEntries, - EmitBeforeFork: emitBeforeFork, - EmitBeforeSessionSwitch: emitBeforeSessionSwitch, - GetGlobalShortcuts: getGlobalShortcuts, - GetExtensionCommands: getExtensionCommands, - SetModel: setModel, - EmitModelChange: emitModelChange, - ThinkingLevel: thinkingLevel, - IsReasoningModel: isReasoningModel, - SetThinkingLevel: setThinkingLevel, - SwitchSession: switchSession, - ReloadExtensions: reloadExtensions, + ServerNames: deps.serverNames, + ToolNames: deps.toolNames, + GetToolNames: deps.getToolNames, + GetMCPToolCount: deps.getMCPToolCount, + MCPToolCount: deps.mcpToolCount, + ExtensionToolCount: deps.extensionToolCount, + UsageTracker: deps.usageTracker, + ExtensionCommands: deps.extCommands, + PromptTemplates: deps.promptTemplates, + GetPromptTemplates: deps.getPromptTemplates, + MCPPrompts: deps.mcpPrompts, + GetMCPPrompts: deps.getMCPPrompts, + ExpandMCPPrompt: deps.expandMCPPrompt, + ContextPaths: deps.contextPaths, + SkillItems: deps.skillItems, + GetSkillItems: deps.getSkillItems, + ExtensionItems: deps.extensionItems, + GetExtensionItems: deps.getExtensionItems, + StartupExtensionMessages: deps.startupExtensionMessages, + GetWidgets: deps.getWidgets, + GetHeader: deps.getHeader, + GetFooter: deps.getFooter, + GetToolRenderer: deps.getToolRenderer, + GetEditorInterceptor: deps.getEditorInterceptor, + GetUIVisibility: deps.getUIVisibility, + GetStatusBarEntries: deps.getStatusBarEntries, + EmitBeforeFork: deps.emitBeforeFork, + EmitBeforeSessionSwitch: deps.emitBeforeSessionSwitch, + GetGlobalShortcuts: deps.getGlobalShortcuts, + GetExtensionCommands: deps.getExtensionCommands, + SetModel: deps.setModel, + EmitModelChange: deps.emitModelChange, + ThinkingLevel: deps.thinkingLevel, + IsReasoningModel: deps.isReasoningModel, + SetThinkingLevel: deps.setThinkingLevel, + SwitchSession: deps.switchSession, + ReloadExtensions: deps.reloadExtensions, ShowSessionPicker: resumeFlag, GetMCPResources: mcpGetResources, MCPResourceReader: mcpResourceReader, diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 3d2d7eab..2c80d718 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -169,9 +169,9 @@ type RetryHandler func(attempt int, err error) type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message // GenerateCallbacks consolidates all callback functions for -// GenerateWithLoopAndStreaming into a single struct. This replaces the previous -// 16+ positional callback parameters, making it easier to add new callbacks -// without breaking existing callers (new fields default to nil). +// GenerateWithCallbacks into a single struct, replacing what was previously +// 16+ positional callback parameters. New fields default to nil, so adding +// new callbacks does not break existing callers. type GenerateCallbacks struct { OnToolCall ToolCallHandler OnToolExecution ToolExecutionHandler @@ -522,44 +522,6 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message }) } -// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks. -// The agent handles the tool call loop internally. -// -// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks -// struct and is easier to extend with new callbacks. -func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message, - onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler, - onResponse ResponseHandler, onToolCallContent ToolCallContentHandler, - onStreamingResponse StreamingResponseHandler, - onReasoningDelta ReasoningDeltaHandler, - onReasoningComplete ReasoningCompleteHandler, - onToolOutput ToolOutputHandler, - onStepMessages StepMessagesHandler, - onStepUsage StepUsageHandler, - onPasswordPrompt PasswordPromptHandler, - onToolCallStart ToolCallStartHandler, - onToolCallDelta ToolCallDeltaHandler, - onToolCallEnd ToolCallEndHandler, -) (*GenerateWithLoopResult, error) { - return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{ - OnToolCall: onToolCall, - OnToolExecution: onToolExecution, - OnToolResult: onToolResult, - OnResponse: onResponse, - OnToolCallContent: onToolCallContent, - OnStreamingResponse: onStreamingResponse, - OnReasoningDelta: onReasoningDelta, - OnReasoningComplete: onReasoningComplete, - OnToolOutput: onToolOutput, - OnStepMessages: onStepMessages, - OnStepUsage: onStepUsage, - OnPasswordPrompt: onPasswordPrompt, - OnToolCallStart: onToolCallStart, - OnToolCallDelta: onToolCallDelta, - OnToolCallEnd: onToolCallEnd, - }) -} - // GenerateWithCallbacks processes messages using the agent with streaming and callbacks. // The agent handles the tool call loop internally. We map the rich callback system // to kit's existing callback interface for UI integration. diff --git a/internal/core/bash.go b/internal/core/bash.go index 0e641f93..a493f06c 100644 --- a/internal/core/bash.go +++ b/internal/core/bash.go @@ -249,34 +249,37 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa return executeBashBuffered(cmdCtx, call, cmd, sudoPassword) } -// executeBashBuffered collects all output before returning (original behavior). -// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly -// close them when grandchild processes hold pipe handles open after the -// direct child exits. -func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) { +// setupBashPipes opens stdout/stderr pipes (plus an optional sudo stdin), +// starts the command, and asynchronously writes the sudo password if any. +// Returns the readers ready for the caller to consume. If setup fails, +// errResp is non-nil and the readers must not be used; the caller should +// return the response directly. +func setupBashPipes(cmd *exec.Cmd, sudoPassword string) (stdout, stderr io.Reader, errResp *fantasy.ToolResponse) { stdoutPipe, err := cmd.StdoutPipe() if err != nil { - return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil + r := fantasy.NewTextErrorResponse("failed to create stdout pipe") + return nil, nil, &r } stderrPipe, err := cmd.StderrPipe() if err != nil { - return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil + r := fantasy.NewTextErrorResponse("failed to create stderr pipe") + return nil, nil, &r } - // If we have a sudo password, create a stdin pipe and write the password var stdinPipe io.WriteCloser if sudoPassword != "" { stdinPipe, err = cmd.StdinPipe() if err != nil { - return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil + r := fantasy.NewTextErrorResponse("failed to create stdin pipe") + return nil, nil, &r } } if err := cmd.Start(); err != nil { - return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil + r := fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)) + return nil, nil, &r } - // Write password to stdin if needed, then close stdin if sudoPassword != "" && stdinPipe != nil { go func() { defer func() { _ = stdinPipe.Close() }() @@ -284,19 +287,49 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe }() } + return stdoutPipe, stderrPipe, nil +} + +// interpretBashExit decodes cmd.Wait()'s error into an exit code, mapping +// context-deadline-exceeded to a friendly "command timed out" response. +// errResp is non-nil only when the caller should short-circuit and return +// it directly (e.g. timeout). +func interpretBashExit(waitErr error, cmdCtx context.Context) (exitCode int, errResp *fantasy.ToolResponse) { + if waitErr == nil { + return 0, nil + } + if exitErr, ok := waitErr.(*exec.ExitError); ok { + return exitErr.ExitCode(), nil + } + if cmdCtx.Err() == context.DeadlineExceeded { + r := fantasy.NewTextErrorResponse("command timed out") + return 0, &r + } + return 0, nil +} + +// executeBashBuffered collects all output before returning (original behavior). +// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly +// close them when grandchild processes hold pipe handles open after the +// direct child exits. +func executeBashBuffered(cmdCtx context.Context, _ fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) { + stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword) + if errResp != nil { + return *errResp, nil + } + // Read pipes concurrently var wg sync.WaitGroup var stdout, stderr strings.Builder - var stdoutErr, stderrErr error wg.Add(2) go func() { defer wg.Done() - _, stdoutErr = io.Copy(&stdout, stdoutPipe) + _, _ = io.Copy(&stdout, stdoutPipe) }() go func() { defer wg.Done() - _, stderrErr = io.Copy(&stderr, stderrPipe) + _, _ = io.Copy(&stderr, stderrPipe) }() // Wait for the process to exit first. cmd.WaitDelay ensures that if @@ -307,18 +340,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe // Wait for pipe readers to finish draining. wg.Wait() - // Ignore pipe read errors caused by WaitDelay force-closing — - // we still have whatever was read before the close. - _ = stdoutErr - _ = stderrErr - - exitCode := 0 - if waitErr != nil { - if exitErr, ok := waitErr.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else if cmdCtx.Err() == context.DeadlineExceeded { - return fantasy.NewTextErrorResponse("command timed out"), nil - } + exitCode, errResp := interpretBashExit(waitErr, cmdCtx) + if errResp != nil { + return *errResp, nil } return buildBashResponse(stdout.String(), stderr.String(), exitCode) @@ -326,35 +350,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe // executeBashStreaming streams output as it arrives via the callback. func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) { - stdoutPipe, err := cmd.StdoutPipe() - if err != nil { - return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil - } - stderrPipe, err := cmd.StderrPipe() - if err != nil { - return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil - } - - // If we have a sudo password, create a stdin pipe - var stdinPipe io.WriteCloser - if sudoPassword != "" { - stdinPipe, err = cmd.StdinPipe() - if err != nil { - return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil - } - } - - // Start command execution - if err := cmd.Start(); err != nil { - return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil - } - - // Write password to stdin if needed, then close stdin - if sudoPassword != "" && stdinPipe != nil { - go func() { - defer func() { _ = stdinPipe.Close() }() - _, _ = io.WriteString(stdinPipe, sudoPassword+"\n") - }() + stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword) + if errResp != nil { + return *errResp, nil } // Stream stdout and stderr concurrently @@ -391,20 +389,16 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex // Wait for the process to exit. cmd.WaitDelay ensures that if pipes // remain open (held by grandchild processes), they'll be forcibly closed // after the grace period, which unblocks the scanners above. - err = cmd.Wait() + waitErr := cmd.Wait() // Wait for the pipe readers to finish draining. This will complete // quickly since cmd.Wait() (with WaitDelay) has already ensured // the pipes are closed. wg.Wait() - exitCode := 0 - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } else if cmdCtx.Err() == context.DeadlineExceeded { - return fantasy.NewTextErrorResponse("command timed out"), nil - } + exitCode, errResp := interpretBashExit(waitErr, cmdCtx) + if errResp != nil { + return *errResp, nil } return buildBashResponse(strings.Join(stdoutChunks, "\n"), strings.Join(stderrChunks, "\n"), exitCode) diff --git a/internal/core/edit.go b/internal/core/edit.go index ec3f854b..dd624543 100644 --- a/internal/core/edit.go +++ b/internal/core/edit.go @@ -83,6 +83,9 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool { } func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) { + if err := ctx.Err(); err != nil { + return fantasy.ToolResponse{}, err + } var args editArgs if err := parseArgs(call.Input, &args); err != nil { return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil diff --git a/internal/core/ls.go b/internal/core/ls.go index 196c62d6..31c1fa2f 100644 --- a/internal/core/ls.go +++ b/internal/core/ls.go @@ -42,6 +42,9 @@ func NewLsTool(opts ...ToolOption) fantasy.AgentTool { } func executeLs(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) { + if err := ctx.Err(); err != nil { + return fantasy.ToolResponse{}, err + } var args lsArgs _ = parseArgs(call.Input, &args) // optional args diff --git a/internal/core/read.go b/internal/core/read.go index a2e9665c..27c5b23c 100644 --- a/internal/core/read.go +++ b/internal/core/read.go @@ -47,6 +47,9 @@ func NewReadTool(opts ...ToolOption) fantasy.AgentTool { } func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) { + if err := ctx.Err(); err != nil { + return fantasy.ToolResponse{}, err + } var args readArgs if err := parseArgs(call.Input, &args); err != nil { return fantasy.NewTextErrorResponse("path parameter is required"), nil diff --git a/internal/core/write.go b/internal/core/write.go index 9b20e4a7..684679bc 100644 --- a/internal/core/write.go +++ b/internal/core/write.go @@ -41,6 +41,9 @@ func NewWriteTool(opts ...ToolOption) fantasy.AgentTool { } func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) { + if err := ctx.Err(); err != nil { + return fantasy.ToolResponse{}, err + } var args writeArgs if err := parseArgs(call.Input, &args); err != nil { return fantasy.NewTextErrorResponse("path and content parameters are required"), nil diff --git a/internal/extensions/watcher.go b/internal/extensions/watcher.go index 7943e27d..9d8664a8 100644 --- a/internal/extensions/watcher.go +++ b/internal/extensions/watcher.go @@ -1,143 +1,32 @@ package extensions import ( - "context" - "fmt" - "log" "os" "path/filepath" - "strings" - "sync" - "time" - "github.com/fsnotify/fsnotify" + "github.com/mark3labs/kit/internal/watcher" ) -// Watcher monitors extension directories for file changes and triggers -// a reload callback when .go files are created, modified, or removed. -// It uses fsnotify for kernel-level file notifications (inotify on Linux, -// kqueue on macOS) with debouncing to coalesce rapid editor writes. -type Watcher struct { - watcher *fsnotify.Watcher - onReload func() - debounce time.Duration - cancel context.CancelFunc - done chan struct{} - mu sync.Mutex -} +// Watcher monitors extension directories for .go file changes and triggers +// a reload callback when changes are detected. It is implemented in terms +// of the general-purpose internal/watcher.ContentWatcher. +// +// Type-aliasing here lets existing call sites (cmd/root.go and the +// watcher_test.go suite) keep using `extensions.NewWatcher` / `*Watcher` +// without knowing about the underlying implementation. +type Watcher = watcher.ContentWatcher // NewWatcher creates a file watcher that monitors the given directories // for .go file changes. When a change is detected (after debouncing), // onReload is called. The watcher must be started with Start() and // stopped with Close(). func NewWatcher(dirs []string, onReload func()) (*Watcher, error) { - fsw, err := fsnotify.NewWatcher() - if err != nil { - return nil, fmt.Errorf("creating file watcher: %w", err) - } - - for _, dir := range dirs { - // Watch the directory itself. - if err := fsw.Add(dir); err != nil { - log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err) - continue - } - - // Also watch immediate subdirectories (for */main.go pattern). - entries, err := os.ReadDir(dir) - if err != nil { - continue - } - for _, entry := range entries { - if entry.IsDir() { - subdir := filepath.Join(dir, entry.Name()) - if err := fsw.Add(subdir); err != nil { - log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err) - } - } - } - } - - return &Watcher{ - watcher: fsw, - onReload: onReload, - debounce: 300 * time.Millisecond, - done: make(chan struct{}), - }, nil -} - -// Start begins watching for file changes. It blocks until the context -// is cancelled or Close() is called. Typically called in a goroutine. -func (w *Watcher) Start(ctx context.Context) { - w.mu.Lock() - ctx, w.cancel = context.WithCancel(ctx) - w.mu.Unlock() - - defer close(w.done) - - var timer *time.Timer - var timerC <-chan time.Time - - for { - select { - case <-ctx.Done(): - if timer != nil { - timer.Stop() - } - return - - case event, ok := <-w.watcher.Events: - if !ok { - return - } - - // Only care about .go files. - if !strings.HasSuffix(event.Name, ".go") { - continue - } - - // React to write, create, remove, rename events. - if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 { - continue - } - - log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op) - - // Debounce: reset timer on each event. - if timer != nil { - timer.Stop() - } - timer = time.NewTimer(w.debounce) - timerC = timer.C - - case <-timerC: - timerC = nil - timer = nil - log.Printf("DEBUG watcher: reloading extensions") - w.onReload() - - case err, ok := <-w.watcher.Errors: - if !ok { - return - } - log.Printf("WARN watcher: error: %v", err) - } - } -} - -// Close stops the watcher and releases resources. -func (w *Watcher) Close() error { - w.mu.Lock() - cancel := w.cancel - w.mu.Unlock() - - if cancel != nil { - cancel() - } - - // Wait for the event loop to finish. - <-w.done - return w.watcher.Close() + return watcher.New(watcher.Options{ + Dirs: dirs, + Extensions: []string{".go"}, + OnReload: onReload, + Label: "extensions", + }) } // WatchedDirs returns the directories to watch for extension changes. @@ -146,47 +35,25 @@ func (w *Watcher) Close() error { // point to directories are also included; explicit file paths cause // their parent directory to be watched instead. func WatchedDirs(extraPaths []string) []string { - var dirs []string - seen := make(map[string]bool) - - add := func(dir string) { - abs, err := filepath.Abs(dir) - if err != nil { - return - } - if seen[abs] { - return - } - - // Verify the directory exists. - info, err := os.Stat(abs) - if err != nil || !info.IsDir() { - return - } - - seen[abs] = true - dirs = append(dirs, abs) + standard := []string{ + globalExtensionsDir(), + filepath.Join(".kit", "extensions"), } - // Global extensions dir. - add(globalExtensionsDir()) - - // Project-local extensions dir. - add(filepath.Join(".kit", "extensions")) - - // Explicit paths that are directories. + // Filter explicit paths into directories (passed through) and files + // (parent dir watched) for CollectDirs to dedupe. + var extras []string for _, p := range extraPaths { info, err := os.Stat(p) if err != nil { continue } if info.IsDir() { - add(p) + extras = append(extras, p) } else { - // For explicit files, watch the parent directory. - add(filepath.Dir(p)) + extras = append(extras, filepath.Dir(p)) } } - return dirs + return watcher.CollectDirs(standard, extras) } diff --git a/internal/models/providers.go b/internal/models/providers.go index 415a7a41..f1d2ef52 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -398,6 +398,24 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo } } +// resolveAutoRouteAPIKey looks up the API key for an auto-routed provider, +// returning a uniform error message when none can be resolved. +func resolveAutoRouteAPIKey(config *ProviderConfig, info *ProviderInfo) (string, error) { + apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env) + if apiKey == "" { + return "", fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s", + info.Name, strings.Join(info.Env, " / ")) + } + return apiKey, nil +} + +// wrapProviderErr produces the uniform "failed to create X provider/model: %w" +// error wrap used by every createXxxProvider path. kind is typically +// "provider" or "model". +func wrapProviderErr(name, kind string, err error) error { + return fmt.Errorf("failed to create %s %s: %w", name, kind, err) +} + // createAutoRoutedOpenAICompatProvider creates an openaicompat provider using // the api URL and env vars from models.dev. func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) { @@ -409,10 +427,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC return nil, fmt.Errorf("provider %s requires --provider-url (no API URL in database)", info.ID) } - apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env) - if apiKey == "" { - return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s", - info.Name, strings.Join(info.Env, " / ")) + apiKey, err := resolveAutoRouteAPIKey(config, info) + if err != nil { + return nil, err } var opts []openaicompat.Option @@ -426,12 +443,12 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC p, err := openaicompat.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "provider", err) } model, err := p.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "model", err) } return &ProviderResult{Model: model}, nil @@ -442,10 +459,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) { clearConflictingAnthropicSamplingParams(config) - apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env) - if apiKey == "" { - return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s", - info.Name, strings.Join(info.Env, " / ")) + apiKey, err := resolveAutoRouteAPIKey(config, info) + if err != nil { + return nil, err } var opts []anthropic.Option @@ -464,12 +480,12 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf p, err := anthropic.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "provider", err) } model, err := p.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "model", err) } return &ProviderResult{Model: model}, nil @@ -478,10 +494,9 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf // createAutoRoutedOpenAIProvider creates an openai provider for // third-party providers with openai-compatible APIs. func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) { - apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env) - if apiKey == "" { - return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s", - info.Name, strings.Join(info.Env, " / ")) + apiKey, err := resolveAutoRouteAPIKey(config, info) + if err != nil { + return nil, err } var opts []openai.Option @@ -498,12 +513,12 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, p, err := openai.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "provider", err) } model, err := p.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "model", err) } providerOpts := buildOpenAIProviderOptions(config, modelName) @@ -522,10 +537,9 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, // path that the proxy rejects. In that case we install a transport that // strips the injected segment so the proxy's own version is used. func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) { - apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env) - if apiKey == "" { - return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s", - info.Name, strings.Join(info.Env, " / ")) + apiKey, err := resolveAutoRouteAPIKey(config, info) + if err != nil { + return nil, err } opts := []google.Option{ @@ -550,12 +564,12 @@ func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, p, err := google.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "provider", err) } model, err := p.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err) + return nil, wrapProviderErr(info.Name, "model", err) } return &ProviderResult{Model: model}, nil @@ -859,12 +873,12 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN provider, err := anthropic.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Anthropic provider: %w", err) + return nil, wrapProviderErr("Anthropic", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Anthropic model: %w", err) + return nil, wrapProviderErr("Anthropic", "model", err) } // Build provider options for extended thinking (reasoning budget). @@ -901,12 +915,12 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig, provider, err := anthropic.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err) + return nil, wrapProviderErr("Vertex Anthropic", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err) + return nil, wrapProviderErr("Vertex Anthropic", "model", err) } return &ProviderResult{Model: model}, nil @@ -974,12 +988,12 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName provider, err := openai.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create OpenAI provider: %w", err) + return nil, wrapProviderErr("OpenAI", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create OpenAI model: %w", err) + return nil, wrapProviderErr("OpenAI", "model", err) } // Build provider options for OpenAI Responses API reasoning models. @@ -1015,12 +1029,12 @@ func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, mode provider, err := openai.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err) + return nil, wrapProviderErr("OpenAI Codex", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err) + return nil, wrapProviderErr("OpenAI Codex", "model", err) } providerOpts := buildCodexProviderOptions(config, modelName) @@ -1133,12 +1147,12 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName provider, err := google.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Google provider: %w", err) + return nil, wrapProviderErr("Google", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Google model: %w", err) + return nil, wrapProviderErr("Google", "model", err) } return &ProviderResult{Model: model}, nil @@ -1171,12 +1185,12 @@ func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName provider, err := azure.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err) + return nil, wrapProviderErr("Azure OpenAI", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err) + return nil, wrapProviderErr("Azure OpenAI", "model", err) } return &ProviderResult{Model: model}, nil @@ -1196,12 +1210,12 @@ func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, model provider, err := openrouter.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err) + return nil, wrapProviderErr("OpenRouter", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create OpenRouter model: %w", err) + return nil, wrapProviderErr("OpenRouter", "model", err) } return &ProviderResult{Model: model}, nil @@ -1213,12 +1227,12 @@ func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelNam // Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.) provider, err := bedrock.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Bedrock provider: %w", err) + return nil, wrapProviderErr("Bedrock", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Bedrock model: %w", err) + return nil, wrapProviderErr("Bedrock", "model", err) } return &ProviderResult{Model: model}, nil @@ -1242,12 +1256,12 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName provider, err := vercel.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Vercel provider: %w", err) + return nil, wrapProviderErr("Vercel", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Vercel model: %w", err) + return nil, wrapProviderErr("Vercel", "model", err) } return &ProviderResult{Model: model}, nil @@ -1300,12 +1314,12 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName p, err := openai.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create custom provider: %w", err) + return nil, wrapProviderErr("custom", "provider", err) } model, err := p.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create custom model: %w", err) + return nil, wrapProviderErr("custom", "model", err) } return &ProviderResult{Model: model}, nil @@ -1349,12 +1363,12 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName provider, err := openaicompat.New(opts...) if err != nil { - return nil, fmt.Errorf("failed to create Ollama provider: %w", err) + return nil, wrapProviderErr("Ollama", "provider", err) } model, err := provider.LanguageModel(ctx, modelName) if err != nil { - return nil, fmt.Errorf("failed to create Ollama model: %w", err) + return nil, wrapProviderErr("Ollama", "model", err) } return &ProviderResult{ diff --git a/internal/session/tree_manager.go b/internal/session/tree_manager.go index aa4b76cf..366f388c 100644 --- a/internal/session/tree_manager.go +++ b/internal/session/tree_manager.go @@ -458,11 +458,6 @@ func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) { return tm.AppendMessage(message.FromLLMMessage(msg)) } -// Deprecated: Use AppendLLMMessage instead. -func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) { - return tm.AppendLLMMessage(msg) -} - // AppendModelChange records a model/provider change. func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, error) { tm.mu.Lock() @@ -1170,11 +1165,6 @@ func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error { return tm.flushLocked() } -// Deprecated: Use AddLLMMessages instead. -func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error { - return tm.AddLLMMessages(msgs) -} - // GetLLMMessages builds the context and returns just the messages. // This satisfies the same conceptual role as the old Manager.GetMessages(). func (tm *TreeManager) GetLLMMessages() []fantasy.Message { @@ -1182,11 +1172,6 @@ func (tm *TreeManager) GetLLMMessages() []fantasy.Message { return msgs } -// Deprecated: Use GetLLMMessages instead. -func (tm *TreeManager) GetFantasyMessages() []fantasy.Message { - return tm.GetLLMMessages() -} - // --- Internal helpers --- // addEntryToIndex adds an entry to the in-memory indices. diff --git a/pkg/kit/events.go b/pkg/kit/events.go index 7f82cebf..c5d73223 100644 --- a/pkg/kit/events.go +++ b/pkg/kit/events.go @@ -571,67 +571,56 @@ func (eb *eventBus) emit(event Event) { // Typed convenience subscribers // --------------------------------------------------------------------------- +// subscribeTyped is the generic backbone of all the typed `On` +// convenience methods on *Kit. It wraps Subscribe with a type assertion +// against E so handlers receive a strongly-typed event without each +// public method having to repeat the boilerplate. Returns an unsubscribe +// function. +func subscribeTyped[E Event](k *Kit, handler func(E)) func() { + return k.Subscribe(func(e Event) { + if tev, ok := e.(E); ok { + handler(tev) + } + }) +} + // OnToolCall registers a handler that fires only for ToolCallEvent. // Returns an unsubscribe function. func (m *Kit) OnToolCall(handler func(ToolCallEvent)) func() { - return m.Subscribe(func(e Event) { - if tc, ok := e.(ToolCallEvent); ok { - handler(tc) - } - }) + return subscribeTyped(m, handler) } // OnToolCallStart registers a handler that fires only for ToolCallStartEvent. // This fires when the LLM begins generating tool call arguments — before the // full argument JSON is available. Returns an unsubscribe function. func (m *Kit) OnToolCallStart(handler func(ToolCallStartEvent)) func() { - return m.Subscribe(func(e Event) { - if tcs, ok := e.(ToolCallStartEvent); ok { - handler(tcs) - } - }) + return subscribeTyped(m, handler) } // OnToolCallDelta registers a handler that fires only for ToolCallDeltaEvent. // Each delta contains a JSON fragment of tool call arguments as they stream in. // Returns an unsubscribe function. func (m *Kit) OnToolCallDelta(handler func(ToolCallDeltaEvent)) func() { - return m.Subscribe(func(e Event) { - if tcd, ok := e.(ToolCallDeltaEvent); ok { - handler(tcd) - } - }) + return subscribeTyped(m, handler) } // OnToolCallEnd registers a handler that fires only for ToolCallEndEvent. // This fires when tool argument streaming is complete, before the tool call // is parsed and execution begins. Returns an unsubscribe function. func (m *Kit) OnToolCallEnd(handler func(ToolCallEndEvent)) func() { - return m.Subscribe(func(e Event) { - if tce, ok := e.(ToolCallEndEvent); ok { - handler(tce) - } - }) + return subscribeTyped(m, handler) } // OnToolResult registers a handler that fires only for ToolResultEvent. // Returns an unsubscribe function. func (m *Kit) OnToolResult(handler func(ToolResultEvent)) func() { - return m.Subscribe(func(e Event) { - if tr, ok := e.(ToolResultEvent); ok { - handler(tr) - } - }) + return subscribeTyped(m, handler) } // OnToolOutput registers a handler that fires only for ToolOutputEvent // (streaming tool output chunks, e.g., from bash). Returns an unsubscribe function. func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() { - return m.Subscribe(func(e Event) { - if to, ok := e.(ToolOutputEvent); ok { - handler(to) - } - }) + return subscribeTyped(m, handler) } // OnStreaming registers a handler that fires only for MessageUpdateEvent @@ -646,41 +635,25 @@ func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() { // OnMessageUpdate registers a handler that fires only for MessageUpdateEvent // (streaming text chunks). Returns an unsubscribe function. func (m *Kit) OnMessageUpdate(handler func(MessageUpdateEvent)) func() { - return m.Subscribe(func(e Event) { - if mu, ok := e.(MessageUpdateEvent); ok { - handler(mu) - } - }) + return subscribeTyped(m, handler) } // OnResponse registers a handler that fires only for ResponseEvent. // Returns an unsubscribe function. func (m *Kit) OnResponse(handler func(ResponseEvent)) func() { - return m.Subscribe(func(e Event) { - if r, ok := e.(ResponseEvent); ok { - handler(r) - } - }) + return subscribeTyped(m, handler) } // OnTurnStart registers a handler that fires only for TurnStartEvent. // Returns an unsubscribe function. func (m *Kit) OnTurnStart(handler func(TurnStartEvent)) func() { - return m.Subscribe(func(e Event) { - if ts, ok := e.(TurnStartEvent); ok { - handler(ts) - } - }) + return subscribeTyped(m, handler) } // OnTurnEnd registers a handler that fires only for TurnEndEvent. // Returns an unsubscribe function. func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() { - return m.Subscribe(func(e Event) { - if te, ok := e.(TurnEndEvent); ok { - handler(te) - } - }) + return subscribeTyped(m, handler) } // --------------------------------------------------------------------------- @@ -690,101 +663,61 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() { // OnMessageStart registers a handler that fires only for MessageStartEvent. // Returns an unsubscribe function. func (m *Kit) OnMessageStart(handler func(MessageStartEvent)) func() { - return m.Subscribe(func(e Event) { - if ms, ok := e.(MessageStartEvent); ok { - handler(ms) - } - }) + return subscribeTyped(m, handler) } // OnMessageEnd registers a handler that fires only for MessageEndEvent. // Returns an unsubscribe function. func (m *Kit) OnMessageEnd(handler func(MessageEndEvent)) func() { - return m.Subscribe(func(e Event) { - if me, ok := e.(MessageEndEvent); ok { - handler(me) - } - }) + return subscribeTyped(m, handler) } // OnReasoningDelta registers a handler that fires only for ReasoningDeltaEvent. // Returns an unsubscribe function. func (m *Kit) OnReasoningDelta(handler func(ReasoningDeltaEvent)) func() { - return m.Subscribe(func(e Event) { - if rd, ok := e.(ReasoningDeltaEvent); ok { - handler(rd) - } - }) + return subscribeTyped(m, handler) } // OnReasoningComplete registers a handler that fires only for ReasoningCompleteEvent. // Returns an unsubscribe function. func (m *Kit) OnReasoningComplete(handler func(ReasoningCompleteEvent)) func() { - return m.Subscribe(func(e Event) { - if rc, ok := e.(ReasoningCompleteEvent); ok { - handler(rc) - } - }) + return subscribeTyped(m, handler) } // OnToolExecutionStart registers a handler that fires only for ToolExecutionStartEvent. // Returns an unsubscribe function. func (m *Kit) OnToolExecutionStart(handler func(ToolExecutionStartEvent)) func() { - return m.Subscribe(func(e Event) { - if tes, ok := e.(ToolExecutionStartEvent); ok { - handler(tes) - } - }) + return subscribeTyped(m, handler) } // OnToolExecutionEnd registers a handler that fires only for ToolExecutionEndEvent. // Returns an unsubscribe function. func (m *Kit) OnToolExecutionEnd(handler func(ToolExecutionEndEvent)) func() { - return m.Subscribe(func(e Event) { - if tee, ok := e.(ToolExecutionEndEvent); ok { - handler(tee) - } - }) + return subscribeTyped(m, handler) } // OnToolCallContent registers a handler that fires only for ToolCallContentEvent. // Returns an unsubscribe function. func (m *Kit) OnToolCallContent(handler func(ToolCallContentEvent)) func() { - return m.Subscribe(func(e Event) { - if tcc, ok := e.(ToolCallContentEvent); ok { - handler(tcc) - } - }) + return subscribeTyped(m, handler) } // OnStepUsage registers a handler that fires only for StepUsageEvent. // Returns an unsubscribe function. func (m *Kit) OnStepUsage(handler func(StepUsageEvent)) func() { - return m.Subscribe(func(e Event) { - if su, ok := e.(StepUsageEvent); ok { - handler(su) - } - }) + return subscribeTyped(m, handler) } // OnCompaction registers a handler that fires only for CompactionEvent. // Returns an unsubscribe function. func (m *Kit) OnCompaction(handler func(CompactionEvent)) func() { - return m.Subscribe(func(e Event) { - if ce, ok := e.(CompactionEvent); ok { - handler(ce) - } - }) + return subscribeTyped(m, handler) } // OnSteerConsumed registers a handler that fires only for SteerConsumedEvent. // Returns an unsubscribe function. func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() { - return m.Subscribe(func(e Event) { - if sc, ok := e.(SteerConsumedEvent); ok { - handler(sc) - } - }) + return subscribeTyped(m, handler) } // --------------------------------------------------------------------------- @@ -794,101 +727,61 @@ func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() { // OnStepStart registers a handler that fires only for StepStartEvent. // Returns an unsubscribe function. func (m *Kit) OnStepStart(handler func(StepStartEvent)) func() { - return m.Subscribe(func(e Event) { - if ss, ok := e.(StepStartEvent); ok { - handler(ss) - } - }) + return subscribeTyped(m, handler) } // OnStepFinish registers a handler that fires only for StepFinishEvent. // Returns an unsubscribe function. func (m *Kit) OnStepFinish(handler func(StepFinishEvent)) func() { - return m.Subscribe(func(e Event) { - if sf, ok := e.(StepFinishEvent); ok { - handler(sf) - } - }) + return subscribeTyped(m, handler) } // OnTextStart registers a handler that fires only for TextStartEvent. // Returns an unsubscribe function. func (m *Kit) OnTextStart(handler func(TextStartEvent)) func() { - return m.Subscribe(func(e Event) { - if ts, ok := e.(TextStartEvent); ok { - handler(ts) - } - }) + return subscribeTyped(m, handler) } // OnTextEnd registers a handler that fires only for TextEndEvent. // Returns an unsubscribe function. func (m *Kit) OnTextEnd(handler func(TextEndEvent)) func() { - return m.Subscribe(func(e Event) { - if te, ok := e.(TextEndEvent); ok { - handler(te) - } - }) + return subscribeTyped(m, handler) } // OnReasoningStart registers a handler that fires only for ReasoningStartEvent. // Returns an unsubscribe function. func (m *Kit) OnReasoningStart(handler func(ReasoningStartEvent)) func() { - return m.Subscribe(func(e Event) { - if rs, ok := e.(ReasoningStartEvent); ok { - handler(rs) - } - }) + return subscribeTyped(m, handler) } // OnWarnings registers a handler that fires only for WarningsEvent. // Returns an unsubscribe function. func (m *Kit) OnWarnings(handler func(WarningsEvent)) func() { - return m.Subscribe(func(e Event) { - if w, ok := e.(WarningsEvent); ok { - handler(w) - } - }) + return subscribeTyped(m, handler) } // OnSource registers a handler that fires only for SourceEvent. // Returns an unsubscribe function. func (m *Kit) OnSource(handler func(SourceEvent)) func() { - return m.Subscribe(func(e Event) { - if s, ok := e.(SourceEvent); ok { - handler(s) - } - }) + return subscribeTyped(m, handler) } // OnStreamFinish registers a handler that fires only for StreamFinishEvent. // Returns an unsubscribe function. func (m *Kit) OnStreamFinish(handler func(StreamFinishEvent)) func() { - return m.Subscribe(func(e Event) { - if sf, ok := e.(StreamFinishEvent); ok { - handler(sf) - } - }) + return subscribeTyped(m, handler) } // OnError registers a handler that fires only for ErrorEvent. // Returns an unsubscribe function. func (m *Kit) OnError(handler func(ErrorEvent)) func() { - return m.Subscribe(func(e Event) { - if ee, ok := e.(ErrorEvent); ok { - handler(ee) - } - }) + return subscribeTyped(m, handler) } // OnRetry registers a handler that fires only for RetryEvent. // Returns an unsubscribe function. func (m *Kit) OnRetry(handler func(RetryEvent)) func() { - return m.Subscribe(func(e Event) { - if r, ok := e.(RetryEvent); ok { - handler(r) - } - }) + return subscribeTyped(m, handler) } // --------------------------------------------------------------------------- diff --git a/pkg/kit/extension_api.go b/pkg/kit/extension_api.go index efdc1a82..eea9a9d8 100644 --- a/pkg/kit/extension_api.go +++ b/pkg/kit/extension_api.go @@ -155,17 +155,17 @@ func (m *Kit) Extensions() ExtensionAPI { // Context management -func (e *extensionAPI) SetContext(ctx extensions.Context) { +func (e *extensionAPI) SetContext(ctx ExtensionContext) { if e.kit.extRunner != nil { e.kit.extRunner.SetContext(ctx) } } -func (e *extensionAPI) GetContext() extensions.Context { +func (e *extensionAPI) GetContext() ExtensionContext { if e.kit.extRunner != nil { return e.kit.extRunner.GetContext() } - return extensions.Context{} + return ExtensionContext{} } func (e *extensionAPI) UpdateContextModel(model string) { @@ -178,7 +178,7 @@ func (e *extensionAPI) UpdateContextModel(model string) { // Widgets -func (e *extensionAPI) SetWidget(config extensions.WidgetConfig) { +func (e *extensionAPI) SetWidget(config ExtensionWidgetConfig) { if e.kit.extRunner != nil { e.kit.extRunner.SetWidget(config) } @@ -190,7 +190,7 @@ func (e *extensionAPI) RemoveWidget(id string) { } } -func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig { +func (e *extensionAPI) GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig { if e.kit.extRunner == nil { return nil } @@ -199,7 +199,7 @@ func (e *extensionAPI) GetWidgets(placement extensions.WidgetPlacement) []extens // Header/Footer -func (e *extensionAPI) SetHeader(config extensions.HeaderFooterConfig) { +func (e *extensionAPI) SetHeader(config ExtensionHeaderFooterConfig) { if e.kit.extRunner != nil { e.kit.extRunner.SetHeader(config) } @@ -211,14 +211,14 @@ func (e *extensionAPI) RemoveHeader() { } } -func (e *extensionAPI) GetHeader() *extensions.HeaderFooterConfig { +func (e *extensionAPI) GetHeader() *ExtensionHeaderFooterConfig { if e.kit.extRunner == nil { return nil } return e.kit.extRunner.GetHeader() } -func (e *extensionAPI) SetFooter(config extensions.HeaderFooterConfig) { +func (e *extensionAPI) SetFooter(config ExtensionHeaderFooterConfig) { if e.kit.extRunner != nil { e.kit.extRunner.SetFooter(config) } @@ -230,7 +230,7 @@ func (e *extensionAPI) RemoveFooter() { } } -func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig { +func (e *extensionAPI) GetFooter() *ExtensionHeaderFooterConfig { if e.kit.extRunner == nil { return nil } @@ -239,7 +239,7 @@ func (e *extensionAPI) GetFooter() *extensions.HeaderFooterConfig { // Editor -func (e *extensionAPI) SetEditor(config extensions.EditorConfig) { +func (e *extensionAPI) SetEditor(config ExtensionEditorConfig) { if e.kit.extRunner != nil { e.kit.extRunner.SetEditor(config) } @@ -251,7 +251,7 @@ func (e *extensionAPI) ResetEditor() { } } -func (e *extensionAPI) GetEditor() *extensions.EditorConfig { +func (e *extensionAPI) GetEditor() *ExtensionEditorConfig { if e.kit.extRunner == nil { return nil } @@ -260,13 +260,13 @@ func (e *extensionAPI) GetEditor() *extensions.EditorConfig { // UI Visibility -func (e *extensionAPI) SetUIVisibility(v extensions.UIVisibility) { +func (e *extensionAPI) SetUIVisibility(v ExtensionUIVisibility) { if e.kit.extRunner != nil { e.kit.extRunner.SetUIVisibility(v) } } -func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility { +func (e *extensionAPI) GetUIVisibility() *ExtensionUIVisibility { if e.kit.extRunner == nil { return nil } @@ -275,14 +275,14 @@ func (e *extensionAPI) GetUIVisibility() *extensions.UIVisibility { // Tool rendering -func (e *extensionAPI) GetToolRenderer(toolName string) *extensions.ToolRenderConfig { +func (e *extensionAPI) GetToolRenderer(toolName string) *ExtensionToolRenderConfig { if e.kit.extRunner == nil { return nil } return e.kit.extRunner.GetToolRenderer(toolName) } -func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRendererConfig { +func (e *extensionAPI) GetMessageRenderer(name string) *ExtensionMessageRendererConfig { if e.kit.extRunner == nil { return nil } @@ -291,7 +291,7 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender // Session data -func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage { +func (e *extensionAPI) GetSessionMessages() []ExtensionSessionMessage { if e.kit.session == nil { return nil } @@ -299,8 +299,8 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage { // Try to use the legacy iterBranchMessages for backward compatibility // with the default TreeManager adapter if adapter, ok := e.kit.session.(*treeManagerAdapter); ok { - return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage { - return extensions.SessionMessage{ + return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) ExtensionSessionMessage { + return ExtensionSessionMessage{ ID: me.ID, Role: string(msg.Role), Content: msg.Content(), @@ -311,10 +311,10 @@ func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage { // For custom SessionManagers, use the public interface branch := e.kit.session.GetCurrentBranch() - var result []extensions.SessionMessage + var result []ExtensionSessionMessage for _, entry := range branch { if entry.Type == EntryTypeMessage { - result = append(result, extensions.SessionMessage{ + result = append(result, ExtensionSessionMessage{ ID: entry.ID, Role: entry.Role, Content: entry.Content, @@ -332,14 +332,14 @@ func (e *extensionAPI) AppendEntry(extType, data string) (string, error) { return e.kit.session.AppendExtensionData(extType, data) } -func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry { +func (e *extensionAPI) GetEntries(extType string) []ExtensionEntry { if e.kit.session == nil { return nil } entries := e.kit.session.GetExtensionData(extType) - result := make([]extensions.ExtensionEntry, 0, len(entries)) + result := make([]ExtensionEntry, 0, len(entries)) for _, e := range entries { - result = append(result, extensions.ExtensionEntry{ + result = append(result, ExtensionEntry{ ID: e.ID, EntryType: e.ExtType, Data: e.Data, @@ -351,7 +351,7 @@ func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry { // Status bar -func (e *extensionAPI) SetStatus(entry extensions.StatusBarEntry) { +func (e *extensionAPI) SetStatus(entry ExtensionStatusBarEntry) { if e.kit.extRunner != nil { e.kit.extRunner.SetStatusEntry(entry) } @@ -363,7 +363,7 @@ func (e *extensionAPI) RemoveStatus(key string) { } } -func (e *extensionAPI) GetStatusEntries() []extensions.StatusBarEntry { +func (e *extensionAPI) GetStatusEntries() []ExtensionStatusBarEntry { if e.kit.extRunner == nil { return nil } @@ -394,12 +394,12 @@ func (e *extensionAPI) GetShortcuts() map[string]func() { // Tools -func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo { +func (e *extensionAPI) GetToolInfos() []ExtensionToolInfo { agentTools := e.kit.agent.GetTools() coreCount := e.kit.agent.GetCoreToolCount() mcpCount := e.kit.agent.GetMCPToolCount() - result := make([]extensions.ToolInfo, 0, len(agentTools)) + result := make([]ExtensionToolInfo, 0, len(agentTools)) for i, t := range agentTools { info := t.Info() source := "core" @@ -412,7 +412,7 @@ func (e *extensionAPI) GetToolInfos() []extensions.ToolInfo { if e.kit.extRunner != nil && e.kit.extRunner.IsToolDisabled(info.Name) { enabled = false } - result = append(result, extensions.ToolInfo{ + result = append(result, ExtensionToolInfo{ Name: info.Name, Description: info.Description, Source: source, @@ -505,7 +505,7 @@ func (e *extensionAPI) EmitBeforeSessionSwitch(switchReason string) (cancelled b // Commands -func (e *extensionAPI) Commands() []extensions.CommandDef { +func (e *extensionAPI) Commands() []ExtensionCommandDef { if e.kit.extRunner == nil { return nil }