From e8e99b19a83e1fb59e4107b5bb50a20f0bd1b9cd Mon Sep 17 00:00:00 2001 From: Ed Zynda Date: Thu, 11 Jun 2026 16:13:18 +0300 Subject: [PATCH] refactor: dedupe cross-package logic and remove dead code from audit (#58) * Remove dead code: 5 unused symbols across internal packages - internal/models: LoadModelSettingsFromConfig (zero refs) - internal/prompts: PromptTemplate.ExpandWithArgs (zero refs) - internal/app: NewMessageStore (tests migrated to NewMessageStoreWithMessages) - internal/config: HasEnvVars (+ its test) - internal/core: ContextWithSudoPassword (test migrated to context.WithValue) * pkg/kit: use TreeManager alias in exported signatures NewTreeManagerAdapter and InitTreeSession now spell their signatures with the public kit.TreeManager alias instead of internal/session.TreeManager, so go doc renders domain types rather than internal paths. * Consolidate tool-kind classification into internal/extensions coreToolKinds + toolKindFor were duplicated verbatim in internal/extensions/wrapper.go and pkg/kit/events.go, risking silent divergence between extension events and SDK events. Single source of truth now lives in internal/extensions/toolkinds.go; pkg/kit re-exports the constants. * Consolidate Anthropic OAuth detection and usage-tracker refresh The 'is the active Anthropic credential a stored OAuth token' check was copy-pasted at 5 sites, all prefix-matching the magic string 'stored OAuth' produced in internal/auth. Now: - internal/auth: new CredentialSourceOAuth constant + IsAnthropicOAuth() - internal/ui: new UpdateUsageTrackerForModel(); CreateUsageTracker and SetupCLI share lookupTrackableModel (SetupCLI no longer re-inlines the tracker construction) - cmd/root.go + cmd/extension_context.go: verbatim-duplicated tracker refresh blocks replaced with ui.UpdateUsageTrackerForModel - pkg/kit isAnthropicOAuth delegates to auth.IsAnthropicOAuth - internal/models compares source against the constant * pkg/kit: consolidate model-path helpers and argument tokenizer - ExtractModelFromPath mis-parsed model IDs containing '/' (e.g. 'openrouter/meta/llama' -> 'meta'); it now delegates to RemoveProviderFromModel and is deprecated alongside ExtractProviderFromPath (-> GetCurrentProvider) - parseFields delegated to prompts.ParseCommandArgs so extension argument parsing and builtin prompt-template parsing share one quote/escape grammar; ParseCommandArgs now also splits on tabs (superset of both previous tokenizers) * Unify the two {{variable}} template engines internal/skills and pkg/kit/template_bridge each had their own grammar: skills rejected '{{ name }}' (whitespace) but allowed digit-first names; the bridge was the opposite. A template behaved differently depending on whether it was loaded as a skill prompt or via the extension API. internal/skills is now the single engine using the superset grammar (\{\{\s*(\w+)\s*\}\}); pkg/kit ParseTemplate/RenderTemplate are thin adapters over it. Expand is now regex-based so whitespace placeholders expand consistently; missing variables are still left as-is. * internal/ui: extract switchModel helper for model-switch flow The model-selector handler (ModelSelectedMsg) and /model slash command duplicated the full switch sequence (thinking-level fallback, setModel, display-state update, preference persistence, ModelChange emit) and had already drifted in ordering. Both now call a single switchModel method. Display state is still updated directly (no prog.Send from Update). * extbridge: extract shared BaseContext for extension wiring cmd/extension_context.go and internal/acpserver/session.go each built a giant extensions.Context literal, duplicating ~15 delegation closures (GetContextStats, GetMessages, AppendEntry, options, SetModel core, Complete, SpawnSubagent, ...) that had to be kept in sync by hand. New data-access fields had to be wired in both places or ACP-mode extensions silently got nil function fields. extbridge.BaseContext now provides the headless half; both call sites overlay only their UI-specific closures. As a side effect ACP mode gains previously-missing APIs (state, tree navigation, skills, template parsing, model resolution) that were nil before. The interactive TUI keeps its exact SetModel/ReloadExtensions ordering via overrides. * internal/tools: extract withOAuthRetry and marshalToolResult helpers ExecuteTool repeated the OAuth-error/re-auth/retry stanza verbatim twice (sync and task-augmented paths) and the marshal-and-wrap stanza four times. Both are now single helpers with identical error strings, so a fix to OAuth retry or error categorization applies everywhere at once. * internal/ui: extract buildShareFile with defer-based cleanup handleShareCommand repeated the close/remove/print/return cleanup chain four times across its temp-file write error paths. File assembly now lives in buildShareFile with a single deferred cleanup on error. * cmd: extract flag validation, preference restore, and provider-URL routing from runNormalMode runNormalMode opened with ~150 lines of policy logic (flag-combination validation, persisted model/thinking-level preference restoration, and two subtle --provider-url model-rewrite rules). These are now standalone functions (validateModeFlags, restorePersistedPreferences, applyProviderURLRouting) so the routing policy is independently readable and testable. Behaviour unchanged; ordering preserved. * fix: address review findings on SDK godoc and nil guard - pkg/kit: remove internal package paths from exported godoc on ParseTemplate and the ToolKind* constants (SDK doc surface must not reference internal packages) - internal/tools: guard marshalToolResult against a nil CallToolResult (json.Marshal(nil) succeeds as 'null', then result.IsError panics if a client returns nil result with nil error) Skipped the TreeNode Children deep-copy suggestion: the slice already comes from TreeManager.GetChildren which returns a fresh copy per call into a throwaway intermediate, so no internal state is exposed. --- cmd/extension_context.go | 705 ++++++++++----------------- cmd/root.go | 77 +-- internal/acpserver/session.go | 153 +++--- internal/app/messages.go | 5 - internal/app/messages_test.go | 20 +- internal/auth/credentials.go | 16 +- internal/config/substitution.go | 6 - internal/config/substitution_test.go | 38 -- internal/core/bash.go | 6 - internal/core/bash_test.go | 2 +- internal/extbridge/context.go | 234 +++++++++ internal/extensions/toolkinds.go | 38 ++ internal/extensions/wrapper.go | 23 +- internal/models/custom.go | 7 - internal/models/providers.go | 2 +- internal/prompts/template.go | 13 +- internal/skills/templates.go | 19 +- internal/tools/mcp.go | 117 ++--- internal/ui/factory.go | 64 ++- internal/ui/model.go | 200 ++++---- pkg/kit/adapter.go | 4 +- pkg/kit/events.go | 33 +- pkg/kit/extensions_bridge.go | 11 +- pkg/kit/kit.go | 2 +- pkg/kit/template_bridge.go | 98 +--- 25 files changed, 894 insertions(+), 999 deletions(-) create mode 100644 internal/extbridge/context.go create mode 100644 internal/extensions/toolkinds.go diff --git a/cmd/extension_context.go b/cmd/extension_context.go index e38002bf..17e5b354 100644 --- a/cmd/extension_context.go +++ b/cmd/extension_context.go @@ -4,13 +4,11 @@ import ( "context" "fmt" "os" - "strings" "github.com/spf13/viper" "golang.org/x/term" "github.com/mark3labs/kit/internal/app" - "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/extbridge" "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/models" @@ -35,451 +33,276 @@ type extensionContextDeps struct { // the three print routes appropriately for their phase (startup buffering // vs. live runtime routing). // -// This consolidates two near-identical 400-line literal expressions that -// previously appeared inline in runNormalMode. +// The headless half (data access, state, options, tree navigation, skills, +// templates, model resolution, subagents) comes from extbridge.BaseContext; +// this function overlays the TUI-specific fields and overrides SetModel / +// ReloadExtensions with TUI-aware versions. func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context { kitInstance := deps.kitInstance appInstance := deps.appInstance usageTracker := deps.usageTracker - ctx := deps.ctx - return extensions.Context{ - CWD: deps.cwd, - Model: deps.modelName, - Interactive: deps.interactive, - PrintBlock: func(opts extensions.PrintBlockOpts) { - appInstance.PrintBlockFromExtension(opts) - }, - SendMessage: func(text string) { appInstance.Run(text) }, - CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) }, - Abort: func() { appInstance.Abort() }, - IsIdle: func() bool { return !appInstance.IsBusy() }, - Compact: func(cfg extensions.CompactConfig) error { - return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError) - }, - SendMultimodalMessage: func(text string, files []extensions.FilePart) { - parts := make([]kit.LLMFilePart, len(files)) - for i, f := range files { - parts[i] = kit.LLMFilePart{ - Filename: f.Filename, - Data: f.Data, - MediaType: f.MediaType, - } - } - appInstance.RunWithFiles(text, parts) - }, - GetSessionUsage: func() extensions.SessionUsage { - if usageTracker == nil { - return extensions.SessionUsage{} - } - stats := usageTracker.GetSessionStats() - return extensions.SessionUsage{ - TotalInputTokens: stats.TotalInputTokens, - TotalOutputTokens: stats.TotalOutputTokens, - TotalCacheReadTokens: stats.TotalCacheReadTokens, - TotalCacheWriteTokens: stats.TotalCacheWriteTokens, - TotalCost: stats.TotalCost, - RequestCount: stats.RequestCount, - } - }, - Exit: func() { appInstance.QuitFromExtension() }, - SetWidget: func(config extensions.WidgetConfig) { - kitInstance.Extensions().SetWidget(config) - go appInstance.NotifyWidgetUpdate() - }, - RemoveWidget: func(id string) { - kitInstance.Extensions().RemoveWidget(id) - go appInstance.NotifyWidgetUpdate() - }, - SetHeader: func(config extensions.HeaderFooterConfig) { - kitInstance.Extensions().SetHeader(config) - go appInstance.NotifyWidgetUpdate() - }, - RemoveHeader: func() { - kitInstance.Extensions().RemoveHeader() - go appInstance.NotifyWidgetUpdate() - }, - SetFooter: func(config extensions.HeaderFooterConfig) { - kitInstance.Extensions().SetFooter(config) - go appInstance.NotifyWidgetUpdate() - }, - RemoveFooter: func() { - kitInstance.Extensions().RemoveFooter() - go appInstance.NotifyWidgetUpdate() - }, - PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult { - ch := make(chan app.PromptResponse, 1) - appInstance.SendPromptRequest(app.PromptRequestEvent{ - PromptType: "select", - Message: config.Message, - Options: config.Options, - ResponseCh: ch, - }) - resp := <-ch - if resp.Cancelled { - return extensions.PromptSelectResult{Cancelled: true} - } - return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index} - }, - PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult { - ch := make(chan app.PromptResponse, 1) - def := "false" - if config.DefaultValue { - def = "true" - } - appInstance.SendPromptRequest(app.PromptRequestEvent{ - PromptType: "confirm", - Message: config.Message, - Default: def, - ResponseCh: ch, - }) - resp := <-ch - if resp.Cancelled { - return extensions.PromptConfirmResult{Cancelled: true} - } - return extensions.PromptConfirmResult{Value: resp.Confirmed} - }, - PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult { - ch := make(chan app.PromptResponse, 1) - appInstance.SendPromptRequest(app.PromptRequestEvent{ - PromptType: "input", - Message: config.Message, - Placeholder: config.Placeholder, - Default: config.Default, - ResponseCh: ch, - }) - resp := <-ch - if resp.Cancelled { - return extensions.PromptInputResult{Cancelled: true} - } - return extensions.PromptInputResult{Value: resp.Value} - }, - SetUIVisibility: func(v extensions.UIVisibility) { - kitInstance.Extensions().SetUIVisibility(v) - go appInstance.NotifyWidgetUpdate() - }, - GetContextStats: func() extensions.ContextStats { - s := kitInstance.GetContextStats() - return extensions.ContextStats{ - EstimatedTokens: s.EstimatedTokens, - ContextLimit: s.ContextLimit, - UsagePercent: s.UsagePercent, - MessageCount: s.MessageCount, - } - }, - SetEditor: func(config extensions.EditorConfig) { - kitInstance.Extensions().SetEditor(config) - // Always use a goroutine for NotifyWidgetUpdate: prog.Send() - // deadlocks if called synchronously from inside BubbleTea's - // Update() handler. All call sites use go-routines uniformly. - go appInstance.NotifyWidgetUpdate() - }, - ResetEditor: func() { - kitInstance.Extensions().ResetEditor() - go appInstance.NotifyWidgetUpdate() - }, - GetMessages: func() []extensions.SessionMessage { - return kitInstance.Extensions().GetSessionMessages() - }, - GetSessionPath: func() string { - return kitInstance.GetSessionPath() - }, - AppendEntry: func(entryType string, data string) (string, error) { - return kitInstance.Extensions().AppendEntry(entryType, data) - }, - GetEntries: func(entryType string) []extensions.ExtensionEntry { - return kitInstance.Extensions().GetEntries(entryType) - }, - SetState: func(key string, value string) { - kitInstance.Extensions().SetState(key, value) - }, - GetState: func(key string) (string, bool) { - return kitInstance.Extensions().GetState(key) - }, - DeleteState: func(key string) { - kitInstance.Extensions().DeleteState(key) - }, - ListState: func() []string { - return kitInstance.Extensions().ListState() - }, - SetEditorText: func(text string) { - appInstance.SetEditorTextFromExtension(text) - }, - SetStatus: func(key string, text string, priority int) { - kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{ - Key: key, - Text: text, - Priority: priority, - }) - go appInstance.NotifyWidgetUpdate() - }, - RemoveStatus: func(key string) { - kitInstance.Extensions().RemoveStatus(key) - go appInstance.NotifyWidgetUpdate() - }, - GetOption: func(name string) string { - return kitInstance.Extensions().GetOption(name) - }, - SetOption: func(name string, value string) { - kitInstance.Extensions().SetOption(name, value) - }, - SetModel: func(modelString string) error { - // Capture previous model for the ModelChange event. - previousModel := kitInstance.Extensions().GetContext().Model - err := kitInstance.SetModel(context.Background(), modelString) - if err != nil { - return err - } - // Notify TUI so it updates model in status bar. - p, m, _ := models.ParseModelString(modelString) - appInstance.NotifyModelChanged(p, m) - // Update the context's Model field so handlers see it. - kitInstance.Extensions().UpdateContextModel(modelString) - // Fire OnModelChange event to extensions. - kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension") - // Update usage tracker with new model info for correct token counting. - if usageTracker != nil { - newProvider, newModel, _ := models.ParseModelString(modelString) - if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" { - registry := models.GetGlobalRegistry() - if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil { - // Check OAuth status for Anthropic models - isOAuth := false - if newProvider == "anthropic" { - _, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key")) - if err == nil && strings.HasPrefix(source, "stored OAuth") { - isOAuth = true - } - } - usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth) - } - } - } - return nil - }, - GetAvailableModels: func() []extensions.ModelInfoEntry { - return kitInstance.GetAvailableModels() - }, - EmitCustomEvent: func(name string, data string) { - kitInstance.Extensions().EmitCustomEvent(name, data) - }, - Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) { - return kitInstance.ExecuteCompletion(context.Background(), req) - }, - SuspendTUI: func(callback func()) error { - return appInstance.SuspendTUI(callback) - }, - RenderMessage: func(rendererName, content string) { - renderer := kitInstance.Extensions().GetMessageRenderer(rendererName) - if renderer == nil || renderer.Render == nil { - appInstance.PrintFromExtension("", content) - return - } - w, _, _ := term.GetSize(int(os.Stdout.Fd())) - if w == 0 { - w = 80 - } - rendered := renderer.Render(content, w) - appInstance.PrintFromExtension("", rendered) - }, - ReloadExtensions: func() error { - err := kitInstance.Extensions().Reload() - if err != nil { - return err - } - // Notify TUI that widgets/status/commands may have changed. - go appInstance.NotifyWidgetUpdate() - return nil - }, - GetAllTools: func() []extensions.ToolInfo { - return kitInstance.Extensions().GetToolInfos() - }, - SetActiveTools: func(names []string) { - kitInstance.Extensions().SetActiveTools(names) - }, - RegisterTheme: func(name string, config extensions.ThemeColorConfig) { - tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} } - ui.RegisterThemeFromConfig(name, - tc(config.Primary), tc(config.Secondary), - tc(config.Success), tc(config.Warning), - tc(config.Error), tc(config.Info), - tc(config.Text), tc(config.Muted), - tc(config.VeryMuted), tc(config.Background), - tc(config.Border), tc(config.MutedBorder), - tc(config.System), tc(config.Tool), - tc(config.Accent), tc(config.Highlight), - tc(config.MdHeading), tc(config.MdLink), - tc(config.MdKeyword), tc(config.MdString), - tc(config.MdNumber), tc(config.MdComment), - ) - }, - SetTheme: func(name string) error { - return ui.ApplyTheme(name) - }, - ListThemes: func() []string { - return ui.ListThemes() - }, - ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult { - ch := make(chan app.OverlayResponse, 1) - appInstance.SendOverlayRequest(app.OverlayRequestEvent{ - Title: config.Title, - Content: config.Content.Text, - Markdown: config.Content.Markdown, - BorderColor: config.Style.BorderColor, - Background: config.Style.Background, - Width: config.Width, - MaxHeight: config.MaxHeight, - Anchor: string(config.Anchor), - Actions: config.Actions, - ResponseCh: ch, - }) - resp := <-ch - if resp.Cancelled { - return extensions.OverlayResult{Cancelled: true, Index: -1} - } - return extensions.OverlayResult{ - Action: resp.Action, - Index: resp.Index, - } - }, - SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) { - return extbridge.SpawnSubagent(ctx, kitInstance, config) - }, - // ------------------------------------------------------------------- - // Tree Navigation API - // ------------------------------------------------------------------- - GetTreeNode: func(entryID string) *extensions.TreeNode { - node := kitInstance.GetTreeNode(entryID) - if node == nil { - return nil - } - return &extensions.TreeNode{ - ID: node.ID, - ParentID: node.ParentID, - Type: node.Type, - Role: node.Role, - Content: node.Content, - Model: node.Model, - Provider: node.Provider, - Timestamp: node.Timestamp, - Children: node.Children, - } - }, - GetCurrentBranch: func() []extensions.TreeNode { - nodes := kitInstance.GetCurrentBranch() - result := make([]extensions.TreeNode, len(nodes)) - for i, n := range nodes { - result[i] = extensions.TreeNode{ - ID: n.ID, - ParentID: n.ParentID, - Type: n.Type, - Role: n.Role, - Content: n.Content, - Model: n.Model, - Provider: n.Provider, - Timestamp: n.Timestamp, - Children: n.Children, - } - } - return result - }, - GetChildren: func(parentID string) []string { - return kitInstance.GetChildren(parentID) - }, - NavigateTo: func(entryID string) extensions.TreeNavigationResult { - err := kitInstance.NavigateTo(entryID) - if err != nil { - return extensions.TreeNavigationResult{Success: false, Error: err.Error()} - } - return extensions.TreeNavigationResult{Success: true} - }, - SummarizeBranch: func(fromID, toID string) string { - summary, _ := kitInstance.SummarizeBranch(fromID, toID) - return summary - }, - CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult { - err := kitInstance.CollapseBranch(fromID, toID, summary) - if err != nil { - return extensions.TreeNavigationResult{Success: false, Error: err.Error()} - } - return extensions.TreeNavigationResult{Success: true} - }, + ec := extbridge.BaseContext(deps.ctx, kitInstance) - // ------------------------------------------------------------------- - // Skill Loading API - // ------------------------------------------------------------------- - LoadSkill: func(path string) (*extensions.Skill, string) { - s, err := kitInstance.LoadSkillForExtension(path) - return s, err - }, - LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult { - return kitInstance.LoadSkillsFromDirForExtension(dir) - }, - DiscoverSkills: func() extensions.SkillLoadResult { - skills := kitInstance.DiscoverSkillsForExtension() - return extensions.SkillLoadResult{Skills: skills} - }, - InjectSkillAsContext: func(skillName string) string { - skills := kitInstance.DiscoverSkillsForExtension() - for _, s := range skills { - if s.Name == skillName { - appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content)) - return "" - } - } - return fmt.Sprintf("skill not found: %s", skillName) - }, - InjectRawSkillAsContext: func(path string) string { - s, err := kitInstance.LoadSkillForExtension(path) - if err != "" { - return err - } - appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content)) - return "" - }, - GetAvailableSkills: func() []extensions.Skill { - return kitInstance.DiscoverSkillsForExtension() - }, + ec.CWD = deps.cwd + ec.Model = deps.modelName + ec.Interactive = deps.interactive - // ------------------------------------------------------------------- - // Template Parsing API - // ------------------------------------------------------------------- - ParseTemplate: func(name, content string) extensions.PromptTemplate { - return kit.ParseTemplate(name, content) - }, - RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string { - return kit.RenderTemplate(tpl, vars) - }, - ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult { - return kit.ParseArguments(input, pattern) - }, - SimpleParseArguments: func(input string, count int) []string { - return kit.SimpleParseArguments(input, count) - }, - EvaluateModelConditional: func(condition string) bool { - return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition) - }, - RenderWithModelConditionals: func(content string) string { - return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model) - }, - - // ------------------------------------------------------------------- - // Model Resolution API - // ------------------------------------------------------------------- - ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult { - return kit.ResolveModelChain(preferences) - }, - GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) { - return kit.GetModelCapabilities(model) - }, - CheckModelAvailable: func(model string) bool { - return kit.CheckModelAvailable(model) - }, - GetCurrentProvider: func() string { - return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model) - }, - GetCurrentModelID: func() string { - return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model) - }, + ec.PrintBlock = func(opts extensions.PrintBlockOpts) { + appInstance.PrintBlockFromExtension(opts) } + ec.SendMessage = func(text string) { appInstance.Run(text) } + ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) } + ec.Abort = func() { appInstance.Abort() } + ec.IsIdle = func() bool { return !appInstance.IsBusy() } + ec.Compact = func(cfg extensions.CompactConfig) error { + return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError) + } + ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) { + parts := make([]kit.LLMFilePart, len(files)) + for i, f := range files { + parts[i] = kit.LLMFilePart{ + Filename: f.Filename, + Data: f.Data, + MediaType: f.MediaType, + } + } + appInstance.RunWithFiles(text, parts) + } + ec.GetSessionUsage = func() extensions.SessionUsage { + if usageTracker == nil { + return extensions.SessionUsage{} + } + stats := usageTracker.GetSessionStats() + return extensions.SessionUsage{ + TotalInputTokens: stats.TotalInputTokens, + TotalOutputTokens: stats.TotalOutputTokens, + TotalCacheReadTokens: stats.TotalCacheReadTokens, + TotalCacheWriteTokens: stats.TotalCacheWriteTokens, + TotalCost: stats.TotalCost, + RequestCount: stats.RequestCount, + } + } + ec.Exit = func() { appInstance.QuitFromExtension() } + + // TUI widgets/chrome — mutate runner state, then notify the TUI. + // Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks + // if called synchronously from inside BubbleTea's Update() handler. + // All call sites use go-routines uniformly. + ec.SetWidget = func(config extensions.WidgetConfig) { + kitInstance.Extensions().SetWidget(config) + go appInstance.NotifyWidgetUpdate() + } + ec.RemoveWidget = func(id string) { + kitInstance.Extensions().RemoveWidget(id) + go appInstance.NotifyWidgetUpdate() + } + ec.SetHeader = func(config extensions.HeaderFooterConfig) { + kitInstance.Extensions().SetHeader(config) + go appInstance.NotifyWidgetUpdate() + } + ec.RemoveHeader = func() { + kitInstance.Extensions().RemoveHeader() + go appInstance.NotifyWidgetUpdate() + } + ec.SetFooter = func(config extensions.HeaderFooterConfig) { + kitInstance.Extensions().SetFooter(config) + go appInstance.NotifyWidgetUpdate() + } + ec.RemoveFooter = func() { + kitInstance.Extensions().RemoveFooter() + go appInstance.NotifyWidgetUpdate() + } + ec.SetUIVisibility = func(v extensions.UIVisibility) { + kitInstance.Extensions().SetUIVisibility(v) + go appInstance.NotifyWidgetUpdate() + } + ec.SetEditor = func(config extensions.EditorConfig) { + kitInstance.Extensions().SetEditor(config) + go appInstance.NotifyWidgetUpdate() + } + ec.ResetEditor = func() { + kitInstance.Extensions().ResetEditor() + go appInstance.NotifyWidgetUpdate() + } + ec.SetEditorText = func(text string) { + appInstance.SetEditorTextFromExtension(text) + } + ec.SetStatus = func(key string, text string, priority int) { + kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{ + Key: key, + Text: text, + Priority: priority, + }) + go appInstance.NotifyWidgetUpdate() + } + ec.RemoveStatus = func(key string) { + kitInstance.Extensions().RemoveStatus(key) + go appInstance.NotifyWidgetUpdate() + } + + // Interactive prompts — channel-based round trips through the TUI. + ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult { + ch := make(chan app.PromptResponse, 1) + appInstance.SendPromptRequest(app.PromptRequestEvent{ + PromptType: "select", + Message: config.Message, + Options: config.Options, + ResponseCh: ch, + }) + resp := <-ch + if resp.Cancelled { + return extensions.PromptSelectResult{Cancelled: true} + } + return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index} + } + ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult { + ch := make(chan app.PromptResponse, 1) + def := "false" + if config.DefaultValue { + def = "true" + } + appInstance.SendPromptRequest(app.PromptRequestEvent{ + PromptType: "confirm", + Message: config.Message, + Default: def, + ResponseCh: ch, + }) + resp := <-ch + if resp.Cancelled { + return extensions.PromptConfirmResult{Cancelled: true} + } + return extensions.PromptConfirmResult{Value: resp.Confirmed} + } + ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult { + ch := make(chan app.PromptResponse, 1) + appInstance.SendPromptRequest(app.PromptRequestEvent{ + PromptType: "input", + Message: config.Message, + Placeholder: config.Placeholder, + Default: config.Default, + ResponseCh: ch, + }) + resp := <-ch + if resp.Cancelled { + return extensions.PromptInputResult{Cancelled: true} + } + return extensions.PromptInputResult{Value: resp.Value} + } + ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult { + ch := make(chan app.OverlayResponse, 1) + appInstance.SendOverlayRequest(app.OverlayRequestEvent{ + Title: config.Title, + Content: config.Content.Text, + Markdown: config.Content.Markdown, + BorderColor: config.Style.BorderColor, + Background: config.Style.Background, + Width: config.Width, + MaxHeight: config.MaxHeight, + Anchor: string(config.Anchor), + Actions: config.Actions, + ResponseCh: ch, + }) + resp := <-ch + if resp.Cancelled { + return extensions.OverlayResult{Cancelled: true, Index: -1} + } + return extensions.OverlayResult{ + Action: resp.Action, + Index: resp.Index, + } + } + ec.SuspendTUI = func(callback func()) error { + return appInstance.SuspendTUI(callback) + } + + // TUI-aware model switch: also notifies the TUI status bar and + // refreshes the usage tracker for correct token counting. + ec.SetModel = func(modelString string) error { + // Capture previous model for the ModelChange event. + previousModel := kitInstance.Extensions().GetContext().Model + err := kitInstance.SetModel(context.Background(), modelString) + if err != nil { + return err + } + // Notify TUI so it updates model in status bar. + p, m, _ := models.ParseModelString(modelString) + appInstance.NotifyModelChanged(p, m) + // Update the context's Model field so handlers see it. + kitInstance.Extensions().UpdateContextModel(modelString) + // Fire OnModelChange event to extensions. + kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension") + // Update usage tracker with new model info for correct token counting. + ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key")) + return nil + } + + ec.RenderMessage = func(rendererName, content string) { + renderer := kitInstance.Extensions().GetMessageRenderer(rendererName) + if renderer == nil || renderer.Render == nil { + appInstance.PrintFromExtension("", content) + return + } + w, _, _ := term.GetSize(int(os.Stdout.Fd())) + if w == 0 { + w = 80 + } + rendered := renderer.Render(content, w) + appInstance.PrintFromExtension("", rendered) + } + ec.ReloadExtensions = func() error { + err := kitInstance.Extensions().Reload() + if err != nil { + return err + } + // Notify TUI that widgets/status/commands may have changed. + go appInstance.NotifyWidgetUpdate() + return nil + } + + // Theme management (TUI only). + ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) { + tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} } + ui.RegisterThemeFromConfig(name, + tc(config.Primary), tc(config.Secondary), + tc(config.Success), tc(config.Warning), + tc(config.Error), tc(config.Info), + tc(config.Text), tc(config.Muted), + tc(config.VeryMuted), tc(config.Background), + tc(config.Border), tc(config.MutedBorder), + tc(config.System), tc(config.Tool), + tc(config.Accent), tc(config.Highlight), + tc(config.MdHeading), tc(config.MdLink), + tc(config.MdKeyword), tc(config.MdString), + tc(config.MdNumber), tc(config.MdComment), + ) + } + ec.SetTheme = func(name string) error { + return ui.ApplyTheme(name) + } + ec.ListThemes = func() []string { + return ui.ListThemes() + } + + // Skill context-injection (drives a new agent turn through the TUI). + ec.InjectSkillAsContext = func(skillName string) string { + skills := kitInstance.DiscoverSkillsForExtension() + for _, s := range skills { + if s.Name == skillName { + appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content)) + return "" + } + } + return fmt.Sprintf("skill not found: %s", skillName) + } + ec.InjectRawSkillAsContext = func(path string) string { + s, err := kitInstance.LoadSkillForExtension(path) + if err != "" { + return err + } + appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content)) + return "" + } + + return ec } diff --git a/cmd/root.go b/cmd/root.go index 99d81c63..37753239 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -12,7 +12,6 @@ import ( tea "charm.land/bubbletea/v2" "github.com/mark3labs/kit/internal/app" - "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/config" "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/models" @@ -677,8 +676,8 @@ func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() { } } -func runNormalMode(ctx context.Context) error { - // Validate flag combinations +// validateModeFlags rejects invalid flag combinations for the root command. +func validateModeFlags() error { if quietFlag && positionalPrompt == "" { return fmt.Errorf("--quiet requires a prompt (e.g. kit \"your question\" --quiet)") } @@ -691,21 +690,14 @@ func runNormalMode(ctx context.Context) error { if noExitFlag && positionalPrompt == "" { return fmt.Errorf("--no-exit requires a prompt (e.g. kit \"your question\" --no-exit)") } + return nil +} - // Set up logging - if debugMode { - log.SetFlags(log.LstdFlags | log.Lshortfile) - } - - // Update debug mode from viper - if viper.GetBool("debug") && !debugMode { - debugMode = viper.GetBool("debug") - log.SetFlags(log.LstdFlags | log.Lshortfile) - } - - // Restore persisted model preference when no explicit --model flag or - // config file model is set. Precedence: CLI flag > config file > saved - // preference > built-in default. This mirrors how themes are persisted. +// restorePersistedPreferences applies saved model / thinking-level +// preferences into viper when neither a CLI flag nor a config-file value +// takes precedence. Precedence: CLI flag > config file > saved preference > +// built-in default. This mirrors how themes are persisted. +func restorePersistedPreferences() { // Skip custom/* models unless --provider-url is also provided, since the // custom provider requires a URL that was only valid for the previous session. if !modelFlagChanged && !viper.InConfig("model") { @@ -724,6 +716,15 @@ func runNormalMode(ctx context.Context) error { viper.Set("thinking-level", pref) } } +} + +// applyProviderURLRouting rewrites the model in viper when --provider-url +// is set, routing requests through the "custom" (OpenAI-compatible) +// provider. Must run after restorePersistedPreferences. +func applyProviderURLRouting() { + if viper.GetString("provider-url") == "" { + return + } // When --provider-url is set but no explicit --model was provided, // default to "custom/custom" so the user doesn't need to remember a @@ -731,7 +732,7 @@ func runNormalMode(ctx context.Context) error { // This intentionally overrides saved preferences but respects config-file // models — if you specify a model in ~/.kit.yml, it will be used with // custom/custom's provider routing. - if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") { + if !modelFlagChanged && !viper.InConfig("model") { viper.Set("model", "custom/custom") } @@ -746,7 +747,7 @@ func runNormalMode(ctx context.Context) error { // to point a non-OpenAI wire (Anthropic, Google, ...) at a proxy URL, // use the explicit `custom/` form to opt out of the rewrite by // configuring the proxy as that provider in your config file instead. - if viper.GetString("provider-url") != "" && modelFlagChanged { + if modelFlagChanged { model := viper.GetString("model") if model != "" { name := model @@ -758,6 +759,26 @@ func runNormalMode(ctx context.Context) error { } } } +} + +func runNormalMode(ctx context.Context) error { + if err := validateModeFlags(); err != nil { + return err + } + + // Set up logging + if debugMode { + log.SetFlags(log.LstdFlags | log.Lshortfile) + } + + // Update debug mode from viper + if viper.GetBool("debug") && !debugMode { + debugMode = viper.GetBool("debug") + log.SetFlags(log.LstdFlags | log.Lshortfile) + } + + restorePersistedPreferences() + applyProviderURLRouting() // Load MCP configuration. mcpConfig, err := config.LoadAndValidateConfig() @@ -1164,23 +1185,7 @@ func runNormalMode(ctx context.Context) error { // NotifyModelChanged calls prog.Send() which deadlocks. The UI layer // updates m.providerName and m.modelName directly after setModel returns. // Update usage tracker with new model info for correct token counting. - if usageTracker != nil { - newProvider, newModel, _ := models.ParseModelString(modelString) - if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" { - registry := models.GetGlobalRegistry() - if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil { - // Check OAuth status for Anthropic models - isOAuth := false - if newProvider == "anthropic" { - _, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key")) - if err == nil && strings.HasPrefix(source, "stored OAuth") { - isOAuth = true - } - } - usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth) - } - } - } + ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key")) return nil } emitModelChangeForUI := func(newModel, previousModel, source string) { diff --git a/internal/acpserver/session.go b/internal/acpserver/session.go index 937cc6ad..5945ef16 100644 --- a/internal/acpserver/session.go +++ b/internal/acpserver/session.go @@ -73,111 +73,70 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, // Wire extension context with headless implementations so extensions // work in ACP mode. TUI-dependent features (widgets, prompts, editor) - // become no-ops or return cancelled; all data/model/tool APIs work - // identically to interactive mode. + // become no-ops or return cancelled; all data/model/tool APIs come from + // extbridge.BaseContext and work identically to interactive mode. if kitInstance.Extensions().HasExtensions() { - kitInstance.Extensions().SetContext(extensions.Context{ - SessionID: sessionID, - CWD: cwd, - Model: kitInstance.GetModelString(), - Interactive: false, + // Use a background context for subagent spawns: the create() ctx is + // request-scoped and may be cancelled before extensions spawn anything. + ec := extbridge.BaseContext(context.Background(), kitInstance) - // Output — route through structured logger. - Print: func(text string) { log.Debug("extension: print", "text", text) }, - PrintInfo: func(text string) { log.Info("extension: info", "text", text) }, - PrintError: func(text string) { log.Error("extension: error", "text", text) }, - PrintBlock: func(opts extensions.PrintBlockOpts) { - log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text) - }, + ec.SessionID = sessionID + ec.CWD = cwd + ec.Model = kitInstance.GetModelString() + ec.Interactive = false - // Message injection — no-ops for now; ACP clients drive prompts. - SendMessage: func(string) {}, - CancelAndSend: func(string) {}, - Exit: func() {}, + // Output — route through structured logger. + ec.Print = func(text string) { log.Debug("extension: print", "text", text) } + ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) } + ec.PrintError = func(text string) { log.Error("extension: error", "text", text) } + ec.PrintBlock = func(opts extensions.PrintBlockOpts) { + log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text) + } - // TUI widgets/chrome — silent no-ops (no TUI in ACP). - SetWidget: func(extensions.WidgetConfig) {}, - RemoveWidget: func(string) {}, - SetHeader: func(extensions.HeaderFooterConfig) {}, - RemoveHeader: func() {}, - SetFooter: func(extensions.HeaderFooterConfig) {}, - RemoveFooter: func() {}, - SetEditor: func(extensions.EditorConfig) {}, - ResetEditor: func() {}, - SetEditorText: func(string) {}, - SetUIVisibility: func(extensions.UIVisibility) {}, - SetStatus: func(string, string, int) {}, - RemoveStatus: func(string) {}, + // Message injection — no-ops for now; ACP clients drive prompts. + ec.SendMessage = func(string) {} + ec.CancelAndSend = func(string) {} + ec.Exit = func() {} - // Interactive prompts — return cancelled (no user to prompt). - PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult { - return extensions.PromptSelectResult{Cancelled: true} - }, - PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult { - return extensions.PromptConfirmResult{Cancelled: true} - }, - PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult { - return extensions.PromptInputResult{Cancelled: true} - }, - ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult { - return extensions.OverlayResult{Cancelled: true, Index: -1} - }, - SuspendTUI: func(callback func()) error { callback(); return nil }, + // TUI widgets/chrome — silent no-ops (no TUI in ACP). + ec.SetWidget = func(extensions.WidgetConfig) {} + ec.RemoveWidget = func(string) {} + ec.SetHeader = func(extensions.HeaderFooterConfig) {} + ec.RemoveHeader = func() {} + ec.SetFooter = func(extensions.HeaderFooterConfig) {} + ec.RemoveFooter = func() {} + ec.SetEditor = func(extensions.EditorConfig) {} + ec.ResetEditor = func() {} + ec.SetEditorText = func(string) {} + ec.SetUIVisibility = func(extensions.UIVisibility) {} + ec.SetStatus = func(string, string, int) {} + ec.RemoveStatus = func(string) {} - // Data access — delegate to Kit instance. - GetContextStats: func() extensions.ContextStats { - s := kitInstance.GetContextStats() - return extensions.ContextStats{ - EstimatedTokens: s.EstimatedTokens, - ContextLimit: s.ContextLimit, - UsagePercent: s.UsagePercent, - MessageCount: s.MessageCount, - } - }, - GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() }, - GetSessionPath: func() string { return kitInstance.GetSessionPath() }, - AppendEntry: func(entryType, data string) (string, error) { - return kitInstance.Extensions().AppendEntry(entryType, data) - }, - GetEntries: func(entryType string) []extensions.ExtensionEntry { - return kitInstance.Extensions().GetEntries(entryType) - }, + // Interactive prompts — return cancelled (no user to prompt). + ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult { + return extensions.PromptSelectResult{Cancelled: true} + } + ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult { + return extensions.PromptConfirmResult{Cancelled: true} + } + ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult { + return extensions.PromptInputResult{Cancelled: true} + } + ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult { + return extensions.OverlayResult{Cancelled: true, Index: -1} + } + ec.SuspendTUI = func(callback func()) error { callback(); return nil } - // Options, model, and tool management. - GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) }, - SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) }, - SetModel: func(modelString string) error { - previousModel := kitInstance.Extensions().GetContext().Model - if err := kitInstance.SetModel(context.Background(), modelString); err != nil { - return err - } - kitInstance.Extensions().UpdateContextModel(modelString) - kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension") - return nil - }, - GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() }, - EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) }, - GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() }, - SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) }, + // Render — fall back to logging. + ec.RenderMessage = func(name, content string) { + renderer := kitInstance.Extensions().GetMessageRenderer(name) + if renderer != nil && renderer.Render != nil { + content = renderer.Render(content, 80) + } + log.Info("extension: message", "renderer", name, "content", content) + } - // LLM completions and subagents. - Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) { - return kitInstance.ExecuteCompletion(context.Background(), req) - }, - SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) { - return extbridge.SpawnSubagent(context.Background(), kitInstance, config) - }, - - // Render — fall back to logging. - RenderMessage: func(name, content string) { - renderer := kitInstance.Extensions().GetMessageRenderer(name) - if renderer != nil && renderer.Render != nil { - content = renderer.Render(content, 80) - } - log.Info("extension: message", "renderer", name, "content", content) - }, - ReloadExtensions: func() error { return kitInstance.Extensions().Reload() }, - }) + kitInstance.Extensions().SetContext(ec) kitInstance.Extensions().EmitSessionStart() } diff --git a/internal/app/messages.go b/internal/app/messages.go index 0ec94d0e..4eaae954 100644 --- a/internal/app/messages.go +++ b/internal/app/messages.go @@ -13,11 +13,6 @@ type MessageStore struct { messages []kit.LLMMessage } -// NewMessageStore creates an empty MessageStore. -func NewMessageStore() *MessageStore { - return &MessageStore{} -} - // NewMessageStoreWithMessages creates a MessageStore pre-populated with the // given messages. This is used when loading an existing session at startup. func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore { diff --git a/internal/app/messages_test.go b/internal/app/messages_test.go index 8c4ce598..c5339eb3 100644 --- a/internal/app/messages_test.go +++ b/internal/app/messages_test.go @@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string { // -------------------------------------------------------------------------- func TestNewMessageStore_empty(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) if s == nil { t.Fatal("expected non-nil store") } @@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) { // -------------------------------------------------------------------------- func TestAdd_appendsMessage(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) s.Add(makeTextMsg("user", "first")) s.Add(makeTextMsg("assistant", "second")) @@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) { } func TestAdd_preservesOrder(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) texts := []string{"a", "b", "c"} for _, t2 := range texts { s.Add(makeTextMsg("user", t2)) @@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) { // -------------------------------------------------------------------------- func TestReplace_swapsHistory(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) s.Add(makeTextMsg("user", "old")) replacement := []kit.LLMMessage{ @@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) { // Replace must deep-copy the incoming slice. func TestReplace_isolatesInput(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) replacement := []kit.LLMMessage{makeTextMsg("user", "original")} s.Replace(replacement) @@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) { // -------------------------------------------------------------------------- func TestGetAll_returnsCopy(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) s.Add(makeTextMsg("user", "hello")) got := s.GetAll() @@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) { } func TestGetAll_emptyStore(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) got := s.GetAll() if len(got) != 0 { t.Fatalf("expected empty slice, got %d elements", len(got)) @@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) { // -------------------------------------------------------------------------- func TestClear_removesAllMessages(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) s.Add(makeTextMsg("user", "a")) s.Add(makeTextMsg("user", "b")) s.Clear() @@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) { } func TestClear_allowsSubsequentAdds(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) s.Add(makeTextMsg("user", "before")) s.Clear() s.Add(makeTextMsg("user", "after")) @@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) { // -------------------------------------------------------------------------- func TestConcurrentAccess(t *testing.T) { - s := NewMessageStore() + s := NewMessageStoreWithMessages(nil) done := make(chan struct{}) // Writer goroutine. diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go index 44d95ff4..92a62d9a 100644 --- a/internal/auth/credentials.go +++ b/internal/auth/credentials.go @@ -513,6 +513,20 @@ func validateAnthropicAPIKey(apiKey string) error { return nil } +// CredentialSourceOAuth is the source description returned by +// GetAnthropicAPIKey when the key resolves to stored OAuth credentials. +// Consumers should compare against this constant (or use IsAnthropicOAuth) +// rather than matching the string literal. +const CredentialSourceOAuth = "stored OAuth credentials" + +// IsAnthropicOAuth reports whether the active Anthropic credential resolves +// to a stored OAuth token (in which case the user is not billed per-token). +// flagValue is the --provider-api-key flag value (may be empty). +func IsAnthropicOAuth(flagValue string) bool { + _, source, err := GetAnthropicAPIKey(flagValue) + return err == nil && source == CredentialSourceOAuth +} + // GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order: // 1. Command-line flag value (highest priority) // 2. Stored credentials (OAuth or API key) @@ -535,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) { if err != nil { return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err) } - return token, "stored OAuth credentials", nil + return token, CredentialSourceOAuth, nil } else if creds.Type == "api_key" && creds.APIKey != "" { return creds.APIKey, "stored API key", nil } diff --git a/internal/config/substitution.go b/internal/config/substitution.go index b4b143cd..71f06ada 100644 --- a/internal/config/substitution.go +++ b/internal/config/substitution.go @@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) { return result, nil } - -// HasEnvVars checks if content contains environment variable patterns (${env://...}). -// This is useful for determining if substitution is needed before processing. -func HasEnvVars(content string) bool { - return envVarPattern.MatchString(content) -} diff --git a/internal/config/substitution_test.go b/internal/config/substitution_test.go index 10d2eb69..5445672d 100644 --- a/internal/config/substitution_test.go +++ b/internal/config/substitution_test.go @@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) { }) } } - -func TestHasEnvVars(t *testing.T) { - tests := []struct { - name string - content string - expected bool - }{ - { - name: "has env vars", - content: `{"token": "${env://GITHUB_TOKEN}"}`, - expected: true, - }, - { - name: "has env vars with default", - content: `{"debug": "${env://DEBUG:-false}"}`, - expected: true, - }, - { - name: "no env vars", - content: `{"name": "${username}", "normal": "value"}`, - expected: false, - }, - { - name: "empty content", - content: "", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := HasEnvVars(tt.content) - if result != tt.expected { - t.Errorf("Expected %v, got %v", tt.expected, result) - } - }) - } -} diff --git a/internal/core/bash.go b/internal/core/bash.go index a493f06c..3428fb3e 100644 --- a/internal/core/bash.go +++ b/internal/core/bash.go @@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback { return nil } -// ContextWithSudoPassword returns a new context with the sudo password set. -// When present, the bash tool will use sudo -S to pipe this password to sudo commands. -func ContextWithSudoPassword(ctx context.Context, password string) context.Context { - return context.WithValue(ctx, sudoPasswordKey, password) -} - // sudoPasswordFromContext retrieves the sudo password from context. func sudoPasswordFromContext(ctx context.Context) string { if pw, ok := ctx.Value(sudoPasswordKey).(string); ok { diff --git a/internal/core/bash_test.go b/internal/core/bash_test.go index f6ded4a2..5727e648 100644 --- a/internal/core/bash_test.go +++ b/internal/core/bash_test.go @@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) { func TestSudoPasswordFromContext(t *testing.T) { // Test with password in context - ctx := ContextWithSudoPassword(context.Background(), "secret123") + ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123") pw := sudoPasswordFromContext(ctx) if pw != "secret123" { t.Errorf("expected password 'secret123', got %q", pw) diff --git a/internal/extbridge/context.go b/internal/extbridge/context.go new file mode 100644 index 00000000..6f035a47 --- /dev/null +++ b/internal/extbridge/context.go @@ -0,0 +1,234 @@ +package extbridge + +import ( + "context" + + "github.com/mark3labs/kit/internal/extensions" + kit "github.com/mark3labs/kit/pkg/kit" +) + +// BaseContext returns an extensions.Context populated with the headless, +// TUI-independent delegation fields: data access, state, options, +// model/tool management, completions, subagents, tree navigation, skills, +// template parsing, and model resolution. +// +// Callers overlay their UI-specific fields (print routes, widgets, prompts, +// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value: +// cmd/extension_context.go for the interactive TUI and +// internal/acpserver/session.go for headless ACP mode. Keeping the shared +// half here means a new data-access Context field only has to be wired once. +// +// ctx is used for subagent spawns; pass a long-lived context (not a +// per-request one) so later spawns aren't cancelled prematurely. +func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context { + return extensions.Context{ + // ------------------------------------------------------------------- + // Data access + // ------------------------------------------------------------------- + GetContextStats: func() extensions.ContextStats { + s := kitInstance.GetContextStats() + return extensions.ContextStats{ + EstimatedTokens: s.EstimatedTokens, + ContextLimit: s.ContextLimit, + UsagePercent: s.UsagePercent, + MessageCount: s.MessageCount, + } + }, + GetMessages: func() []extensions.SessionMessage { + return kitInstance.Extensions().GetSessionMessages() + }, + GetSessionPath: func() string { + return kitInstance.GetSessionPath() + }, + AppendEntry: func(entryType string, data string) (string, error) { + return kitInstance.Extensions().AppendEntry(entryType, data) + }, + GetEntries: func(entryType string) []extensions.ExtensionEntry { + return kitInstance.Extensions().GetEntries(entryType) + }, + + // ------------------------------------------------------------------- + // Extension state + // ------------------------------------------------------------------- + SetState: func(key string, value string) { + kitInstance.Extensions().SetState(key, value) + }, + GetState: func(key string) (string, bool) { + return kitInstance.Extensions().GetState(key) + }, + DeleteState: func(key string) { + kitInstance.Extensions().DeleteState(key) + }, + ListState: func() []string { + return kitInstance.Extensions().ListState() + }, + + // ------------------------------------------------------------------- + // Options, model, and tool management + // ------------------------------------------------------------------- + GetOption: func(name string) string { + return kitInstance.Extensions().GetOption(name) + }, + SetOption: func(name string, value string) { + kitInstance.Extensions().SetOption(name, value) + }, + // Headless model switch. The interactive TUI overrides this with a + // version that also notifies the TUI and refreshes the usage tracker. + SetModel: func(modelString string) error { + previousModel := kitInstance.Extensions().GetContext().Model + if err := kitInstance.SetModel(context.Background(), modelString); err != nil { + return err + } + kitInstance.Extensions().UpdateContextModel(modelString) + kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension") + return nil + }, + GetAvailableModels: func() []extensions.ModelInfoEntry { + return kitInstance.GetAvailableModels() + }, + EmitCustomEvent: func(name string, data string) { + kitInstance.Extensions().EmitCustomEvent(name, data) + }, + GetAllTools: func() []extensions.ToolInfo { + return kitInstance.Extensions().GetToolInfos() + }, + SetActiveTools: func(names []string) { + kitInstance.Extensions().SetActiveTools(names) + }, + // Headless reload. The interactive TUI overrides this to also + // refresh widgets/status/commands. + ReloadExtensions: func() error { + return kitInstance.Extensions().Reload() + }, + + // ------------------------------------------------------------------- + // LLM completions and subagents + // ------------------------------------------------------------------- + Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) { + return kitInstance.ExecuteCompletion(context.Background(), req) + }, + SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) { + return SpawnSubagent(ctx, kitInstance, config) + }, + + // ------------------------------------------------------------------- + // Tree Navigation API + // ------------------------------------------------------------------- + GetTreeNode: func(entryID string) *extensions.TreeNode { + node := kitInstance.GetTreeNode(entryID) + if node == nil { + return nil + } + return &extensions.TreeNode{ + ID: node.ID, + ParentID: node.ParentID, + Type: node.Type, + Role: node.Role, + Content: node.Content, + Model: node.Model, + Provider: node.Provider, + Timestamp: node.Timestamp, + Children: node.Children, + } + }, + GetCurrentBranch: func() []extensions.TreeNode { + nodes := kitInstance.GetCurrentBranch() + result := make([]extensions.TreeNode, len(nodes)) + for i, n := range nodes { + result[i] = extensions.TreeNode{ + ID: n.ID, + ParentID: n.ParentID, + Type: n.Type, + Role: n.Role, + Content: n.Content, + Model: n.Model, + Provider: n.Provider, + Timestamp: n.Timestamp, + Children: n.Children, + } + } + return result + }, + GetChildren: func(parentID string) []string { + return kitInstance.GetChildren(parentID) + }, + NavigateTo: func(entryID string) extensions.TreeNavigationResult { + err := kitInstance.NavigateTo(entryID) + if err != nil { + return extensions.TreeNavigationResult{Success: false, Error: err.Error()} + } + return extensions.TreeNavigationResult{Success: true} + }, + SummarizeBranch: func(fromID, toID string) string { + summary, _ := kitInstance.SummarizeBranch(fromID, toID) + return summary + }, + CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult { + err := kitInstance.CollapseBranch(fromID, toID, summary) + if err != nil { + return extensions.TreeNavigationResult{Success: false, Error: err.Error()} + } + return extensions.TreeNavigationResult{Success: true} + }, + + // ------------------------------------------------------------------- + // Skill Loading API (context-injection variants are TUI-specific and + // wired by the interactive overlay) + // ------------------------------------------------------------------- + LoadSkill: func(path string) (*extensions.Skill, string) { + s, err := kitInstance.LoadSkillForExtension(path) + return s, err + }, + LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult { + return kitInstance.LoadSkillsFromDirForExtension(dir) + }, + DiscoverSkills: func() extensions.SkillLoadResult { + skills := kitInstance.DiscoverSkillsForExtension() + return extensions.SkillLoadResult{Skills: skills} + }, + GetAvailableSkills: func() []extensions.Skill { + return kitInstance.DiscoverSkillsForExtension() + }, + + // ------------------------------------------------------------------- + // Template Parsing API + // ------------------------------------------------------------------- + ParseTemplate: func(name, content string) extensions.PromptTemplate { + return kit.ParseTemplate(name, content) + }, + RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string { + return kit.RenderTemplate(tpl, vars) + }, + ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult { + return kit.ParseArguments(input, pattern) + }, + SimpleParseArguments: func(input string, count int) []string { + return kit.SimpleParseArguments(input, count) + }, + EvaluateModelConditional: func(condition string) bool { + return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition) + }, + RenderWithModelConditionals: func(content string) string { + return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model) + }, + + // ------------------------------------------------------------------- + // Model Resolution API + // ------------------------------------------------------------------- + ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult { + return kit.ResolveModelChain(preferences) + }, + GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) { + return kit.GetModelCapabilities(model) + }, + CheckModelAvailable: func(model string) bool { + return kit.CheckModelAvailable(model) + }, + GetCurrentProvider: func() string { + return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model) + }, + GetCurrentModelID: func() string { + return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model) + }, + } +} diff --git a/internal/extensions/toolkinds.go b/internal/extensions/toolkinds.go new file mode 100644 index 00000000..0ac4395b --- /dev/null +++ b/internal/extensions/toolkinds.go @@ -0,0 +1,38 @@ +package extensions + +// ToolKind constants classify what a tool does, enabling UIs to render +// appropriate visualizations (e.g. diff view for edit tools, command+output +// for execute tools) and file trackers to identify which results contain +// modifications. +// +// This is the single source of truth for tool-kind classification; the +// pkg/kit SDK re-exports these constants. +const ( + ToolKindExecute = "execute" // Shell execution (bash) + ToolKindEdit = "edit" // File modification (edit, write) + ToolKindRead = "read" // File reading (read, ls) + ToolKindSearch = "search" // Content/file search (grep, find) + ToolKindSubagent = "agent" // Subagent spawning (subagent) +) + +// coreToolKinds maps built-in tool names to their kind classification. +// MCP and extension tools without an entry default to ToolKindExecute. +var coreToolKinds = map[string]string{ + "bash": ToolKindExecute, + "edit": ToolKindEdit, + "write": ToolKindEdit, + "read": ToolKindRead, + "ls": ToolKindRead, + "grep": ToolKindSearch, + "find": ToolKindSearch, + "subagent": ToolKindSubagent, +} + +// ToolKindFor returns the ToolKind for a given tool name, defaulting to +// ToolKindExecute for unknown tools (including MCP tools). +func ToolKindFor(toolName string) string { + if kind, ok := coreToolKinds[toolName]; ok { + return kind + } + return ToolKindExecute +} diff --git a/internal/extensions/wrapper.go b/internal/extensions/wrapper.go index 6af25692..8119e4c7 100644 --- a/internal/extensions/wrapper.go +++ b/internal/extensions/wrapper.go @@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo return tools } -// coreToolKinds maps built-in tool names to their kind classification. -var coreToolKinds = map[string]string{ - "bash": "execute", - "edit": "edit", - "write": "edit", - "read": "read", - "ls": "read", - "grep": "search", - "find": "search", - "subagent": "agent", -} - -// toolKindFor returns the ToolKind for a given tool name, defaulting to -// "execute" for unknown tools (including MCP tools). -func toolKindFor(toolName string) string { - if kind, ok := coreToolKinds[toolName]; ok { - return kind - } - return "execute" -} - // parseToolArgsJSON attempts to parse JSON-encoded tool args into a map. // Returns nil on failure (non-fatal convenience parsing). func parseToolArgsJSON(input string) map[string]any { @@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil } - kind := toolKindFor(toolName) + kind := ToolKindFor(toolName) // 1. Emit ToolCall — extensions can block execution. if w.runner.HasHandlers(ToolCall) { diff --git a/internal/models/custom.go b/internal/models/custom.go index cce5f5a8..b384312a 100644 --- a/internal/models/custom.go +++ b/internal/models/custom.go @@ -69,13 +69,6 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo { return info } -// LoadModelSettingsFromConfig loads per-model generation parameter overrides -// from the process-global viper store. Keys are "provider/model" strings. -// Returns nil if no model settings are configured. -func LoadModelSettingsFromConfig() map[string]*GenerationParams { - return LoadModelSettingsFrom(viper.GetViper()) -} - // LoadModelSettingsFrom loads per-model generation parameter overrides from the // supplied per-instance store. When v is nil the process-global store is used. // Keys are "provider/model" strings. Returns nil if no model settings are diff --git a/internal/models/providers.go b/internal/models/providers.go index 6995f6dd..c9a16807 100644 --- a/internal/models/providers.go +++ b/internal/models/providers.go @@ -932,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN } // Handle OAuth vs API key authentication - if strings.HasPrefix(source, "stored OAuth") { + if source == auth.CredentialSourceOAuth { httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify) opts = append(opts, anthropic.WithHTTPClient(httpClient)) // Note: For OAuth, the API key is set as a placeholder; the transport handles auth diff --git a/internal/prompts/template.go b/internal/prompts/template.go index cb9eb311..87810a2f 100644 --- a/internal/prompts/template.go +++ b/internal/prompts/template.go @@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) { } // ParseCommandArgs splits a command line into arguments respecting quotes. -// It handles single quotes, double quotes, and backslash escaping. +// It handles single quotes, double quotes, backslash escaping, and splits on +// spaces and tabs. func ParseCommandArgs(input string) []string { var args []string var current strings.Builder @@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string { inDoubleQuote := false escaped := false - for i, r := range input { + for _, r := range input { if escaped { current.WriteRune(r) escaped = false @@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string { continue } - if r == ' ' && !inSingleQuote && !inDoubleQuote { + if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote { if current.Len() > 0 { args = append(args, current.String()) current.Reset() @@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string { } current.WriteRune(r) - _ = i // silence unused warning when we need position later } if current.Len() > 0 { @@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string { args := ParseCommandArgs(argsInput) return SubstituteArgs(t.Content, args) } - -// ExpandWithArgs substitutes the provided arguments into the template content. -func (t *PromptTemplate) ExpandWithArgs(args []string) string { - return SubstituteArgs(t.Content, args) -} diff --git a/internal/skills/templates.go b/internal/skills/templates.go index 7902d662..b3c5c963 100644 --- a/internal/skills/templates.go +++ b/internal/skills/templates.go @@ -18,8 +18,11 @@ type PromptTemplate struct { Variables []string } -// variableRe matches {{variable_name}} placeholders. -var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`) +// variableRe matches {{variable_name}} placeholders, tolerating surrounding +// whitespace inside the braces (e.g. {{ name }}). This is the canonical +// template grammar shared by skill prompts and the extension template API +// (pkg/kit ParseTemplate/RenderTemplate delegate here). +var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`) // NewPromptTemplate creates a PromptTemplate, automatically extracting // variable names from {{...}} placeholders in content. @@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) { // Expand replaces all {{variable}} placeholders with values from the // provided map. Missing variables are left as-is (no error). func (t *PromptTemplate) Expand(values map[string]string) string { - result := t.Content - for k, v := range values { - result = strings.ReplaceAll(result, "{{"+k+"}}", v) - } - return result + return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string { + name := variableRe.FindStringSubmatch(m)[1] + if v, ok := values[name]; ok { + return v + } + return m + }) } // ExpandStrict replaces all {{variable}} placeholders and returns an error diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go index f4a4a66b..d4ab10f0 100644 --- a/internal/tools/mcp.go +++ b/internal/tools/mcp.go @@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO Request: mcp.Request{Method: "tools/call"}, Params: callParams, } - result, callErr := conn.client.CallTool(ctx, callRequest) - if callErr != nil { - if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) { - if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil { - return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr) - } - result, callErr = conn.client.CallTool(ctx, callRequest) - if callErr != nil { - m.connectionPool.HandleConnectionError(mapping.serverName, callErr) - return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr) - } - } else { - m.connectionPool.HandleConnectionError(mapping.serverName, callErr) - return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) - } + var result *mcp.CallToolResult + err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error { + var callErr error + result, callErr = conn.client.CallTool(ctx, callRequest) + return callErr + }) + if err != nil { + return nil, err } - marshaledResult, mErr := json.Marshal(result) - if mErr != nil { - return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) - } - return &MCPToolResult{ - Content: string(marshaledResult), - IsError: result.IsError, - }, nil + return marshalToolResult(result) } // Task-augmented path. Bypass the upstream CallTool helper because its @@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO m.connectionPool.HandleConnectionError(mapping.serverName, callErr) return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) } - marshaledResult, mErr := json.Marshal(result) - if mErr != nil { - return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) - } - return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil + return marshalToolResult(result) } - callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams) - if callErr != nil { - if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) { - if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil { - return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr) - } - callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams) - if callErr != nil { - m.connectionPool.HandleConnectionError(mapping.serverName, callErr) - return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr) - } - } else { - m.connectionPool.HandleConnectionError(mapping.serverName, callErr) - return nil, fmt.Errorf("failed to call mcp tool: %w", callErr) - } + var ( + callResult *mcp.CallToolResult + taskResult *mcp.CreateTaskResult + ) + err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error { + var callErr error + callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams) + return callErr + }) + if err != nil { + return nil, err } // Server chose to answer synchronously — same shape as the no-task path. if callResult != nil { - marshaledResult, mErr := json.Marshal(callResult) - if mErr != nil { - return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) - } - return &MCPToolResult{ - Content: string(marshaledResult), - IsError: callResult.IsError, - }, nil + return marshalToolResult(callResult) } // Asynchronous task path: poll until terminal, then return the result. @@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO } // Adapt TaskResultResult → CallToolResult for downstream JSON shape parity. - adapted := &mcp.CallToolResult{ + return marshalToolResult(&mcp.CallToolResult{ Content: final.Content, StructuredContent: final.StructuredContent, IsError: final.IsError, + }) +} + +// withOAuthRetry runs call once; when it fails with an OAuth error and an +// OAuth flow is configured, it re-authorizes the server and retries once. +// Connection failures are reported to the pool and wrapped uniformly. This +// consolidates the retry/error chain shared by the synchronous and +// task-augmented tool-call paths. +func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error { + callErr := call() + if callErr == nil { + return nil } - marshaledResult, mErr := json.Marshal(adapted) - if mErr != nil { - return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr) + if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) { + if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil { + return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr) + } + if callErr = call(); callErr != nil { + m.connectionPool.HandleConnectionError(serverName, callErr) + return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr) + } + return nil + } + m.connectionPool.HandleConnectionError(serverName, callErr) + return fmt.Errorf("failed to call mcp tool: %w", callErr) +} + +// marshalToolResult converts an MCP CallToolResult into the JSON-encoded +// MCPToolResult shape returned to the agent. +func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) { + if result == nil { + return nil, errors.New("mcp tool call returned nil result") + } + marshaled, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err) } return &MCPToolResult{ - Content: string(marshaledResult), - IsError: final.IsError, + Content: string(marshaled), + IsError: result.IsError, }, nil } diff --git a/internal/ui/factory.go b/internal/ui/factory.go index 7a8e85c9..4cb618d7 100644 --- a/internal/ui/factory.go +++ b/internal/ui/factory.go @@ -2,7 +2,6 @@ package ui import ( "fmt" - "strings" "github.com/mark3labs/kit/internal/auth" "github.com/mark3labs/kit/internal/models" @@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) { // ollama or unrecognised models). This is used by the interactive TUI path // which doesn't go through SetupCLI. func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker { - provider, model := parseModelName(modelString) - if provider == "unknown" || model == "unknown" || provider == "ollama" { - return nil - } - - registry := models.GetGlobalRegistry() - modelInfo := registry.LookupModel(provider, model) + modelInfo, provider := lookupTrackableModel(modelString) if modelInfo == nil { return nil } - - isOAuth := false - if provider == "anthropic" { - _, source, err := auth.GetAnthropicAPIKey(providerAPIKey) - if err == nil && strings.HasPrefix(source, "stored OAuth") { - isOAuth = true - } - } - + isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey) return NewUsageTracker(modelInfo, provider, 80, isOAuth) } +// UpdateUsageTrackerForModel refreshes an existing tracker after a model +// switch so token counting and cost reporting use the new model's metadata. +// No-op for a nil tracker or untrackable models (unknown/ollama). +func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) { + if t == nil { + return + } + modelInfo, provider := lookupTrackableModel(modelString) + if modelInfo == nil { + return + } + isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey) + t.UpdateModelInfo(modelInfo, provider, isOAuth) +} + +// lookupTrackableModel resolves a model string to registry metadata, returning +// nil for models without usage tracking support (unknown or ollama models). +func lookupTrackableModel(modelString string) (*models.ModelInfo, string) { + provider, model := parseModelName(modelString) + if provider == "unknown" || model == "unknown" || provider == "ollama" { + return nil, provider + } + return models.GetGlobalRegistry().LookupModel(provider, model), provider +} + // SetupCLI creates, configures, and initializes a CLI instance with the provided // options. It sets up model display, usage tracking for supported providers, and // shows initial loading information. Returns nil in quiet mode or an initialized @@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) { } // Set up usage tracking for supported providers - if provider != "unknown" && model != "unknown" { - // Skip usage tracking for ollama as it's not in models.dev - if provider != "ollama" { - registry := models.GetGlobalRegistry() - if modelInfo := registry.LookupModel(provider, model); modelInfo != nil { - // Check if OAuth credentials are being used for Anthropic models - isOAuth := false - if provider == "anthropic" { - _, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey) - if err == nil && strings.HasPrefix(source, "stored OAuth") { - isOAuth = true - } - } - - usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width - cli.SetUsageTracker(usageTracker) - } - } + if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil { + cli.SetUsageTracker(usageTracker) } // Display model info (the system message block provides its own spacing). diff --git a/internal/ui/model.go b/internal/ui/model.go index a9a870d4..8e367887 100644 --- a/internal/ui/model.go +++ b/internal/ui/model.go @@ -1208,53 +1208,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.modelSelector = nil m.state = stateInput if m.setModel != nil { - previousModel := m.providerName + "/" + m.modelName - - // Check if thinking level needs adjustment for the new model. - // Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none". - if m.thinkingLevel != "" && m.thinkingLevel != "off" { - parts := strings.SplitN(msg.ModelString, "/", 2) - if len(parts) == 2 { - modelName := parts[1] - currentLevel := models.ParseThinkingLevel(m.thinkingLevel) - if !models.IsValidThinkingLevelForModel(currentLevel, modelName) { - fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName) - if fallback != models.ThinkingOff { - m.printSystemMessage(fmt.Sprintf( - "Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.", - modelName, currentLevel, fallback, - )) - m.thinkingLevel = string(fallback) - if m.setThinkingLevel != nil { - _ = m.setThinkingLevel(string(fallback)) - } - go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }() - } - } - } - } - - if err := m.setModel(msg.ModelString); err != nil { - m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err)) - } else { - // Update display state directly — we cannot use - // NotifyModelChanged (prog.Send) from inside Update() - // without deadlocking BubbleTea. - parts := strings.SplitN(msg.ModelString, "/", 2) - if len(parts) == 2 { - m.providerName = parts[0] - m.modelName = parts[1] - } - m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString)) - // Persist model selection for next launch. - go func() { _ = prefs.SaveModelPreference(msg.ModelString) }() - if m.emitModelChange != nil { - emit := m.emitModelChange - newModel := msg.ModelString - prev := previousModel - go emit(newModel, prev, "user") - } - } + m.switchModel(msg.ModelString) } return m, tea.Batch(cmds...) @@ -4211,11 +4165,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd { return nil } + // Direct model switch with the provided model string. + m.switchModel(args) + return nil +} + +// switchModel performs a direct model switch, shared by the model selector +// overlay and the /model slash command: it adjusts the thinking level when +// the new model doesn't support the current one, calls the setModel +// callback, updates display state, persists preferences, and emits the +// ModelChange extension event. +// +// Display state is updated directly — we cannot use NotifyModelChanged +// (prog.Send) from inside Update() without deadlocking BubbleTea. +func (m *AppModel) switchModel(modelString string) { + if m.setModel == nil { + m.printSystemMessage("Model switching is not available.") + return + } + + previousModel := m.providerName + "/" + m.modelName + // Check if thinking level needs adjustment for the new model. // Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none". if m.thinkingLevel != "" && m.thinkingLevel != "off" { - parts := strings.SplitN(args, "/", 2) - if len(parts) == 2 { + if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 { modelName := parts[1] currentLevel := models.ParseThinkingLevel(m.thinkingLevel) if !models.IsValidThinkingLevelForModel(currentLevel, modelName) { @@ -4235,32 +4209,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd { } } - // Direct model switch with the provided model string. - previousModel := m.providerName + "/" + m.modelName - if err := m.setModel(args); err != nil { + if err := m.setModel(modelString); err != nil { m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err)) - return nil + return } // Update display state directly (cannot use prog.Send from Update). - parts := strings.SplitN(args, "/", 2) - if len(parts) == 2 { + if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 { m.providerName = parts[0] m.modelName = parts[1] } - if m.emitModelChange != nil { - emit := m.emitModelChange - prev := previousModel - newModel := args - go emit(newModel, prev, "user") - } + m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString)) // Persist model selection for next launch. - go func() { _ = prefs.SaveModelPreference(args) }() + go func() { _ = prefs.SaveModelPreference(modelString) }() - m.printSystemMessage(fmt.Sprintf("Switched to %s", args)) - return nil + if m.emitModelChange != nil { + emit := m.emitModelChange + go emit(modelString, previousModel, "user") + } } // -------------------------------------------------------------------------- @@ -4827,61 +4795,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd { return r }, name) - tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name)) + tmpPath, err := buildShareFile(name, data, sysPromptJSON) if err != nil { - m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err)) + m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err)) return nil } - tmpPath := tmpFile.Name() - - // Write the session data with the system prompt entry inserted after the header. - // The header is the first line, so we write: - // 1. First line (header) from original data - // 2. System prompt entry - // 3. Remaining lines from original data - lines := strings.Split(string(data), "\n") - if len(lines) > 0 && lines[len(lines)-1] == "" { - lines = lines[:len(lines)-1] // Remove trailing empty line - } - - if len(lines) > 0 { - // Write header (first line) - if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil { - _ = tmpFile.Close() - _ = os.Remove(tmpPath) - m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err)) - return nil - } - - // Write system prompt entry - if _, err := tmpFile.Write(sysPromptJSON); err != nil { - _ = tmpFile.Close() - _ = os.Remove(tmpPath) - m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err)) - return nil - } - if _, err := tmpFile.WriteString("\n"); err != nil { - _ = tmpFile.Close() - _ = os.Remove(tmpPath) - m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err)) - return nil - } - - // Write remaining lines - for i := 1; i < len(lines); i++ { - if lines[i] == "" { - continue // Skip empty lines - } - if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil { - _ = tmpFile.Close() - _ = os.Remove(tmpPath) - m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err)) - return nil - } - } - } - - _ = tmpFile.Close() m.printSystemMessage("Uploading session to GitHub Gist...") @@ -4907,6 +4825,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd { } } +// buildShareFile assembles a temp JSONL file containing the session data +// with the system-prompt entry inserted after the header line. On success +// the caller owns the returned file and must remove it when done; on error +// any partially-written temp file has already been cleaned up. +func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) { + tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name)) + if err != nil { + return "", fmt.Errorf("create temp file: %w", err) + } + tmpPath = tmpFile.Name() + defer func() { + _ = tmpFile.Close() + if err != nil { + _ = os.Remove(tmpPath) + } + }() + + // Write the session data with the system prompt entry inserted after the + // header. The header is the first line, so we write: + // 1. First line (header) from original data + // 2. System prompt entry + // 3. Remaining lines from original data + lines := strings.Split(string(data), "\n") + if len(lines) > 0 && lines[len(lines)-1] == "" { + lines = lines[:len(lines)-1] // Remove trailing empty line + } + if len(lines) == 0 { + return tmpPath, nil + } + + if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil { + return "", fmt.Errorf("write temp file: %w", err) + } + if _, err = tmpFile.Write(sysPromptJSON); err != nil { + return "", fmt.Errorf("write system prompt: %w", err) + } + if _, err = tmpFile.WriteString("\n"); err != nil { + return "", fmt.Errorf("write temp file: %w", err) + } + for i := 1; i < len(lines); i++ { + if lines[i] == "" { + continue // Skip empty lines + } + if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil { + return "", fmt.Errorf("write temp file: %w", err) + } + } + return tmpPath, nil +} + // handleImportCommand imports a session from a JSONL file. // Usage: /import path.jsonl func (m *AppModel) handleImportCommand(args string) tea.Cmd { diff --git a/pkg/kit/adapter.go b/pkg/kit/adapter.go index 9e97e0c2..61add1c4 100644 --- a/pkg/kit/adapter.go +++ b/pkg/kit/adapter.go @@ -11,12 +11,12 @@ import ( // treeManagerAdapter adapts TreeManager to SessionManager interface. // This is unexported - users don't interact with it directly. type treeManagerAdapter struct { - inner *session.TreeManager + inner *TreeManager } // NewTreeManagerAdapter creates an adapter (exported for use in New function). // This is used by the SDK when no custom SessionManager is provided. -func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager { +func NewTreeManagerAdapter(tm *TreeManager) SessionManager { return &treeManagerAdapter{inner: tm} } diff --git a/pkg/kit/events.go b/pkg/kit/events.go index c5d73223..d662f1df 100644 --- a/pkg/kit/events.go +++ b/pkg/kit/events.go @@ -3,6 +3,8 @@ package kit import ( "encoding/json" "sync" + + "github.com/mark3labs/kit/internal/extensions" ) // --------------------------------------------------------------------------- @@ -103,34 +105,21 @@ type Event interface { // appropriate visualizations (e.g. diff view for edit tools, command+output // for execute tools) and file trackers to identify which results contain // modifications. +// +// These constants re-export the canonical classification used by extension +// events, so SDK events and extension events always agree. const ( - ToolKindExecute = "execute" // Shell execution (bash) - ToolKindEdit = "edit" // File modification (edit, write) - ToolKindRead = "read" // File reading (read, ls) - ToolKindSearch = "search" // Content/file search (grep, find) - ToolKindSubagent = "agent" // Subagent spawning (subagent) + ToolKindExecute = extensions.ToolKindExecute // Shell execution (bash) + ToolKindEdit = extensions.ToolKindEdit // File modification (edit, write) + ToolKindRead = extensions.ToolKindRead // File reading (read, ls) + ToolKindSearch = extensions.ToolKindSearch // Content/file search (grep, find) + ToolKindSubagent = extensions.ToolKindSubagent // Subagent spawning (subagent) ) -// coreToolKinds maps built-in tool names to their kind. MCP and extension -// tools without an entry default to ToolKindExecute. -var coreToolKinds = map[string]string{ - "bash": ToolKindExecute, - "edit": ToolKindEdit, - "write": ToolKindEdit, - "read": ToolKindRead, - "ls": ToolKindRead, - "grep": ToolKindSearch, - "find": ToolKindSearch, - "subagent": ToolKindSubagent, -} - // toolKindFor returns the ToolKind for a given tool name, defaulting to // ToolKindExecute for unknown tools. func toolKindFor(toolName string) string { - if kind, ok := coreToolKinds[toolName]; ok { - return kind - } - return ToolKindExecute + return extensions.ToolKindFor(toolName) } // parseToolArgs attempts to parse a JSON-encoded tool args string into a map. diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go index 36e069bb..d80dc927 100644 --- a/pkg/kit/extensions_bridge.go +++ b/pkg/kit/extensions_bridge.go @@ -578,18 +578,13 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6 } // isAnthropicOAuth reports whether the current Anthropic credential resolves -// to a stored OAuth token (in which case the user is not billed per-token). -// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker -// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage(). +// to a stored OAuth token (in which case the user is not billed per-token), +// so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage(). func isAnthropicOAuth(m *Kit, provider string) bool { if m == nil || provider != "anthropic" { return false } - _, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key")) - if err != nil { - return false - } - return strings.HasPrefix(source, "stored OAuth") + return auth.IsAnthropicOAuth(m.v.GetString("provider-api-key")) } // llmToContextMessages converts a slice of LLM messages to extension diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go index f46b8c2a..12724842 100644 --- a/pkg/kit/kit.go +++ b/pkg/kit/kit.go @@ -1160,7 +1160,7 @@ type CLIOptions struct { // - Continue: resume most recent session for SessionDir (or cwd) // - SessionPath: open a specific JSONL session file // - default: create a new tree session for SessionDir (or cwd) -func InitTreeSession(opts *Options) (*session.TreeManager, error) { +func InitTreeSession(opts *Options) (*TreeManager, error) { if opts == nil { opts = &Options{} } diff --git a/pkg/kit/template_bridge.go b/pkg/kit/template_bridge.go index 98307b11..ba0b0c58 100644 --- a/pkg/kit/template_bridge.go +++ b/pkg/kit/template_bridge.go @@ -7,45 +7,36 @@ import ( "github.com/mark3labs/kit/internal/extensions" "github.com/mark3labs/kit/internal/models" + "github.com/mark3labs/kit/internal/prompts" + "github.com/mark3labs/kit/internal/skills" ) // --------------------------------------------------------------------------- // Template Parsing Bridge for Extensions (Phase 3) // --------------------------------------------------------------------------- -// varRegex matches {{variable}} placeholders in templates. -var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`) - -// ParseTemplate extracts {{variables}} from template content. +// ParseTemplate extracts {{variables}} from template content. The template +// grammar is shared with skill prompt templates, so a template parses +// identically regardless of which API loads it. func ParseTemplate(name, content string) extensions.PromptTemplate { - matches := varRegex.FindAllStringSubmatch(content, -1) - vars := make([]string, 0, len(matches)) - seen := make(map[string]bool) - for _, m := range matches { - if len(m) > 1 && !seen[m[1]] { - seen[m[1]] = true - vars = append(vars, m[1]) - } + tpl := skills.NewPromptTemplate(name, content) + vars := tpl.Variables + if vars == nil { + vars = []string{} } return extensions.PromptTemplate{ - Name: name, - Content: content, + Name: tpl.Name, + Content: tpl.Content, Variables: vars, } } // RenderTemplate substitutes variables into template content. -// Handles {{name}} and {{ name }} (any whitespace) placeholders. +// Handles {{name}} and {{ name }} (any whitespace) placeholders; missing +// variables are left as-is. func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string { - return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string { - sub := varRegex.FindStringSubmatch(m) - if len(sub) > 1 { - if v, ok := vars[sub[1]]; ok { - return v - } - } - return m - }) + t := skills.PromptTemplate{Content: tpl.Content} + return t.Expand(vars) } // ParseArguments parses command-line style arguments. @@ -183,44 +174,12 @@ func SimpleParseArguments(input string, count int) []string { return result } -// parseFields splits input respecting quoted strings. +// parseFields splits input into arguments respecting quoted strings and +// backslash escaping. It delegates to the canonical tokenizer in +// internal/prompts so extension argument parsing and builtin prompt-template +// parsing agree on grammar. func parseFields(input string) []string { - var fields []string - var current strings.Builder - inQuote := false - quoteChar := rune(0) - - for _, r := range input { - switch r { - case '"', '\'': - if !inQuote { - inQuote = true - quoteChar = r - } else if r == quoteChar { - inQuote = false - quoteChar = 0 - } else { - current.WriteRune(r) - } - case ' ', '\t': - if inQuote { - current.WriteRune(r) - } else { - if current.Len() > 0 { - fields = append(fields, current.String()) - current.Reset() - } - } - default: - current.WriteRune(r) - } - } - - if current.Len() > 0 { - fields = append(fields, current.String()) - } - - return fields + return prompts.ParseCommandArgs(input) } // EvaluateModelConditional checks if condition matches current model. @@ -417,21 +376,18 @@ func MatchModelGlob(model, pattern string) bool { } // ExtractProviderFromPath extracts provider from a path-like model string. +// +// Deprecated: Use GetCurrentProvider instead. func ExtractProviderFromPath(model string) string { - parts := strings.Split(model, "/") - if len(parts) >= 2 { - return parts[0] - } - return "" + return GetCurrentProvider(model) } // ExtractModelFromPath extracts model ID from a path-like model string. +// +// Deprecated: Use RemoveProviderFromModel instead, which correctly handles +// model IDs containing "/" (e.g. "openrouter/meta/llama"). func ExtractModelFromPath(model string) string { - parts := strings.Split(model, "/") - if len(parts) >= 2 { - return parts[1] - } - return model + return RemoveProviderFromModel(model) } // IsBareModelID checks if a string is a bare model ID (no provider).