diff --git a/cmd/extension_context.go b/cmd/extension_context.go
index e38002bf..17e5b354 100644
--- a/cmd/extension_context.go
+++ b/cmd/extension_context.go
@@ -4,13 +4,11 @@ import (
"context"
"fmt"
"os"
- "strings"
"github.com/spf13/viper"
"golang.org/x/term"
"github.com/mark3labs/kit/internal/app"
- "github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/extbridge"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
@@ -35,451 +33,276 @@ type extensionContextDeps struct {
// the three print routes appropriately for their phase (startup buffering
// vs. live runtime routing).
//
-// This consolidates two near-identical 400-line literal expressions that
-// previously appeared inline in runNormalMode.
+// The headless half (data access, state, options, tree navigation, skills,
+// templates, model resolution, subagents) comes from extbridge.BaseContext;
+// this function overlays the TUI-specific fields and overrides SetModel /
+// ReloadExtensions with TUI-aware versions.
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
kitInstance := deps.kitInstance
appInstance := deps.appInstance
usageTracker := deps.usageTracker
- ctx := deps.ctx
- return extensions.Context{
- CWD: deps.cwd,
- Model: deps.modelName,
- Interactive: deps.interactive,
- PrintBlock: func(opts extensions.PrintBlockOpts) {
- appInstance.PrintBlockFromExtension(opts)
- },
- SendMessage: func(text string) { appInstance.Run(text) },
- CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
- Abort: func() { appInstance.Abort() },
- IsIdle: func() bool { return !appInstance.IsBusy() },
- Compact: func(cfg extensions.CompactConfig) error {
- return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
- },
- SendMultimodalMessage: func(text string, files []extensions.FilePart) {
- parts := make([]kit.LLMFilePart, len(files))
- for i, f := range files {
- parts[i] = kit.LLMFilePart{
- Filename: f.Filename,
- Data: f.Data,
- MediaType: f.MediaType,
- }
- }
- appInstance.RunWithFiles(text, parts)
- },
- GetSessionUsage: func() extensions.SessionUsage {
- if usageTracker == nil {
- return extensions.SessionUsage{}
- }
- stats := usageTracker.GetSessionStats()
- return extensions.SessionUsage{
- TotalInputTokens: stats.TotalInputTokens,
- TotalOutputTokens: stats.TotalOutputTokens,
- TotalCacheReadTokens: stats.TotalCacheReadTokens,
- TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
- TotalCost: stats.TotalCost,
- RequestCount: stats.RequestCount,
- }
- },
- Exit: func() { appInstance.QuitFromExtension() },
- SetWidget: func(config extensions.WidgetConfig) {
- kitInstance.Extensions().SetWidget(config)
- go appInstance.NotifyWidgetUpdate()
- },
- RemoveWidget: func(id string) {
- kitInstance.Extensions().RemoveWidget(id)
- go appInstance.NotifyWidgetUpdate()
- },
- SetHeader: func(config extensions.HeaderFooterConfig) {
- kitInstance.Extensions().SetHeader(config)
- go appInstance.NotifyWidgetUpdate()
- },
- RemoveHeader: func() {
- kitInstance.Extensions().RemoveHeader()
- go appInstance.NotifyWidgetUpdate()
- },
- SetFooter: func(config extensions.HeaderFooterConfig) {
- kitInstance.Extensions().SetFooter(config)
- go appInstance.NotifyWidgetUpdate()
- },
- RemoveFooter: func() {
- kitInstance.Extensions().RemoveFooter()
- go appInstance.NotifyWidgetUpdate()
- },
- PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
- ch := make(chan app.PromptResponse, 1)
- appInstance.SendPromptRequest(app.PromptRequestEvent{
- PromptType: "select",
- Message: config.Message,
- Options: config.Options,
- ResponseCh: ch,
- })
- resp := <-ch
- if resp.Cancelled {
- return extensions.PromptSelectResult{Cancelled: true}
- }
- return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
- },
- PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
- ch := make(chan app.PromptResponse, 1)
- def := "false"
- if config.DefaultValue {
- def = "true"
- }
- appInstance.SendPromptRequest(app.PromptRequestEvent{
- PromptType: "confirm",
- Message: config.Message,
- Default: def,
- ResponseCh: ch,
- })
- resp := <-ch
- if resp.Cancelled {
- return extensions.PromptConfirmResult{Cancelled: true}
- }
- return extensions.PromptConfirmResult{Value: resp.Confirmed}
- },
- PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
- ch := make(chan app.PromptResponse, 1)
- appInstance.SendPromptRequest(app.PromptRequestEvent{
- PromptType: "input",
- Message: config.Message,
- Placeholder: config.Placeholder,
- Default: config.Default,
- ResponseCh: ch,
- })
- resp := <-ch
- if resp.Cancelled {
- return extensions.PromptInputResult{Cancelled: true}
- }
- return extensions.PromptInputResult{Value: resp.Value}
- },
- SetUIVisibility: func(v extensions.UIVisibility) {
- kitInstance.Extensions().SetUIVisibility(v)
- go appInstance.NotifyWidgetUpdate()
- },
- GetContextStats: func() extensions.ContextStats {
- s := kitInstance.GetContextStats()
- return extensions.ContextStats{
- EstimatedTokens: s.EstimatedTokens,
- ContextLimit: s.ContextLimit,
- UsagePercent: s.UsagePercent,
- MessageCount: s.MessageCount,
- }
- },
- SetEditor: func(config extensions.EditorConfig) {
- kitInstance.Extensions().SetEditor(config)
- // Always use a goroutine for NotifyWidgetUpdate: prog.Send()
- // deadlocks if called synchronously from inside BubbleTea's
- // Update() handler. All call sites use go-routines uniformly.
- go appInstance.NotifyWidgetUpdate()
- },
- ResetEditor: func() {
- kitInstance.Extensions().ResetEditor()
- go appInstance.NotifyWidgetUpdate()
- },
- GetMessages: func() []extensions.SessionMessage {
- return kitInstance.Extensions().GetSessionMessages()
- },
- GetSessionPath: func() string {
- return kitInstance.GetSessionPath()
- },
- AppendEntry: func(entryType string, data string) (string, error) {
- return kitInstance.Extensions().AppendEntry(entryType, data)
- },
- GetEntries: func(entryType string) []extensions.ExtensionEntry {
- return kitInstance.Extensions().GetEntries(entryType)
- },
- SetState: func(key string, value string) {
- kitInstance.Extensions().SetState(key, value)
- },
- GetState: func(key string) (string, bool) {
- return kitInstance.Extensions().GetState(key)
- },
- DeleteState: func(key string) {
- kitInstance.Extensions().DeleteState(key)
- },
- ListState: func() []string {
- return kitInstance.Extensions().ListState()
- },
- SetEditorText: func(text string) {
- appInstance.SetEditorTextFromExtension(text)
- },
- SetStatus: func(key string, text string, priority int) {
- kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
- Key: key,
- Text: text,
- Priority: priority,
- })
- go appInstance.NotifyWidgetUpdate()
- },
- RemoveStatus: func(key string) {
- kitInstance.Extensions().RemoveStatus(key)
- go appInstance.NotifyWidgetUpdate()
- },
- GetOption: func(name string) string {
- return kitInstance.Extensions().GetOption(name)
- },
- SetOption: func(name string, value string) {
- kitInstance.Extensions().SetOption(name, value)
- },
- SetModel: func(modelString string) error {
- // Capture previous model for the ModelChange event.
- previousModel := kitInstance.Extensions().GetContext().Model
- err := kitInstance.SetModel(context.Background(), modelString)
- if err != nil {
- return err
- }
- // Notify TUI so it updates model in status bar.
- p, m, _ := models.ParseModelString(modelString)
- appInstance.NotifyModelChanged(p, m)
- // Update the context's Model field so handlers see it.
- kitInstance.Extensions().UpdateContextModel(modelString)
- // Fire OnModelChange event to extensions.
- kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
- // Update usage tracker with new model info for correct token counting.
- if usageTracker != nil {
- newProvider, newModel, _ := models.ParseModelString(modelString)
- if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
- registry := models.GetGlobalRegistry()
- if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
- // Check OAuth status for Anthropic models
- isOAuth := false
- if newProvider == "anthropic" {
- _, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
- if err == nil && strings.HasPrefix(source, "stored OAuth") {
- isOAuth = true
- }
- }
- usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
- }
- }
- }
- return nil
- },
- GetAvailableModels: func() []extensions.ModelInfoEntry {
- return kitInstance.GetAvailableModels()
- },
- EmitCustomEvent: func(name string, data string) {
- kitInstance.Extensions().EmitCustomEvent(name, data)
- },
- Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
- return kitInstance.ExecuteCompletion(context.Background(), req)
- },
- SuspendTUI: func(callback func()) error {
- return appInstance.SuspendTUI(callback)
- },
- RenderMessage: func(rendererName, content string) {
- renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
- if renderer == nil || renderer.Render == nil {
- appInstance.PrintFromExtension("", content)
- return
- }
- w, _, _ := term.GetSize(int(os.Stdout.Fd()))
- if w == 0 {
- w = 80
- }
- rendered := renderer.Render(content, w)
- appInstance.PrintFromExtension("", rendered)
- },
- ReloadExtensions: func() error {
- err := kitInstance.Extensions().Reload()
- if err != nil {
- return err
- }
- // Notify TUI that widgets/status/commands may have changed.
- go appInstance.NotifyWidgetUpdate()
- return nil
- },
- GetAllTools: func() []extensions.ToolInfo {
- return kitInstance.Extensions().GetToolInfos()
- },
- SetActiveTools: func(names []string) {
- kitInstance.Extensions().SetActiveTools(names)
- },
- RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
- tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
- ui.RegisterThemeFromConfig(name,
- tc(config.Primary), tc(config.Secondary),
- tc(config.Success), tc(config.Warning),
- tc(config.Error), tc(config.Info),
- tc(config.Text), tc(config.Muted),
- tc(config.VeryMuted), tc(config.Background),
- tc(config.Border), tc(config.MutedBorder),
- tc(config.System), tc(config.Tool),
- tc(config.Accent), tc(config.Highlight),
- tc(config.MdHeading), tc(config.MdLink),
- tc(config.MdKeyword), tc(config.MdString),
- tc(config.MdNumber), tc(config.MdComment),
- )
- },
- SetTheme: func(name string) error {
- return ui.ApplyTheme(name)
- },
- ListThemes: func() []string {
- return ui.ListThemes()
- },
- ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
- ch := make(chan app.OverlayResponse, 1)
- appInstance.SendOverlayRequest(app.OverlayRequestEvent{
- Title: config.Title,
- Content: config.Content.Text,
- Markdown: config.Content.Markdown,
- BorderColor: config.Style.BorderColor,
- Background: config.Style.Background,
- Width: config.Width,
- MaxHeight: config.MaxHeight,
- Anchor: string(config.Anchor),
- Actions: config.Actions,
- ResponseCh: ch,
- })
- resp := <-ch
- if resp.Cancelled {
- return extensions.OverlayResult{Cancelled: true, Index: -1}
- }
- return extensions.OverlayResult{
- Action: resp.Action,
- Index: resp.Index,
- }
- },
- SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
- return extbridge.SpawnSubagent(ctx, kitInstance, config)
- },
- // -------------------------------------------------------------------
- // Tree Navigation API
- // -------------------------------------------------------------------
- GetTreeNode: func(entryID string) *extensions.TreeNode {
- node := kitInstance.GetTreeNode(entryID)
- if node == nil {
- return nil
- }
- return &extensions.TreeNode{
- ID: node.ID,
- ParentID: node.ParentID,
- Type: node.Type,
- Role: node.Role,
- Content: node.Content,
- Model: node.Model,
- Provider: node.Provider,
- Timestamp: node.Timestamp,
- Children: node.Children,
- }
- },
- GetCurrentBranch: func() []extensions.TreeNode {
- nodes := kitInstance.GetCurrentBranch()
- result := make([]extensions.TreeNode, len(nodes))
- for i, n := range nodes {
- result[i] = extensions.TreeNode{
- ID: n.ID,
- ParentID: n.ParentID,
- Type: n.Type,
- Role: n.Role,
- Content: n.Content,
- Model: n.Model,
- Provider: n.Provider,
- Timestamp: n.Timestamp,
- Children: n.Children,
- }
- }
- return result
- },
- GetChildren: func(parentID string) []string {
- return kitInstance.GetChildren(parentID)
- },
- NavigateTo: func(entryID string) extensions.TreeNavigationResult {
- err := kitInstance.NavigateTo(entryID)
- if err != nil {
- return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
- }
- return extensions.TreeNavigationResult{Success: true}
- },
- SummarizeBranch: func(fromID, toID string) string {
- summary, _ := kitInstance.SummarizeBranch(fromID, toID)
- return summary
- },
- CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
- err := kitInstance.CollapseBranch(fromID, toID, summary)
- if err != nil {
- return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
- }
- return extensions.TreeNavigationResult{Success: true}
- },
+ ec := extbridge.BaseContext(deps.ctx, kitInstance)
- // -------------------------------------------------------------------
- // Skill Loading API
- // -------------------------------------------------------------------
- LoadSkill: func(path string) (*extensions.Skill, string) {
- s, err := kitInstance.LoadSkillForExtension(path)
- return s, err
- },
- LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
- return kitInstance.LoadSkillsFromDirForExtension(dir)
- },
- DiscoverSkills: func() extensions.SkillLoadResult {
- skills := kitInstance.DiscoverSkillsForExtension()
- return extensions.SkillLoadResult{Skills: skills}
- },
- InjectSkillAsContext: func(skillName string) string {
- skills := kitInstance.DiscoverSkillsForExtension()
- for _, s := range skills {
- if s.Name == skillName {
- appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content))
- return ""
- }
- }
- return fmt.Sprintf("skill not found: %s", skillName)
- },
- InjectRawSkillAsContext: func(path string) string {
- s, err := kitInstance.LoadSkillForExtension(path)
- if err != "" {
- return err
- }
- appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content))
- return ""
- },
- GetAvailableSkills: func() []extensions.Skill {
- return kitInstance.DiscoverSkillsForExtension()
- },
+ ec.CWD = deps.cwd
+ ec.Model = deps.modelName
+ ec.Interactive = deps.interactive
- // -------------------------------------------------------------------
- // Template Parsing API
- // -------------------------------------------------------------------
- ParseTemplate: func(name, content string) extensions.PromptTemplate {
- return kit.ParseTemplate(name, content)
- },
- RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
- return kit.RenderTemplate(tpl, vars)
- },
- ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
- return kit.ParseArguments(input, pattern)
- },
- SimpleParseArguments: func(input string, count int) []string {
- return kit.SimpleParseArguments(input, count)
- },
- EvaluateModelConditional: func(condition string) bool {
- return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
- },
- RenderWithModelConditionals: func(content string) string {
- return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
- },
-
- // -------------------------------------------------------------------
- // Model Resolution API
- // -------------------------------------------------------------------
- ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
- return kit.ResolveModelChain(preferences)
- },
- GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
- return kit.GetModelCapabilities(model)
- },
- CheckModelAvailable: func(model string) bool {
- return kit.CheckModelAvailable(model)
- },
- GetCurrentProvider: func() string {
- return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
- },
- GetCurrentModelID: func() string {
- return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
- },
+ ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
+ appInstance.PrintBlockFromExtension(opts)
}
+ ec.SendMessage = func(text string) { appInstance.Run(text) }
+ ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) }
+ ec.Abort = func() { appInstance.Abort() }
+ ec.IsIdle = func() bool { return !appInstance.IsBusy() }
+ ec.Compact = func(cfg extensions.CompactConfig) error {
+ return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
+ }
+ ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) {
+ parts := make([]kit.LLMFilePart, len(files))
+ for i, f := range files {
+ parts[i] = kit.LLMFilePart{
+ Filename: f.Filename,
+ Data: f.Data,
+ MediaType: f.MediaType,
+ }
+ }
+ appInstance.RunWithFiles(text, parts)
+ }
+ ec.GetSessionUsage = func() extensions.SessionUsage {
+ if usageTracker == nil {
+ return extensions.SessionUsage{}
+ }
+ stats := usageTracker.GetSessionStats()
+ return extensions.SessionUsage{
+ TotalInputTokens: stats.TotalInputTokens,
+ TotalOutputTokens: stats.TotalOutputTokens,
+ TotalCacheReadTokens: stats.TotalCacheReadTokens,
+ TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
+ TotalCost: stats.TotalCost,
+ RequestCount: stats.RequestCount,
+ }
+ }
+ ec.Exit = func() { appInstance.QuitFromExtension() }
+
+ // TUI widgets/chrome — mutate runner state, then notify the TUI.
+ // Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks
+ // if called synchronously from inside BubbleTea's Update() handler.
+ // All call sites use go-routines uniformly.
+ ec.SetWidget = func(config extensions.WidgetConfig) {
+ kitInstance.Extensions().SetWidget(config)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.RemoveWidget = func(id string) {
+ kitInstance.Extensions().RemoveWidget(id)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.SetHeader = func(config extensions.HeaderFooterConfig) {
+ kitInstance.Extensions().SetHeader(config)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.RemoveHeader = func() {
+ kitInstance.Extensions().RemoveHeader()
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.SetFooter = func(config extensions.HeaderFooterConfig) {
+ kitInstance.Extensions().SetFooter(config)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.RemoveFooter = func() {
+ kitInstance.Extensions().RemoveFooter()
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.SetUIVisibility = func(v extensions.UIVisibility) {
+ kitInstance.Extensions().SetUIVisibility(v)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.SetEditor = func(config extensions.EditorConfig) {
+ kitInstance.Extensions().SetEditor(config)
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.ResetEditor = func() {
+ kitInstance.Extensions().ResetEditor()
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.SetEditorText = func(text string) {
+ appInstance.SetEditorTextFromExtension(text)
+ }
+ ec.SetStatus = func(key string, text string, priority int) {
+ kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
+ Key: key,
+ Text: text,
+ Priority: priority,
+ })
+ go appInstance.NotifyWidgetUpdate()
+ }
+ ec.RemoveStatus = func(key string) {
+ kitInstance.Extensions().RemoveStatus(key)
+ go appInstance.NotifyWidgetUpdate()
+ }
+
+ // Interactive prompts — channel-based round trips through the TUI.
+ ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
+ ch := make(chan app.PromptResponse, 1)
+ appInstance.SendPromptRequest(app.PromptRequestEvent{
+ PromptType: "select",
+ Message: config.Message,
+ Options: config.Options,
+ ResponseCh: ch,
+ })
+ resp := <-ch
+ if resp.Cancelled {
+ return extensions.PromptSelectResult{Cancelled: true}
+ }
+ return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
+ }
+ ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
+ ch := make(chan app.PromptResponse, 1)
+ def := "false"
+ if config.DefaultValue {
+ def = "true"
+ }
+ appInstance.SendPromptRequest(app.PromptRequestEvent{
+ PromptType: "confirm",
+ Message: config.Message,
+ Default: def,
+ ResponseCh: ch,
+ })
+ resp := <-ch
+ if resp.Cancelled {
+ return extensions.PromptConfirmResult{Cancelled: true}
+ }
+ return extensions.PromptConfirmResult{Value: resp.Confirmed}
+ }
+ ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult {
+ ch := make(chan app.PromptResponse, 1)
+ appInstance.SendPromptRequest(app.PromptRequestEvent{
+ PromptType: "input",
+ Message: config.Message,
+ Placeholder: config.Placeholder,
+ Default: config.Default,
+ ResponseCh: ch,
+ })
+ resp := <-ch
+ if resp.Cancelled {
+ return extensions.PromptInputResult{Cancelled: true}
+ }
+ return extensions.PromptInputResult{Value: resp.Value}
+ }
+ ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult {
+ ch := make(chan app.OverlayResponse, 1)
+ appInstance.SendOverlayRequest(app.OverlayRequestEvent{
+ Title: config.Title,
+ Content: config.Content.Text,
+ Markdown: config.Content.Markdown,
+ BorderColor: config.Style.BorderColor,
+ Background: config.Style.Background,
+ Width: config.Width,
+ MaxHeight: config.MaxHeight,
+ Anchor: string(config.Anchor),
+ Actions: config.Actions,
+ ResponseCh: ch,
+ })
+ resp := <-ch
+ if resp.Cancelled {
+ return extensions.OverlayResult{Cancelled: true, Index: -1}
+ }
+ return extensions.OverlayResult{
+ Action: resp.Action,
+ Index: resp.Index,
+ }
+ }
+ ec.SuspendTUI = func(callback func()) error {
+ return appInstance.SuspendTUI(callback)
+ }
+
+ // TUI-aware model switch: also notifies the TUI status bar and
+ // refreshes the usage tracker for correct token counting.
+ ec.SetModel = func(modelString string) error {
+ // Capture previous model for the ModelChange event.
+ previousModel := kitInstance.Extensions().GetContext().Model
+ err := kitInstance.SetModel(context.Background(), modelString)
+ if err != nil {
+ return err
+ }
+ // Notify TUI so it updates model in status bar.
+ p, m, _ := models.ParseModelString(modelString)
+ appInstance.NotifyModelChanged(p, m)
+ // Update the context's Model field so handlers see it.
+ kitInstance.Extensions().UpdateContextModel(modelString)
+ // Fire OnModelChange event to extensions.
+ kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
+ // Update usage tracker with new model info for correct token counting.
+ ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
+ return nil
+ }
+
+ ec.RenderMessage = func(rendererName, content string) {
+ renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
+ if renderer == nil || renderer.Render == nil {
+ appInstance.PrintFromExtension("", content)
+ return
+ }
+ w, _, _ := term.GetSize(int(os.Stdout.Fd()))
+ if w == 0 {
+ w = 80
+ }
+ rendered := renderer.Render(content, w)
+ appInstance.PrintFromExtension("", rendered)
+ }
+ ec.ReloadExtensions = func() error {
+ err := kitInstance.Extensions().Reload()
+ if err != nil {
+ return err
+ }
+ // Notify TUI that widgets/status/commands may have changed.
+ go appInstance.NotifyWidgetUpdate()
+ return nil
+ }
+
+ // Theme management (TUI only).
+ ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) {
+ tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
+ ui.RegisterThemeFromConfig(name,
+ tc(config.Primary), tc(config.Secondary),
+ tc(config.Success), tc(config.Warning),
+ tc(config.Error), tc(config.Info),
+ tc(config.Text), tc(config.Muted),
+ tc(config.VeryMuted), tc(config.Background),
+ tc(config.Border), tc(config.MutedBorder),
+ tc(config.System), tc(config.Tool),
+ tc(config.Accent), tc(config.Highlight),
+ tc(config.MdHeading), tc(config.MdLink),
+ tc(config.MdKeyword), tc(config.MdString),
+ tc(config.MdNumber), tc(config.MdComment),
+ )
+ }
+ ec.SetTheme = func(name string) error {
+ return ui.ApplyTheme(name)
+ }
+ ec.ListThemes = func() []string {
+ return ui.ListThemes()
+ }
+
+ // Skill context-injection (drives a new agent turn through the TUI).
+ ec.InjectSkillAsContext = func(skillName string) string {
+ skills := kitInstance.DiscoverSkillsForExtension()
+ for _, s := range skills {
+ if s.Name == skillName {
+ appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content))
+ return ""
+ }
+ }
+ return fmt.Sprintf("skill not found: %s", skillName)
+ }
+ ec.InjectRawSkillAsContext = func(path string) string {
+ s, err := kitInstance.LoadSkillForExtension(path)
+ if err != "" {
+ return err
+ }
+ appInstance.Run(fmt.Sprintf("\n%s\n", s.Name, s.Content))
+ return ""
+ }
+
+ return ec
}
diff --git a/cmd/root.go b/cmd/root.go
index 99d81c63..37753239 100644
--- a/cmd/root.go
+++ b/cmd/root.go
@@ -12,7 +12,6 @@ import (
tea "charm.land/bubbletea/v2"
"github.com/mark3labs/kit/internal/app"
- "github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
@@ -677,8 +676,8 @@ func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
}
}
-func runNormalMode(ctx context.Context) error {
- // Validate flag combinations
+// validateModeFlags rejects invalid flag combinations for the root command.
+func validateModeFlags() error {
if quietFlag && positionalPrompt == "" {
return fmt.Errorf("--quiet requires a prompt (e.g. kit \"your question\" --quiet)")
}
@@ -691,21 +690,14 @@ func runNormalMode(ctx context.Context) error {
if noExitFlag && positionalPrompt == "" {
return fmt.Errorf("--no-exit requires a prompt (e.g. kit \"your question\" --no-exit)")
}
+ return nil
+}
- // Set up logging
- if debugMode {
- log.SetFlags(log.LstdFlags | log.Lshortfile)
- }
-
- // Update debug mode from viper
- if viper.GetBool("debug") && !debugMode {
- debugMode = viper.GetBool("debug")
- log.SetFlags(log.LstdFlags | log.Lshortfile)
- }
-
- // Restore persisted model preference when no explicit --model flag or
- // config file model is set. Precedence: CLI flag > config file > saved
- // preference > built-in default. This mirrors how themes are persisted.
+// restorePersistedPreferences applies saved model / thinking-level
+// preferences into viper when neither a CLI flag nor a config-file value
+// takes precedence. Precedence: CLI flag > config file > saved preference >
+// built-in default. This mirrors how themes are persisted.
+func restorePersistedPreferences() {
// Skip custom/* models unless --provider-url is also provided, since the
// custom provider requires a URL that was only valid for the previous session.
if !modelFlagChanged && !viper.InConfig("model") {
@@ -724,6 +716,15 @@ func runNormalMode(ctx context.Context) error {
viper.Set("thinking-level", pref)
}
}
+}
+
+// applyProviderURLRouting rewrites the model in viper when --provider-url
+// is set, routing requests through the "custom" (OpenAI-compatible)
+// provider. Must run after restorePersistedPreferences.
+func applyProviderURLRouting() {
+ if viper.GetString("provider-url") == "" {
+ return
+ }
// When --provider-url is set but no explicit --model was provided,
// default to "custom/custom" so the user doesn't need to remember a
@@ -731,7 +732,7 @@ func runNormalMode(ctx context.Context) error {
// This intentionally overrides saved preferences but respects config-file
// models — if you specify a model in ~/.kit.yml, it will be used with
// custom/custom's provider routing.
- if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") {
+ if !modelFlagChanged && !viper.InConfig("model") {
viper.Set("model", "custom/custom")
}
@@ -746,7 +747,7 @@ func runNormalMode(ctx context.Context) error {
// to point a non-OpenAI wire (Anthropic, Google, ...) at a proxy URL,
// use the explicit `custom/` form to opt out of the rewrite by
// configuring the proxy as that provider in your config file instead.
- if viper.GetString("provider-url") != "" && modelFlagChanged {
+ if modelFlagChanged {
model := viper.GetString("model")
if model != "" {
name := model
@@ -758,6 +759,26 @@ func runNormalMode(ctx context.Context) error {
}
}
}
+}
+
+func runNormalMode(ctx context.Context) error {
+ if err := validateModeFlags(); err != nil {
+ return err
+ }
+
+ // Set up logging
+ if debugMode {
+ log.SetFlags(log.LstdFlags | log.Lshortfile)
+ }
+
+ // Update debug mode from viper
+ if viper.GetBool("debug") && !debugMode {
+ debugMode = viper.GetBool("debug")
+ log.SetFlags(log.LstdFlags | log.Lshortfile)
+ }
+
+ restorePersistedPreferences()
+ applyProviderURLRouting()
// Load MCP configuration.
mcpConfig, err := config.LoadAndValidateConfig()
@@ -1164,23 +1185,7 @@ func runNormalMode(ctx context.Context) error {
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
// updates m.providerName and m.modelName directly after setModel returns.
// Update usage tracker with new model info for correct token counting.
- if usageTracker != nil {
- newProvider, newModel, _ := models.ParseModelString(modelString)
- if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
- registry := models.GetGlobalRegistry()
- if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
- // Check OAuth status for Anthropic models
- isOAuth := false
- if newProvider == "anthropic" {
- _, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
- if err == nil && strings.HasPrefix(source, "stored OAuth") {
- isOAuth = true
- }
- }
- usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
- }
- }
- }
+ ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
return nil
}
emitModelChangeForUI := func(newModel, previousModel, source string) {
diff --git a/internal/acpserver/session.go b/internal/acpserver/session.go
index 937cc6ad..5945ef16 100644
--- a/internal/acpserver/session.go
+++ b/internal/acpserver/session.go
@@ -73,111 +73,70 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
// Wire extension context with headless implementations so extensions
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
- // become no-ops or return cancelled; all data/model/tool APIs work
- // identically to interactive mode.
+ // become no-ops or return cancelled; all data/model/tool APIs come from
+ // extbridge.BaseContext and work identically to interactive mode.
if kitInstance.Extensions().HasExtensions() {
- kitInstance.Extensions().SetContext(extensions.Context{
- SessionID: sessionID,
- CWD: cwd,
- Model: kitInstance.GetModelString(),
- Interactive: false,
+ // Use a background context for subagent spawns: the create() ctx is
+ // request-scoped and may be cancelled before extensions spawn anything.
+ ec := extbridge.BaseContext(context.Background(), kitInstance)
- // Output — route through structured logger.
- Print: func(text string) { log.Debug("extension: print", "text", text) },
- PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
- PrintError: func(text string) { log.Error("extension: error", "text", text) },
- PrintBlock: func(opts extensions.PrintBlockOpts) {
- log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
- },
+ ec.SessionID = sessionID
+ ec.CWD = cwd
+ ec.Model = kitInstance.GetModelString()
+ ec.Interactive = false
- // Message injection — no-ops for now; ACP clients drive prompts.
- SendMessage: func(string) {},
- CancelAndSend: func(string) {},
- Exit: func() {},
+ // Output — route through structured logger.
+ ec.Print = func(text string) { log.Debug("extension: print", "text", text) }
+ ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) }
+ ec.PrintError = func(text string) { log.Error("extension: error", "text", text) }
+ ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
+ log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
+ }
- // TUI widgets/chrome — silent no-ops (no TUI in ACP).
- SetWidget: func(extensions.WidgetConfig) {},
- RemoveWidget: func(string) {},
- SetHeader: func(extensions.HeaderFooterConfig) {},
- RemoveHeader: func() {},
- SetFooter: func(extensions.HeaderFooterConfig) {},
- RemoveFooter: func() {},
- SetEditor: func(extensions.EditorConfig) {},
- ResetEditor: func() {},
- SetEditorText: func(string) {},
- SetUIVisibility: func(extensions.UIVisibility) {},
- SetStatus: func(string, string, int) {},
- RemoveStatus: func(string) {},
+ // Message injection — no-ops for now; ACP clients drive prompts.
+ ec.SendMessage = func(string) {}
+ ec.CancelAndSend = func(string) {}
+ ec.Exit = func() {}
- // Interactive prompts — return cancelled (no user to prompt).
- PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
- return extensions.PromptSelectResult{Cancelled: true}
- },
- PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
- return extensions.PromptConfirmResult{Cancelled: true}
- },
- PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
- return extensions.PromptInputResult{Cancelled: true}
- },
- ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
- return extensions.OverlayResult{Cancelled: true, Index: -1}
- },
- SuspendTUI: func(callback func()) error { callback(); return nil },
+ // TUI widgets/chrome — silent no-ops (no TUI in ACP).
+ ec.SetWidget = func(extensions.WidgetConfig) {}
+ ec.RemoveWidget = func(string) {}
+ ec.SetHeader = func(extensions.HeaderFooterConfig) {}
+ ec.RemoveHeader = func() {}
+ ec.SetFooter = func(extensions.HeaderFooterConfig) {}
+ ec.RemoveFooter = func() {}
+ ec.SetEditor = func(extensions.EditorConfig) {}
+ ec.ResetEditor = func() {}
+ ec.SetEditorText = func(string) {}
+ ec.SetUIVisibility = func(extensions.UIVisibility) {}
+ ec.SetStatus = func(string, string, int) {}
+ ec.RemoveStatus = func(string) {}
- // Data access — delegate to Kit instance.
- GetContextStats: func() extensions.ContextStats {
- s := kitInstance.GetContextStats()
- return extensions.ContextStats{
- EstimatedTokens: s.EstimatedTokens,
- ContextLimit: s.ContextLimit,
- UsagePercent: s.UsagePercent,
- MessageCount: s.MessageCount,
- }
- },
- GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
- GetSessionPath: func() string { return kitInstance.GetSessionPath() },
- AppendEntry: func(entryType, data string) (string, error) {
- return kitInstance.Extensions().AppendEntry(entryType, data)
- },
- GetEntries: func(entryType string) []extensions.ExtensionEntry {
- return kitInstance.Extensions().GetEntries(entryType)
- },
+ // Interactive prompts — return cancelled (no user to prompt).
+ ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
+ return extensions.PromptSelectResult{Cancelled: true}
+ }
+ ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
+ return extensions.PromptConfirmResult{Cancelled: true}
+ }
+ ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult {
+ return extensions.PromptInputResult{Cancelled: true}
+ }
+ ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult {
+ return extensions.OverlayResult{Cancelled: true, Index: -1}
+ }
+ ec.SuspendTUI = func(callback func()) error { callback(); return nil }
- // Options, model, and tool management.
- GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
- SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
- SetModel: func(modelString string) error {
- previousModel := kitInstance.Extensions().GetContext().Model
- if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
- return err
- }
- kitInstance.Extensions().UpdateContextModel(modelString)
- kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
- return nil
- },
- GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
- EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
- GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
- SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
+ // Render — fall back to logging.
+ ec.RenderMessage = func(name, content string) {
+ renderer := kitInstance.Extensions().GetMessageRenderer(name)
+ if renderer != nil && renderer.Render != nil {
+ content = renderer.Render(content, 80)
+ }
+ log.Info("extension: message", "renderer", name, "content", content)
+ }
- // LLM completions and subagents.
- Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
- return kitInstance.ExecuteCompletion(context.Background(), req)
- },
- SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
- return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
- },
-
- // Render — fall back to logging.
- RenderMessage: func(name, content string) {
- renderer := kitInstance.Extensions().GetMessageRenderer(name)
- if renderer != nil && renderer.Render != nil {
- content = renderer.Render(content, 80)
- }
- log.Info("extension: message", "renderer", name, "content", content)
- },
- ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
- })
+ kitInstance.Extensions().SetContext(ec)
kitInstance.Extensions().EmitSessionStart()
}
diff --git a/internal/app/messages.go b/internal/app/messages.go
index 0ec94d0e..4eaae954 100644
--- a/internal/app/messages.go
+++ b/internal/app/messages.go
@@ -13,11 +13,6 @@ type MessageStore struct {
messages []kit.LLMMessage
}
-// NewMessageStore creates an empty MessageStore.
-func NewMessageStore() *MessageStore {
- return &MessageStore{}
-}
-
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
// given messages. This is used when loading an existing session at startup.
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
diff --git a/internal/app/messages_test.go b/internal/app/messages_test.go
index 8c4ce598..c5339eb3 100644
--- a/internal/app/messages_test.go
+++ b/internal/app/messages_test.go
@@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string {
// --------------------------------------------------------------------------
func TestNewMessageStore_empty(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
if s == nil {
t.Fatal("expected non-nil store")
}
@@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
// --------------------------------------------------------------------------
func TestAdd_appendsMessage(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "first"))
s.Add(makeTextMsg("assistant", "second"))
@@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) {
}
func TestAdd_preservesOrder(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
texts := []string{"a", "b", "c"}
for _, t2 := range texts {
s.Add(makeTextMsg("user", t2))
@@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) {
// --------------------------------------------------------------------------
func TestReplace_swapsHistory(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "old"))
replacement := []kit.LLMMessage{
@@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) {
// Replace must deep-copy the incoming slice.
func TestReplace_isolatesInput(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
s.Replace(replacement)
@@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) {
// --------------------------------------------------------------------------
func TestGetAll_returnsCopy(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "hello"))
got := s.GetAll()
@@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) {
}
func TestGetAll_emptyStore(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
got := s.GetAll()
if len(got) != 0 {
t.Fatalf("expected empty slice, got %d elements", len(got))
@@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) {
// --------------------------------------------------------------------------
func TestClear_removesAllMessages(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "a"))
s.Add(makeTextMsg("user", "b"))
s.Clear()
@@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) {
}
func TestClear_allowsSubsequentAdds(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
s.Add(makeTextMsg("user", "before"))
s.Clear()
s.Add(makeTextMsg("user", "after"))
@@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
// --------------------------------------------------------------------------
func TestConcurrentAccess(t *testing.T) {
- s := NewMessageStore()
+ s := NewMessageStoreWithMessages(nil)
done := make(chan struct{})
// Writer goroutine.
diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go
index 44d95ff4..92a62d9a 100644
--- a/internal/auth/credentials.go
+++ b/internal/auth/credentials.go
@@ -513,6 +513,20 @@ func validateAnthropicAPIKey(apiKey string) error {
return nil
}
+// CredentialSourceOAuth is the source description returned by
+// GetAnthropicAPIKey when the key resolves to stored OAuth credentials.
+// Consumers should compare against this constant (or use IsAnthropicOAuth)
+// rather than matching the string literal.
+const CredentialSourceOAuth = "stored OAuth credentials"
+
+// IsAnthropicOAuth reports whether the active Anthropic credential resolves
+// to a stored OAuth token (in which case the user is not billed per-token).
+// flagValue is the --provider-api-key flag value (may be empty).
+func IsAnthropicOAuth(flagValue string) bool {
+ _, source, err := GetAnthropicAPIKey(flagValue)
+ return err == nil && source == CredentialSourceOAuth
+}
+
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
// 1. Command-line flag value (highest priority)
// 2. Stored credentials (OAuth or API key)
@@ -535,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
if err != nil {
return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err)
}
- return token, "stored OAuth credentials", nil
+ return token, CredentialSourceOAuth, nil
} else if creds.Type == "api_key" && creds.APIKey != "" {
return creds.APIKey, "stored API key", nil
}
diff --git a/internal/config/substitution.go b/internal/config/substitution.go
index b4b143cd..71f06ada 100644
--- a/internal/config/substitution.go
+++ b/internal/config/substitution.go
@@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
return result, nil
}
-
-// HasEnvVars checks if content contains environment variable patterns (${env://...}).
-// This is useful for determining if substitution is needed before processing.
-func HasEnvVars(content string) bool {
- return envVarPattern.MatchString(content)
-}
diff --git a/internal/config/substitution_test.go b/internal/config/substitution_test.go
index 10d2eb69..5445672d 100644
--- a/internal/config/substitution_test.go
+++ b/internal/config/substitution_test.go
@@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) {
})
}
}
-
-func TestHasEnvVars(t *testing.T) {
- tests := []struct {
- name string
- content string
- expected bool
- }{
- {
- name: "has env vars",
- content: `{"token": "${env://GITHUB_TOKEN}"}`,
- expected: true,
- },
- {
- name: "has env vars with default",
- content: `{"debug": "${env://DEBUG:-false}"}`,
- expected: true,
- },
- {
- name: "no env vars",
- content: `{"name": "${username}", "normal": "value"}`,
- expected: false,
- },
- {
- name: "empty content",
- content: "",
- expected: false,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result := HasEnvVars(tt.content)
- if result != tt.expected {
- t.Errorf("Expected %v, got %v", tt.expected, result)
- }
- })
- }
-}
diff --git a/internal/core/bash.go b/internal/core/bash.go
index a493f06c..3428fb3e 100644
--- a/internal/core/bash.go
+++ b/internal/core/bash.go
@@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
return nil
}
-// ContextWithSudoPassword returns a new context with the sudo password set.
-// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
-func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
- return context.WithValue(ctx, sudoPasswordKey, password)
-}
-
// sudoPasswordFromContext retrieves the sudo password from context.
func sudoPasswordFromContext(ctx context.Context) string {
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
diff --git a/internal/core/bash_test.go b/internal/core/bash_test.go
index f6ded4a2..5727e648 100644
--- a/internal/core/bash_test.go
+++ b/internal/core/bash_test.go
@@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) {
func TestSudoPasswordFromContext(t *testing.T) {
// Test with password in context
- ctx := ContextWithSudoPassword(context.Background(), "secret123")
+ ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123")
pw := sudoPasswordFromContext(ctx)
if pw != "secret123" {
t.Errorf("expected password 'secret123', got %q", pw)
diff --git a/internal/extbridge/context.go b/internal/extbridge/context.go
new file mode 100644
index 00000000..6f035a47
--- /dev/null
+++ b/internal/extbridge/context.go
@@ -0,0 +1,234 @@
+package extbridge
+
+import (
+ "context"
+
+ "github.com/mark3labs/kit/internal/extensions"
+ kit "github.com/mark3labs/kit/pkg/kit"
+)
+
+// BaseContext returns an extensions.Context populated with the headless,
+// TUI-independent delegation fields: data access, state, options,
+// model/tool management, completions, subagents, tree navigation, skills,
+// template parsing, and model resolution.
+//
+// Callers overlay their UI-specific fields (print routes, widgets, prompts,
+// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value:
+// cmd/extension_context.go for the interactive TUI and
+// internal/acpserver/session.go for headless ACP mode. Keeping the shared
+// half here means a new data-access Context field only has to be wired once.
+//
+// ctx is used for subagent spawns; pass a long-lived context (not a
+// per-request one) so later spawns aren't cancelled prematurely.
+func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context {
+ return extensions.Context{
+ // -------------------------------------------------------------------
+ // Data access
+ // -------------------------------------------------------------------
+ GetContextStats: func() extensions.ContextStats {
+ s := kitInstance.GetContextStats()
+ return extensions.ContextStats{
+ EstimatedTokens: s.EstimatedTokens,
+ ContextLimit: s.ContextLimit,
+ UsagePercent: s.UsagePercent,
+ MessageCount: s.MessageCount,
+ }
+ },
+ GetMessages: func() []extensions.SessionMessage {
+ return kitInstance.Extensions().GetSessionMessages()
+ },
+ GetSessionPath: func() string {
+ return kitInstance.GetSessionPath()
+ },
+ AppendEntry: func(entryType string, data string) (string, error) {
+ return kitInstance.Extensions().AppendEntry(entryType, data)
+ },
+ GetEntries: func(entryType string) []extensions.ExtensionEntry {
+ return kitInstance.Extensions().GetEntries(entryType)
+ },
+
+ // -------------------------------------------------------------------
+ // Extension state
+ // -------------------------------------------------------------------
+ SetState: func(key string, value string) {
+ kitInstance.Extensions().SetState(key, value)
+ },
+ GetState: func(key string) (string, bool) {
+ return kitInstance.Extensions().GetState(key)
+ },
+ DeleteState: func(key string) {
+ kitInstance.Extensions().DeleteState(key)
+ },
+ ListState: func() []string {
+ return kitInstance.Extensions().ListState()
+ },
+
+ // -------------------------------------------------------------------
+ // Options, model, and tool management
+ // -------------------------------------------------------------------
+ GetOption: func(name string) string {
+ return kitInstance.Extensions().GetOption(name)
+ },
+ SetOption: func(name string, value string) {
+ kitInstance.Extensions().SetOption(name, value)
+ },
+ // Headless model switch. The interactive TUI overrides this with a
+ // version that also notifies the TUI and refreshes the usage tracker.
+ SetModel: func(modelString string) error {
+ previousModel := kitInstance.Extensions().GetContext().Model
+ if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
+ return err
+ }
+ kitInstance.Extensions().UpdateContextModel(modelString)
+ kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
+ return nil
+ },
+ GetAvailableModels: func() []extensions.ModelInfoEntry {
+ return kitInstance.GetAvailableModels()
+ },
+ EmitCustomEvent: func(name string, data string) {
+ kitInstance.Extensions().EmitCustomEvent(name, data)
+ },
+ GetAllTools: func() []extensions.ToolInfo {
+ return kitInstance.Extensions().GetToolInfos()
+ },
+ SetActiveTools: func(names []string) {
+ kitInstance.Extensions().SetActiveTools(names)
+ },
+ // Headless reload. The interactive TUI overrides this to also
+ // refresh widgets/status/commands.
+ ReloadExtensions: func() error {
+ return kitInstance.Extensions().Reload()
+ },
+
+ // -------------------------------------------------------------------
+ // LLM completions and subagents
+ // -------------------------------------------------------------------
+ Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
+ return kitInstance.ExecuteCompletion(context.Background(), req)
+ },
+ SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
+ return SpawnSubagent(ctx, kitInstance, config)
+ },
+
+ // -------------------------------------------------------------------
+ // Tree Navigation API
+ // -------------------------------------------------------------------
+ GetTreeNode: func(entryID string) *extensions.TreeNode {
+ node := kitInstance.GetTreeNode(entryID)
+ if node == nil {
+ return nil
+ }
+ return &extensions.TreeNode{
+ ID: node.ID,
+ ParentID: node.ParentID,
+ Type: node.Type,
+ Role: node.Role,
+ Content: node.Content,
+ Model: node.Model,
+ Provider: node.Provider,
+ Timestamp: node.Timestamp,
+ Children: node.Children,
+ }
+ },
+ GetCurrentBranch: func() []extensions.TreeNode {
+ nodes := kitInstance.GetCurrentBranch()
+ result := make([]extensions.TreeNode, len(nodes))
+ for i, n := range nodes {
+ result[i] = extensions.TreeNode{
+ ID: n.ID,
+ ParentID: n.ParentID,
+ Type: n.Type,
+ Role: n.Role,
+ Content: n.Content,
+ Model: n.Model,
+ Provider: n.Provider,
+ Timestamp: n.Timestamp,
+ Children: n.Children,
+ }
+ }
+ return result
+ },
+ GetChildren: func(parentID string) []string {
+ return kitInstance.GetChildren(parentID)
+ },
+ NavigateTo: func(entryID string) extensions.TreeNavigationResult {
+ err := kitInstance.NavigateTo(entryID)
+ if err != nil {
+ return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
+ }
+ return extensions.TreeNavigationResult{Success: true}
+ },
+ SummarizeBranch: func(fromID, toID string) string {
+ summary, _ := kitInstance.SummarizeBranch(fromID, toID)
+ return summary
+ },
+ CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
+ err := kitInstance.CollapseBranch(fromID, toID, summary)
+ if err != nil {
+ return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
+ }
+ return extensions.TreeNavigationResult{Success: true}
+ },
+
+ // -------------------------------------------------------------------
+ // Skill Loading API (context-injection variants are TUI-specific and
+ // wired by the interactive overlay)
+ // -------------------------------------------------------------------
+ LoadSkill: func(path string) (*extensions.Skill, string) {
+ s, err := kitInstance.LoadSkillForExtension(path)
+ return s, err
+ },
+ LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
+ return kitInstance.LoadSkillsFromDirForExtension(dir)
+ },
+ DiscoverSkills: func() extensions.SkillLoadResult {
+ skills := kitInstance.DiscoverSkillsForExtension()
+ return extensions.SkillLoadResult{Skills: skills}
+ },
+ GetAvailableSkills: func() []extensions.Skill {
+ return kitInstance.DiscoverSkillsForExtension()
+ },
+
+ // -------------------------------------------------------------------
+ // Template Parsing API
+ // -------------------------------------------------------------------
+ ParseTemplate: func(name, content string) extensions.PromptTemplate {
+ return kit.ParseTemplate(name, content)
+ },
+ RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
+ return kit.RenderTemplate(tpl, vars)
+ },
+ ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
+ return kit.ParseArguments(input, pattern)
+ },
+ SimpleParseArguments: func(input string, count int) []string {
+ return kit.SimpleParseArguments(input, count)
+ },
+ EvaluateModelConditional: func(condition string) bool {
+ return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
+ },
+ RenderWithModelConditionals: func(content string) string {
+ return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
+ },
+
+ // -------------------------------------------------------------------
+ // Model Resolution API
+ // -------------------------------------------------------------------
+ ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
+ return kit.ResolveModelChain(preferences)
+ },
+ GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
+ return kit.GetModelCapabilities(model)
+ },
+ CheckModelAvailable: func(model string) bool {
+ return kit.CheckModelAvailable(model)
+ },
+ GetCurrentProvider: func() string {
+ return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
+ },
+ GetCurrentModelID: func() string {
+ return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
+ },
+ }
+}
diff --git a/internal/extensions/toolkinds.go b/internal/extensions/toolkinds.go
new file mode 100644
index 00000000..0ac4395b
--- /dev/null
+++ b/internal/extensions/toolkinds.go
@@ -0,0 +1,38 @@
+package extensions
+
+// ToolKind constants classify what a tool does, enabling UIs to render
+// appropriate visualizations (e.g. diff view for edit tools, command+output
+// for execute tools) and file trackers to identify which results contain
+// modifications.
+//
+// This is the single source of truth for tool-kind classification; the
+// pkg/kit SDK re-exports these constants.
+const (
+ ToolKindExecute = "execute" // Shell execution (bash)
+ ToolKindEdit = "edit" // File modification (edit, write)
+ ToolKindRead = "read" // File reading (read, ls)
+ ToolKindSearch = "search" // Content/file search (grep, find)
+ ToolKindSubagent = "agent" // Subagent spawning (subagent)
+)
+
+// coreToolKinds maps built-in tool names to their kind classification.
+// MCP and extension tools without an entry default to ToolKindExecute.
+var coreToolKinds = map[string]string{
+ "bash": ToolKindExecute,
+ "edit": ToolKindEdit,
+ "write": ToolKindEdit,
+ "read": ToolKindRead,
+ "ls": ToolKindRead,
+ "grep": ToolKindSearch,
+ "find": ToolKindSearch,
+ "subagent": ToolKindSubagent,
+}
+
+// ToolKindFor returns the ToolKind for a given tool name, defaulting to
+// ToolKindExecute for unknown tools (including MCP tools).
+func ToolKindFor(toolName string) string {
+ if kind, ok := coreToolKinds[toolName]; ok {
+ return kind
+ }
+ return ToolKindExecute
+}
diff --git a/internal/extensions/wrapper.go b/internal/extensions/wrapper.go
index 6af25692..8119e4c7 100644
--- a/internal/extensions/wrapper.go
+++ b/internal/extensions/wrapper.go
@@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo
return tools
}
-// coreToolKinds maps built-in tool names to their kind classification.
-var coreToolKinds = map[string]string{
- "bash": "execute",
- "edit": "edit",
- "write": "edit",
- "read": "read",
- "ls": "read",
- "grep": "search",
- "find": "search",
- "subagent": "agent",
-}
-
-// toolKindFor returns the ToolKind for a given tool name, defaulting to
-// "execute" for unknown tools (including MCP tools).
-func toolKindFor(toolName string) string {
- if kind, ok := coreToolKinds[toolName]; ok {
- return kind
- }
- return "execute"
-}
-
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
// Returns nil on failure (non-fatal convenience parsing).
func parseToolArgsJSON(input string) map[string]any {
@@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
}
- kind := toolKindFor(toolName)
+ kind := ToolKindFor(toolName)
// 1. Emit ToolCall — extensions can block execution.
if w.runner.HasHandlers(ToolCall) {
diff --git a/internal/models/custom.go b/internal/models/custom.go
index cce5f5a8..b384312a 100644
--- a/internal/models/custom.go
+++ b/internal/models/custom.go
@@ -69,13 +69,6 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
return info
}
-// LoadModelSettingsFromConfig loads per-model generation parameter overrides
-// from the process-global viper store. Keys are "provider/model" strings.
-// Returns nil if no model settings are configured.
-func LoadModelSettingsFromConfig() map[string]*GenerationParams {
- return LoadModelSettingsFrom(viper.GetViper())
-}
-
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
// supplied per-instance store. When v is nil the process-global store is used.
// Keys are "provider/model" strings. Returns nil if no model settings are
diff --git a/internal/models/providers.go b/internal/models/providers.go
index 6995f6dd..c9a16807 100644
--- a/internal/models/providers.go
+++ b/internal/models/providers.go
@@ -932,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
}
// Handle OAuth vs API key authentication
- if strings.HasPrefix(source, "stored OAuth") {
+ if source == auth.CredentialSourceOAuth {
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
opts = append(opts, anthropic.WithHTTPClient(httpClient))
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
diff --git a/internal/prompts/template.go b/internal/prompts/template.go
index cb9eb311..87810a2f 100644
--- a/internal/prompts/template.go
+++ b/internal/prompts/template.go
@@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) {
}
// ParseCommandArgs splits a command line into arguments respecting quotes.
-// It handles single quotes, double quotes, and backslash escaping.
+// It handles single quotes, double quotes, backslash escaping, and splits on
+// spaces and tabs.
func ParseCommandArgs(input string) []string {
var args []string
var current strings.Builder
@@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string {
inDoubleQuote := false
escaped := false
- for i, r := range input {
+ for _, r := range input {
if escaped {
current.WriteRune(r)
escaped = false
@@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string {
continue
}
- if r == ' ' && !inSingleQuote && !inDoubleQuote {
+ if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote {
if current.Len() > 0 {
args = append(args, current.String())
current.Reset()
@@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string {
}
current.WriteRune(r)
- _ = i // silence unused warning when we need position later
}
if current.Len() > 0 {
@@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string {
args := ParseCommandArgs(argsInput)
return SubstituteArgs(t.Content, args)
}
-
-// ExpandWithArgs substitutes the provided arguments into the template content.
-func (t *PromptTemplate) ExpandWithArgs(args []string) string {
- return SubstituteArgs(t.Content, args)
-}
diff --git a/internal/skills/templates.go b/internal/skills/templates.go
index 7902d662..b3c5c963 100644
--- a/internal/skills/templates.go
+++ b/internal/skills/templates.go
@@ -18,8 +18,11 @@ type PromptTemplate struct {
Variables []string
}
-// variableRe matches {{variable_name}} placeholders.
-var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`)
+// variableRe matches {{variable_name}} placeholders, tolerating surrounding
+// whitespace inside the braces (e.g. {{ name }}). This is the canonical
+// template grammar shared by skill prompts and the extension template API
+// (pkg/kit ParseTemplate/RenderTemplate delegate here).
+var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`)
// NewPromptTemplate creates a PromptTemplate, automatically extracting
// variable names from {{...}} placeholders in content.
@@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
// Expand replaces all {{variable}} placeholders with values from the
// provided map. Missing variables are left as-is (no error).
func (t *PromptTemplate) Expand(values map[string]string) string {
- result := t.Content
- for k, v := range values {
- result = strings.ReplaceAll(result, "{{"+k+"}}", v)
- }
- return result
+ return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string {
+ name := variableRe.FindStringSubmatch(m)[1]
+ if v, ok := values[name]; ok {
+ return v
+ }
+ return m
+ })
}
// ExpandStrict replaces all {{variable}} placeholders and returns an error
diff --git a/internal/tools/mcp.go b/internal/tools/mcp.go
index f4a4a66b..d4ab10f0 100644
--- a/internal/tools/mcp.go
+++ b/internal/tools/mcp.go
@@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
Request: mcp.Request{Method: "tools/call"},
Params: callParams,
}
- result, callErr := conn.client.CallTool(ctx, callRequest)
- if callErr != nil {
- if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
- if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
- return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
- }
- result, callErr = conn.client.CallTool(ctx, callRequest)
- if callErr != nil {
- m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
- return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
- }
- } else {
- m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
- return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
- }
+ var result *mcp.CallToolResult
+ err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
+ var callErr error
+ result, callErr = conn.client.CallTool(ctx, callRequest)
+ return callErr
+ })
+ if err != nil {
+ return nil, err
}
- marshaledResult, mErr := json.Marshal(result)
- if mErr != nil {
- return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
- }
- return &MCPToolResult{
- Content: string(marshaledResult),
- IsError: result.IsError,
- }, nil
+ return marshalToolResult(result)
}
// Task-augmented path. Bypass the upstream CallTool helper because its
@@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
}
- marshaledResult, mErr := json.Marshal(result)
- if mErr != nil {
- return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
- }
- return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
+ return marshalToolResult(result)
}
- callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
- if callErr != nil {
- if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
- if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
- return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
- }
- callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
- if callErr != nil {
- m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
- return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
- }
- } else {
- m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
- return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
- }
+ var (
+ callResult *mcp.CallToolResult
+ taskResult *mcp.CreateTaskResult
+ )
+ err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
+ var callErr error
+ callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
+ return callErr
+ })
+ if err != nil {
+ return nil, err
}
// Server chose to answer synchronously — same shape as the no-task path.
if callResult != nil {
- marshaledResult, mErr := json.Marshal(callResult)
- if mErr != nil {
- return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
- }
- return &MCPToolResult{
- Content: string(marshaledResult),
- IsError: callResult.IsError,
- }, nil
+ return marshalToolResult(callResult)
}
// Asynchronous task path: poll until terminal, then return the result.
@@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
}
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
- adapted := &mcp.CallToolResult{
+ return marshalToolResult(&mcp.CallToolResult{
Content: final.Content,
StructuredContent: final.StructuredContent,
IsError: final.IsError,
+ })
+}
+
+// withOAuthRetry runs call once; when it fails with an OAuth error and an
+// OAuth flow is configured, it re-authorizes the server and retries once.
+// Connection failures are reported to the pool and wrapped uniformly. This
+// consolidates the retry/error chain shared by the synchronous and
+// task-augmented tool-call paths.
+func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error {
+ callErr := call()
+ if callErr == nil {
+ return nil
}
- marshaledResult, mErr := json.Marshal(adapted)
- if mErr != nil {
- return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
+ if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
+ if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil {
+ return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr)
+ }
+ if callErr = call(); callErr != nil {
+ m.connectionPool.HandleConnectionError(serverName, callErr)
+ return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
+ }
+ return nil
+ }
+ m.connectionPool.HandleConnectionError(serverName, callErr)
+ return fmt.Errorf("failed to call mcp tool: %w", callErr)
+}
+
+// marshalToolResult converts an MCP CallToolResult into the JSON-encoded
+// MCPToolResult shape returned to the agent.
+func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) {
+ if result == nil {
+ return nil, errors.New("mcp tool call returned nil result")
+ }
+ marshaled, err := json.Marshal(result)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
}
return &MCPToolResult{
- Content: string(marshaledResult),
- IsError: final.IsError,
+ Content: string(marshaled),
+ IsError: result.IsError,
}, nil
}
diff --git a/internal/ui/factory.go b/internal/ui/factory.go
index 7a8e85c9..4cb618d7 100644
--- a/internal/ui/factory.go
+++ b/internal/ui/factory.go
@@ -2,7 +2,6 @@ package ui
import (
"fmt"
- "strings"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/models"
@@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) {
// ollama or unrecognised models). This is used by the interactive TUI path
// which doesn't go through SetupCLI.
func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
- provider, model := parseModelName(modelString)
- if provider == "unknown" || model == "unknown" || provider == "ollama" {
- return nil
- }
-
- registry := models.GetGlobalRegistry()
- modelInfo := registry.LookupModel(provider, model)
+ modelInfo, provider := lookupTrackableModel(modelString)
if modelInfo == nil {
return nil
}
-
- isOAuth := false
- if provider == "anthropic" {
- _, source, err := auth.GetAnthropicAPIKey(providerAPIKey)
- if err == nil && strings.HasPrefix(source, "stored OAuth") {
- isOAuth = true
- }
- }
-
+ isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
return NewUsageTracker(modelInfo, provider, 80, isOAuth)
}
+// UpdateUsageTrackerForModel refreshes an existing tracker after a model
+// switch so token counting and cost reporting use the new model's metadata.
+// No-op for a nil tracker or untrackable models (unknown/ollama).
+func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) {
+ if t == nil {
+ return
+ }
+ modelInfo, provider := lookupTrackableModel(modelString)
+ if modelInfo == nil {
+ return
+ }
+ isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
+ t.UpdateModelInfo(modelInfo, provider, isOAuth)
+}
+
+// lookupTrackableModel resolves a model string to registry metadata, returning
+// nil for models without usage tracking support (unknown or ollama models).
+func lookupTrackableModel(modelString string) (*models.ModelInfo, string) {
+ provider, model := parseModelName(modelString)
+ if provider == "unknown" || model == "unknown" || provider == "ollama" {
+ return nil, provider
+ }
+ return models.GetGlobalRegistry().LookupModel(provider, model), provider
+}
+
// SetupCLI creates, configures, and initializes a CLI instance with the provided
// options. It sets up model display, usage tracking for supported providers, and
// shows initial loading information. Returns nil in quiet mode or an initialized
@@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
}
// Set up usage tracking for supported providers
- if provider != "unknown" && model != "unknown" {
- // Skip usage tracking for ollama as it's not in models.dev
- if provider != "ollama" {
- registry := models.GetGlobalRegistry()
- if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
- // Check if OAuth credentials are being used for Anthropic models
- isOAuth := false
- if provider == "anthropic" {
- _, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
- if err == nil && strings.HasPrefix(source, "stored OAuth") {
- isOAuth = true
- }
- }
-
- usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
- cli.SetUsageTracker(usageTracker)
- }
- }
+ if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil {
+ cli.SetUsageTracker(usageTracker)
}
// Display model info (the system message block provides its own spacing).
diff --git a/internal/ui/model.go b/internal/ui/model.go
index a9a870d4..8e367887 100644
--- a/internal/ui/model.go
+++ b/internal/ui/model.go
@@ -1208,53 +1208,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.modelSelector = nil
m.state = stateInput
if m.setModel != nil {
- previousModel := m.providerName + "/" + m.modelName
-
- // Check if thinking level needs adjustment for the new model.
- // Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
- if m.thinkingLevel != "" && m.thinkingLevel != "off" {
- parts := strings.SplitN(msg.ModelString, "/", 2)
- if len(parts) == 2 {
- modelName := parts[1]
- currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
- if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
- fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName)
- if fallback != models.ThinkingOff {
- m.printSystemMessage(fmt.Sprintf(
- "Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.",
- modelName, currentLevel, fallback,
- ))
- m.thinkingLevel = string(fallback)
- if m.setThinkingLevel != nil {
- _ = m.setThinkingLevel(string(fallback))
- }
- go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }()
- }
- }
- }
- }
-
- if err := m.setModel(msg.ModelString); err != nil {
- m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
- } else {
- // Update display state directly — we cannot use
- // NotifyModelChanged (prog.Send) from inside Update()
- // without deadlocking BubbleTea.
- parts := strings.SplitN(msg.ModelString, "/", 2)
- if len(parts) == 2 {
- m.providerName = parts[0]
- m.modelName = parts[1]
- }
- m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
- // Persist model selection for next launch.
- go func() { _ = prefs.SaveModelPreference(msg.ModelString) }()
- if m.emitModelChange != nil {
- emit := m.emitModelChange
- newModel := msg.ModelString
- prev := previousModel
- go emit(newModel, prev, "user")
- }
- }
+ m.switchModel(msg.ModelString)
}
return m, tea.Batch(cmds...)
@@ -4211,11 +4165,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
return nil
}
+ // Direct model switch with the provided model string.
+ m.switchModel(args)
+ return nil
+}
+
+// switchModel performs a direct model switch, shared by the model selector
+// overlay and the /model slash command: it adjusts the thinking level when
+// the new model doesn't support the current one, calls the setModel
+// callback, updates display state, persists preferences, and emits the
+// ModelChange extension event.
+//
+// Display state is updated directly — we cannot use NotifyModelChanged
+// (prog.Send) from inside Update() without deadlocking BubbleTea.
+func (m *AppModel) switchModel(modelString string) {
+ if m.setModel == nil {
+ m.printSystemMessage("Model switching is not available.")
+ return
+ }
+
+ previousModel := m.providerName + "/" + m.modelName
+
// Check if thinking level needs adjustment for the new model.
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
- parts := strings.SplitN(args, "/", 2)
- if len(parts) == 2 {
+ if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
modelName := parts[1]
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
@@ -4235,32 +4209,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
}
}
- // Direct model switch with the provided model string.
- previousModel := m.providerName + "/" + m.modelName
- if err := m.setModel(args); err != nil {
+ if err := m.setModel(modelString); err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
- return nil
+ return
}
// Update display state directly (cannot use prog.Send from Update).
- parts := strings.SplitN(args, "/", 2)
- if len(parts) == 2 {
+ if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
m.providerName = parts[0]
m.modelName = parts[1]
}
- if m.emitModelChange != nil {
- emit := m.emitModelChange
- prev := previousModel
- newModel := args
- go emit(newModel, prev, "user")
- }
+ m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString))
// Persist model selection for next launch.
- go func() { _ = prefs.SaveModelPreference(args) }()
+ go func() { _ = prefs.SaveModelPreference(modelString) }()
- m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
- return nil
+ if m.emitModelChange != nil {
+ emit := m.emitModelChange
+ go emit(modelString, previousModel, "user")
+ }
}
// --------------------------------------------------------------------------
@@ -4827,61 +4795,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
return r
}, name)
- tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
+ tmpPath, err := buildShareFile(name, data, sysPromptJSON)
if err != nil {
- m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err))
+ m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err))
return nil
}
- tmpPath := tmpFile.Name()
-
- // Write the session data with the system prompt entry inserted after the header.
- // The header is the first line, so we write:
- // 1. First line (header) from original data
- // 2. System prompt entry
- // 3. Remaining lines from original data
- lines := strings.Split(string(data), "\n")
- if len(lines) > 0 && lines[len(lines)-1] == "" {
- lines = lines[:len(lines)-1] // Remove trailing empty line
- }
-
- if len(lines) > 0 {
- // Write header (first line)
- if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
- _ = tmpFile.Close()
- _ = os.Remove(tmpPath)
- m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
- return nil
- }
-
- // Write system prompt entry
- if _, err := tmpFile.Write(sysPromptJSON); err != nil {
- _ = tmpFile.Close()
- _ = os.Remove(tmpPath)
- m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
- return nil
- }
- if _, err := tmpFile.WriteString("\n"); err != nil {
- _ = tmpFile.Close()
- _ = os.Remove(tmpPath)
- m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
- return nil
- }
-
- // Write remaining lines
- for i := 1; i < len(lines); i++ {
- if lines[i] == "" {
- continue // Skip empty lines
- }
- if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
- _ = tmpFile.Close()
- _ = os.Remove(tmpPath)
- m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
- return nil
- }
- }
- }
-
- _ = tmpFile.Close()
m.printSystemMessage("Uploading session to GitHub Gist...")
@@ -4907,6 +4825,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
}
}
+// buildShareFile assembles a temp JSONL file containing the session data
+// with the system-prompt entry inserted after the header line. On success
+// the caller owns the returned file and must remove it when done; on error
+// any partially-written temp file has already been cleaned up.
+func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) {
+ tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
+ if err != nil {
+ return "", fmt.Errorf("create temp file: %w", err)
+ }
+ tmpPath = tmpFile.Name()
+ defer func() {
+ _ = tmpFile.Close()
+ if err != nil {
+ _ = os.Remove(tmpPath)
+ }
+ }()
+
+ // Write the session data with the system prompt entry inserted after the
+ // header. The header is the first line, so we write:
+ // 1. First line (header) from original data
+ // 2. System prompt entry
+ // 3. Remaining lines from original data
+ lines := strings.Split(string(data), "\n")
+ if len(lines) > 0 && lines[len(lines)-1] == "" {
+ lines = lines[:len(lines)-1] // Remove trailing empty line
+ }
+ if len(lines) == 0 {
+ return tmpPath, nil
+ }
+
+ if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil {
+ return "", fmt.Errorf("write temp file: %w", err)
+ }
+ if _, err = tmpFile.Write(sysPromptJSON); err != nil {
+ return "", fmt.Errorf("write system prompt: %w", err)
+ }
+ if _, err = tmpFile.WriteString("\n"); err != nil {
+ return "", fmt.Errorf("write temp file: %w", err)
+ }
+ for i := 1; i < len(lines); i++ {
+ if lines[i] == "" {
+ continue // Skip empty lines
+ }
+ if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil {
+ return "", fmt.Errorf("write temp file: %w", err)
+ }
+ }
+ return tmpPath, nil
+}
+
// handleImportCommand imports a session from a JSONL file.
// Usage: /import path.jsonl
func (m *AppModel) handleImportCommand(args string) tea.Cmd {
diff --git a/pkg/kit/adapter.go b/pkg/kit/adapter.go
index 9e97e0c2..61add1c4 100644
--- a/pkg/kit/adapter.go
+++ b/pkg/kit/adapter.go
@@ -11,12 +11,12 @@ import (
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
- inner *session.TreeManager
+ inner *TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
-func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
+func NewTreeManagerAdapter(tm *TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
diff --git a/pkg/kit/events.go b/pkg/kit/events.go
index c5d73223..d662f1df 100644
--- a/pkg/kit/events.go
+++ b/pkg/kit/events.go
@@ -3,6 +3,8 @@ package kit
import (
"encoding/json"
"sync"
+
+ "github.com/mark3labs/kit/internal/extensions"
)
// ---------------------------------------------------------------------------
@@ -103,34 +105,21 @@ type Event interface {
// appropriate visualizations (e.g. diff view for edit tools, command+output
// for execute tools) and file trackers to identify which results contain
// modifications.
+//
+// These constants re-export the canonical classification used by extension
+// events, so SDK events and extension events always agree.
const (
- ToolKindExecute = "execute" // Shell execution (bash)
- ToolKindEdit = "edit" // File modification (edit, write)
- ToolKindRead = "read" // File reading (read, ls)
- ToolKindSearch = "search" // Content/file search (grep, find)
- ToolKindSubagent = "agent" // Subagent spawning (subagent)
+ ToolKindExecute = extensions.ToolKindExecute // Shell execution (bash)
+ ToolKindEdit = extensions.ToolKindEdit // File modification (edit, write)
+ ToolKindRead = extensions.ToolKindRead // File reading (read, ls)
+ ToolKindSearch = extensions.ToolKindSearch // Content/file search (grep, find)
+ ToolKindSubagent = extensions.ToolKindSubagent // Subagent spawning (subagent)
)
-// coreToolKinds maps built-in tool names to their kind. MCP and extension
-// tools without an entry default to ToolKindExecute.
-var coreToolKinds = map[string]string{
- "bash": ToolKindExecute,
- "edit": ToolKindEdit,
- "write": ToolKindEdit,
- "read": ToolKindRead,
- "ls": ToolKindRead,
- "grep": ToolKindSearch,
- "find": ToolKindSearch,
- "subagent": ToolKindSubagent,
-}
-
// toolKindFor returns the ToolKind for a given tool name, defaulting to
// ToolKindExecute for unknown tools.
func toolKindFor(toolName string) string {
- if kind, ok := coreToolKinds[toolName]; ok {
- return kind
- }
- return ToolKindExecute
+ return extensions.ToolKindFor(toolName)
}
// parseToolArgs attempts to parse a JSON-encoded tool args string into a map.
diff --git a/pkg/kit/extensions_bridge.go b/pkg/kit/extensions_bridge.go
index 36e069bb..d80dc927 100644
--- a/pkg/kit/extensions_bridge.go
+++ b/pkg/kit/extensions_bridge.go
@@ -578,18 +578,13 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
}
// isAnthropicOAuth reports whether the current Anthropic credential resolves
-// to a stored OAuth token (in which case the user is not billed per-token).
-// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker
-// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
+// to a stored OAuth token (in which case the user is not billed per-token),
+// so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
func isAnthropicOAuth(m *Kit, provider string) bool {
if m == nil || provider != "anthropic" {
return false
}
- _, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key"))
- if err != nil {
- return false
- }
- return strings.HasPrefix(source, "stored OAuth")
+ return auth.IsAnthropicOAuth(m.v.GetString("provider-api-key"))
}
// llmToContextMessages converts a slice of LLM messages to extension
diff --git a/pkg/kit/kit.go b/pkg/kit/kit.go
index f46b8c2a..12724842 100644
--- a/pkg/kit/kit.go
+++ b/pkg/kit/kit.go
@@ -1160,7 +1160,7 @@ type CLIOptions struct {
// - Continue: resume most recent session for SessionDir (or cwd)
// - SessionPath: open a specific JSONL session file
// - default: create a new tree session for SessionDir (or cwd)
-func InitTreeSession(opts *Options) (*session.TreeManager, error) {
+func InitTreeSession(opts *Options) (*TreeManager, error) {
if opts == nil {
opts = &Options{}
}
diff --git a/pkg/kit/template_bridge.go b/pkg/kit/template_bridge.go
index 98307b11..ba0b0c58 100644
--- a/pkg/kit/template_bridge.go
+++ b/pkg/kit/template_bridge.go
@@ -7,45 +7,36 @@ import (
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
+ "github.com/mark3labs/kit/internal/prompts"
+ "github.com/mark3labs/kit/internal/skills"
)
// ---------------------------------------------------------------------------
// Template Parsing Bridge for Extensions (Phase 3)
// ---------------------------------------------------------------------------
-// varRegex matches {{variable}} placeholders in templates.
-var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
-
-// ParseTemplate extracts {{variables}} from template content.
+// ParseTemplate extracts {{variables}} from template content. The template
+// grammar is shared with skill prompt templates, so a template parses
+// identically regardless of which API loads it.
func ParseTemplate(name, content string) extensions.PromptTemplate {
- matches := varRegex.FindAllStringSubmatch(content, -1)
- vars := make([]string, 0, len(matches))
- seen := make(map[string]bool)
- for _, m := range matches {
- if len(m) > 1 && !seen[m[1]] {
- seen[m[1]] = true
- vars = append(vars, m[1])
- }
+ tpl := skills.NewPromptTemplate(name, content)
+ vars := tpl.Variables
+ if vars == nil {
+ vars = []string{}
}
return extensions.PromptTemplate{
- Name: name,
- Content: content,
+ Name: tpl.Name,
+ Content: tpl.Content,
Variables: vars,
}
}
// RenderTemplate substitutes variables into template content.
-// Handles {{name}} and {{ name }} (any whitespace) placeholders.
+// Handles {{name}} and {{ name }} (any whitespace) placeholders; missing
+// variables are left as-is.
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
- return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
- sub := varRegex.FindStringSubmatch(m)
- if len(sub) > 1 {
- if v, ok := vars[sub[1]]; ok {
- return v
- }
- }
- return m
- })
+ t := skills.PromptTemplate{Content: tpl.Content}
+ return t.Expand(vars)
}
// ParseArguments parses command-line style arguments.
@@ -183,44 +174,12 @@ func SimpleParseArguments(input string, count int) []string {
return result
}
-// parseFields splits input respecting quoted strings.
+// parseFields splits input into arguments respecting quoted strings and
+// backslash escaping. It delegates to the canonical tokenizer in
+// internal/prompts so extension argument parsing and builtin prompt-template
+// parsing agree on grammar.
func parseFields(input string) []string {
- var fields []string
- var current strings.Builder
- inQuote := false
- quoteChar := rune(0)
-
- for _, r := range input {
- switch r {
- case '"', '\'':
- if !inQuote {
- inQuote = true
- quoteChar = r
- } else if r == quoteChar {
- inQuote = false
- quoteChar = 0
- } else {
- current.WriteRune(r)
- }
- case ' ', '\t':
- if inQuote {
- current.WriteRune(r)
- } else {
- if current.Len() > 0 {
- fields = append(fields, current.String())
- current.Reset()
- }
- }
- default:
- current.WriteRune(r)
- }
- }
-
- if current.Len() > 0 {
- fields = append(fields, current.String())
- }
-
- return fields
+ return prompts.ParseCommandArgs(input)
}
// EvaluateModelConditional checks if condition matches current model.
@@ -417,21 +376,18 @@ func MatchModelGlob(model, pattern string) bool {
}
// ExtractProviderFromPath extracts provider from a path-like model string.
+//
+// Deprecated: Use GetCurrentProvider instead.
func ExtractProviderFromPath(model string) string {
- parts := strings.Split(model, "/")
- if len(parts) >= 2 {
- return parts[0]
- }
- return ""
+ return GetCurrentProvider(model)
}
// ExtractModelFromPath extracts model ID from a path-like model string.
+//
+// Deprecated: Use RemoveProviderFromModel instead, which correctly handles
+// model IDs containing "/" (e.g. "openrouter/meta/llama").
func ExtractModelFromPath(model string) string {
- parts := strings.Split(model, "/")
- if len(parts) >= 2 {
- return parts[1]
- }
- return model
+ return RemoveProviderFromModel(model)
}
// IsBareModelID checks if a string is a bare model ID (no provider).