mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f366eab84 | |||
| e8e99b19a8 |
@@ -128,6 +128,12 @@ temperature: 0.7
|
||||
stream: true
|
||||
thinking-level: off # off, none, minimal, low, medium, high
|
||||
no-core-tools: false # set to true to disable all built-in core tools
|
||||
|
||||
# Skills — all three keys are optional
|
||||
no-skills: false # set to true to disable all skill loading
|
||||
skill: # explicit skill files/dirs (disables auto-discovery)
|
||||
- /path/to/skill.md
|
||||
skills-dir: "" # override project-local directory for auto-discovery
|
||||
```
|
||||
|
||||
All of the above keys can also be set programmatically via the SDK
|
||||
@@ -203,6 +209,11 @@ mcpServers:
|
||||
--prompt-template Load a specific prompt template by name
|
||||
--no-prompt-templates Disable prompt template loading
|
||||
|
||||
# Skills
|
||||
--skill Load skill file or directory (repeatable)
|
||||
--skills-dir Override the project-local skills directory for auto-discovery
|
||||
--no-skills Disable skill loading (auto-discovery and explicit)
|
||||
|
||||
# Generation parameters
|
||||
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
|
||||
--temperature Randomness 0.0-1.0 (default: 0.7)
|
||||
|
||||
+264
-441
@@ -4,13 +4,11 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -35,451 +33,276 @@ type extensionContextDeps struct {
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
// The headless half (data access, state, options, tree navigation, skills,
|
||||
// templates, model resolution, subagents) comes from extbridge.BaseContext;
|
||||
// this function overlays the TUI-specific fields and overrides SetModel /
|
||||
// ReloadExtensions with TUI-aware versions.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetState: func(key string, value string) {
|
||||
kitInstance.Extensions().SetState(key, value)
|
||||
},
|
||||
GetState: func(key string) (string, bool) {
|
||||
return kitInstance.Extensions().GetState(key)
|
||||
},
|
||||
DeleteState: func(key string) {
|
||||
kitInstance.Extensions().DeleteState(key)
|
||||
},
|
||||
ListState: func() []string {
|
||||
return kitInstance.Extensions().ListState()
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
ec := extbridge.BaseContext(deps.ctx, kitInstance)
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
ec.CWD = deps.cwd
|
||||
ec.Model = deps.modelName
|
||||
ec.Interactive = deps.interactive
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
}
|
||||
ec.SendMessage = func(text string) { appInstance.Run(text) }
|
||||
ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) }
|
||||
ec.Abort = func() { appInstance.Abort() }
|
||||
ec.IsIdle = func() bool { return !appInstance.IsBusy() }
|
||||
ec.Compact = func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
}
|
||||
ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
}
|
||||
ec.GetSessionUsage = func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
}
|
||||
ec.Exit = func() { appInstance.QuitFromExtension() }
|
||||
|
||||
// TUI widgets/chrome — mutate runner state, then notify the TUI.
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks
|
||||
// if called synchronously from inside BubbleTea's Update() handler.
|
||||
// All call sites use go-routines uniformly.
|
||||
ec.SetWidget = func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveWidget = func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetHeader = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveHeader = func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetFooter = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveFooter = func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetUIVisibility = func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditor = func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.ResetEditor = func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditorText = func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
}
|
||||
ec.SetStatus = func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveStatus = func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
|
||||
// Interactive prompts — channel-based round trips through the TUI.
|
||||
ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
}
|
||||
ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
}
|
||||
ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
}
|
||||
ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
}
|
||||
|
||||
// TUI-aware model switch: also notifies the TUI status bar and
|
||||
// refreshes the usage tracker for correct token counting.
|
||||
ec.SetModel = func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
|
||||
return nil
|
||||
}
|
||||
|
||||
ec.RenderMessage = func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
}
|
||||
ec.ReloadExtensions = func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Theme management (TUI only).
|
||||
ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
}
|
||||
ec.SetTheme = func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
}
|
||||
ec.ListThemes = func() []string {
|
||||
return ui.ListThemes()
|
||||
}
|
||||
|
||||
// Skill context-injection (drives a new agent turn through the TUI).
|
||||
ec.InjectSkillAsContext = func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
}
|
||||
ec.InjectRawSkillAsContext = func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
|
||||
return ec
|
||||
}
|
||||
|
||||
+60
-36
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -74,6 +73,11 @@ var (
|
||||
noCoreToolsFlag bool
|
||||
extensionPaths []string
|
||||
|
||||
// Skills control
|
||||
noSkillsFlag bool
|
||||
skillsPaths []string
|
||||
skillsDir string
|
||||
|
||||
// TLS configuration
|
||||
tlsSkipVerify bool
|
||||
|
||||
@@ -284,6 +288,14 @@ func init() {
|
||||
rootCmd.PersistentFlags().
|
||||
StringSliceVarP(&extensionPaths, "extension", "e", nil, "load additional extension file(s)")
|
||||
|
||||
// Skills flags
|
||||
rootCmd.PersistentFlags().
|
||||
BoolVar(&noSkillsFlag, "no-skills", false, "disable skill loading (auto-discovery and explicit)")
|
||||
rootCmd.PersistentFlags().
|
||||
StringSliceVar(&skillsPaths, "skill", nil, "load skill file or directory (repeatable)")
|
||||
rootCmd.PersistentFlags().
|
||||
StringVar(&skillsDir, "skills-dir", "", "override the project-local skills directory for auto-discovery")
|
||||
|
||||
flags := rootCmd.PersistentFlags()
|
||||
flags.StringVar(&providerURL, "provider-url", "", "base URL for the provider API (applies to OpenAI, Anthropic, Ollama, and Google)")
|
||||
flags.StringVar(&providerAPIKey, "provider-api-key", "", "API key for the provider (applies to OpenAI, Anthropic, and Google)")
|
||||
@@ -334,6 +346,9 @@ func init() {
|
||||
_ = viper.BindPFlag("extension", rootCmd.PersistentFlags().Lookup("extension"))
|
||||
_ = viper.BindPFlag("prompt-template", rootCmd.PersistentFlags().Lookup("prompt-template"))
|
||||
_ = viper.BindPFlag("no-prompt-templates", rootCmd.PersistentFlags().Lookup("no-prompt-templates"))
|
||||
_ = viper.BindPFlag("no-skills", rootCmd.PersistentFlags().Lookup("no-skills"))
|
||||
_ = viper.BindPFlag("skill", rootCmd.PersistentFlags().Lookup("skill"))
|
||||
_ = viper.BindPFlag("skills-dir", rootCmd.PersistentFlags().Lookup("skills-dir"))
|
||||
|
||||
// Defaults are already set in flag definitions, no need to duplicate in viper
|
||||
|
||||
@@ -677,8 +692,8 @@ func globalShortcutsProviderForUI(k *kit.Kit) func() map[string]func() {
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
// Validate flag combinations
|
||||
// validateModeFlags rejects invalid flag combinations for the root command.
|
||||
func validateModeFlags() error {
|
||||
if quietFlag && positionalPrompt == "" {
|
||||
return fmt.Errorf("--quiet requires a prompt (e.g. kit \"your question\" --quiet)")
|
||||
}
|
||||
@@ -691,21 +706,14 @@ func runNormalMode(ctx context.Context) error {
|
||||
if noExitFlag && positionalPrompt == "" {
|
||||
return fmt.Errorf("--no-exit requires a prompt (e.g. kit \"your question\" --no-exit)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
if debugMode {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Update debug mode from viper
|
||||
if viper.GetBool("debug") && !debugMode {
|
||||
debugMode = viper.GetBool("debug")
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Restore persisted model preference when no explicit --model flag or
|
||||
// config file model is set. Precedence: CLI flag > config file > saved
|
||||
// preference > built-in default. This mirrors how themes are persisted.
|
||||
// restorePersistedPreferences applies saved model / thinking-level
|
||||
// preferences into viper when neither a CLI flag nor a config-file value
|
||||
// takes precedence. Precedence: CLI flag > config file > saved preference >
|
||||
// built-in default. This mirrors how themes are persisted.
|
||||
func restorePersistedPreferences() {
|
||||
// Skip custom/* models unless --provider-url is also provided, since the
|
||||
// custom provider requires a URL that was only valid for the previous session.
|
||||
if !modelFlagChanged && !viper.InConfig("model") {
|
||||
@@ -724,6 +732,15 @@ func runNormalMode(ctx context.Context) error {
|
||||
viper.Set("thinking-level", pref)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyProviderURLRouting rewrites the model in viper when --provider-url
|
||||
// is set, routing requests through the "custom" (OpenAI-compatible)
|
||||
// provider. Must run after restorePersistedPreferences.
|
||||
func applyProviderURLRouting() {
|
||||
if viper.GetString("provider-url") == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// When --provider-url is set but no explicit --model was provided,
|
||||
// default to "custom/custom" so the user doesn't need to remember a
|
||||
@@ -731,7 +748,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
// This intentionally overrides saved preferences but respects config-file
|
||||
// models — if you specify a model in ~/.kit.yml, it will be used with
|
||||
// custom/custom's provider routing.
|
||||
if viper.GetString("provider-url") != "" && !modelFlagChanged && !viper.InConfig("model") {
|
||||
if !modelFlagChanged && !viper.InConfig("model") {
|
||||
viper.Set("model", "custom/custom")
|
||||
}
|
||||
|
||||
@@ -746,7 +763,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
// to point a non-OpenAI wire (Anthropic, Google, ...) at a proxy URL,
|
||||
// use the explicit `custom/<name>` form to opt out of the rewrite by
|
||||
// configuring the proxy as that provider in your config file instead.
|
||||
if viper.GetString("provider-url") != "" && modelFlagChanged {
|
||||
if modelFlagChanged {
|
||||
model := viper.GetString("model")
|
||||
if model != "" {
|
||||
name := model
|
||||
@@ -758,6 +775,26 @@ func runNormalMode(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runNormalMode(ctx context.Context) error {
|
||||
if err := validateModeFlags(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up logging
|
||||
if debugMode {
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
// Update debug mode from viper
|
||||
if viper.GetBool("debug") && !debugMode {
|
||||
debugMode = viper.GetBool("debug")
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
}
|
||||
|
||||
restorePersistedPreferences()
|
||||
applyProviderURLRouting()
|
||||
|
||||
// Load MCP configuration.
|
||||
mcpConfig, err := config.LoadAndValidateConfig()
|
||||
@@ -799,6 +836,9 @@ func runNormalMode(ctx context.Context) error {
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
DisableCoreTools: viper.GetBool("no-core-tools"),
|
||||
NoSkills: noSkillsFlag,
|
||||
Skills: skillsPaths,
|
||||
SkillsDir: skillsDir,
|
||||
// This callback is called when each MCP server finishes loading.
|
||||
// We use a closure that captures appInstancePtr which is set after
|
||||
// app.New() is called below.
|
||||
@@ -1164,23 +1204,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
// NotifyModelChanged calls prog.Send() which deadlocks. The UI layer
|
||||
// updates m.providerName and m.modelName directly after setModel returns.
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
|
||||
return nil
|
||||
}
|
||||
emitModelChangeForUI := func(newModel, previousModel, source string) {
|
||||
|
||||
@@ -73,111 +73,70 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Wire extension context with headless implementations so extensions
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
// become no-ops or return cancelled; all data/model/tool APIs come from
|
||||
// extbridge.BaseContext and work identically to interactive mode.
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
Interactive: false,
|
||||
// Use a background context for subagent spawns: the create() ctx is
|
||||
// request-scoped and may be cancelled before extensions spawn anything.
|
||||
ec := extbridge.BaseContext(context.Background(), kitInstance)
|
||||
|
||||
// Output — route through structured logger.
|
||||
Print: func(text string) { log.Debug("extension: print", "text", text) },
|
||||
PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
|
||||
PrintError: func(text string) { log.Error("extension: error", "text", text) },
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
},
|
||||
ec.SessionID = sessionID
|
||||
ec.CWD = cwd
|
||||
ec.Model = kitInstance.GetModelString()
|
||||
ec.Interactive = false
|
||||
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
SendMessage: func(string) {},
|
||||
CancelAndSend: func(string) {},
|
||||
Exit: func() {},
|
||||
// Output — route through structured logger.
|
||||
ec.Print = func(text string) { log.Debug("extension: print", "text", text) }
|
||||
ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) }
|
||||
ec.PrintError = func(text string) { log.Error("extension: error", "text", text) }
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
}
|
||||
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
SetWidget: func(extensions.WidgetConfig) {},
|
||||
RemoveWidget: func(string) {},
|
||||
SetHeader: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveHeader: func() {},
|
||||
SetFooter: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveFooter: func() {},
|
||||
SetEditor: func(extensions.EditorConfig) {},
|
||||
ResetEditor: func() {},
|
||||
SetEditorText: func(string) {},
|
||||
SetUIVisibility: func(extensions.UIVisibility) {},
|
||||
SetStatus: func(string, string, int) {},
|
||||
RemoveStatus: func(string) {},
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
ec.SendMessage = func(string) {}
|
||||
ec.CancelAndSend = func(string) {}
|
||||
ec.Exit = func() {}
|
||||
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
},
|
||||
PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
},
|
||||
PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
},
|
||||
ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
},
|
||||
SuspendTUI: func(callback func()) error { callback(); return nil },
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
ec.SetWidget = func(extensions.WidgetConfig) {}
|
||||
ec.RemoveWidget = func(string) {}
|
||||
ec.SetHeader = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveHeader = func() {}
|
||||
ec.SetFooter = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveFooter = func() {}
|
||||
ec.SetEditor = func(extensions.EditorConfig) {}
|
||||
ec.ResetEditor = func() {}
|
||||
ec.SetEditorText = func(string) {}
|
||||
ec.SetUIVisibility = func(extensions.UIVisibility) {}
|
||||
ec.SetStatus = func(string, string, int) {}
|
||||
ec.RemoveStatus = func(string) {}
|
||||
|
||||
// Data access — delegate to Kit instance.
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
// Render — fall back to logging.
|
||||
ec.RenderMessage = func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
}
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.Extensions().SetContext(ec)
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
|
||||
@@ -13,11 +13,6 @@ type MessageStore struct {
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
func NewMessageStore() *MessageStore {
|
||||
return &MessageStore{}
|
||||
}
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
|
||||
@@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestNewMessageStore_empty(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil store")
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestAdd_appendsMessage(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "first"))
|
||||
s.Add(makeTextMsg("assistant", "second"))
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAdd_preservesOrder(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
texts := []string{"a", "b", "c"}
|
||||
for _, t2 := range texts {
|
||||
s.Add(makeTextMsg("user", t2))
|
||||
@@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []kit.LLMMessage{
|
||||
@@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestGetAll_returnsCopy(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "hello"))
|
||||
|
||||
got := s.GetAll()
|
||||
@@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetAll_emptyStore(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
got := s.GetAll()
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice, got %d elements", len(got))
|
||||
@@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestClear_removesAllMessages(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "a"))
|
||||
s.Add(makeTextMsg("user", "b"))
|
||||
s.Clear()
|
||||
@@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "before"))
|
||||
s.Clear()
|
||||
s.Add(makeTextMsg("user", "after"))
|
||||
@@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
done := make(chan struct{})
|
||||
|
||||
// Writer goroutine.
|
||||
|
||||
@@ -513,6 +513,20 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialSourceOAuth is the source description returned by
|
||||
// GetAnthropicAPIKey when the key resolves to stored OAuth credentials.
|
||||
// Consumers should compare against this constant (or use IsAnthropicOAuth)
|
||||
// rather than matching the string literal.
|
||||
const CredentialSourceOAuth = "stored OAuth credentials"
|
||||
|
||||
// IsAnthropicOAuth reports whether the active Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token).
|
||||
// flagValue is the --provider-api-key flag value (may be empty).
|
||||
func IsAnthropicOAuth(flagValue string) bool {
|
||||
_, source, err := GetAnthropicAPIKey(flagValue)
|
||||
return err == nil && source == CredentialSourceOAuth
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
@@ -535,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
return token, "stored OAuth credentials", nil
|
||||
return token, CredentialSourceOAuth, nil
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
return creds.APIKey, "stored API key", nil
|
||||
}
|
||||
|
||||
@@ -493,6 +493,12 @@ mcpServers:
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# Skills configuration (all optional)
|
||||
# no-skills: false # Set to true to disable all skill loading
|
||||
# skill: # Explicit skill files/dirs (disables auto-discovery)
|
||||
# - "/path/to/skill.md"
|
||||
# skills-dir: "/path/to/skills" # Override project-local directory for auto-discovery
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
@@ -205,6 +205,9 @@ func TestEnsureConfigExists(t *testing.T) {
|
||||
"type: \"local\"",
|
||||
"type: \"remote\"",
|
||||
"Core tools",
|
||||
"# Skills configuration",
|
||||
"no-skills:",
|
||||
"skills-dir:",
|
||||
}
|
||||
|
||||
for _, expected := range expectedSections {
|
||||
|
||||
@@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
|
||||
// This is useful for determining if substitution is needed before processing.
|
||||
func HasEnvVars(content string) bool {
|
||||
return envVarPattern.MatchString(content)
|
||||
}
|
||||
|
||||
@@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has env vars",
|
||||
content: `{"token": "${env://GITHUB_TOKEN}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has env vars with default",
|
||||
content: `{"debug": "${env://DEBUG:-false}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no env vars",
|
||||
content: `{"name": "${username}", "normal": "value"}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasEnvVars(tt.content)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
|
||||
@@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) {
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// BaseContext returns an extensions.Context populated with the headless,
|
||||
// TUI-independent delegation fields: data access, state, options,
|
||||
// model/tool management, completions, subagents, tree navigation, skills,
|
||||
// template parsing, and model resolution.
|
||||
//
|
||||
// Callers overlay their UI-specific fields (print routes, widgets, prompts,
|
||||
// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value:
|
||||
// cmd/extension_context.go for the interactive TUI and
|
||||
// internal/acpserver/session.go for headless ACP mode. Keeping the shared
|
||||
// half here means a new data-access Context field only has to be wired once.
|
||||
//
|
||||
// ctx is used for subagent spawns; pass a long-lived context (not a
|
||||
// per-request one) so later spawns aren't cancelled prematurely.
|
||||
func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context {
|
||||
return extensions.Context{
|
||||
// -------------------------------------------------------------------
|
||||
// Data access
|
||||
// -------------------------------------------------------------------
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Extension state
|
||||
// -------------------------------------------------------------------
|
||||
SetState: func(key string, value string) {
|
||||
kitInstance.Extensions().SetState(key, value)
|
||||
},
|
||||
GetState: func(key string) (string, bool) {
|
||||
return kitInstance.Extensions().GetState(key)
|
||||
},
|
||||
DeleteState: func(key string) {
|
||||
kitInstance.Extensions().DeleteState(key)
|
||||
},
|
||||
ListState: func() []string {
|
||||
return kitInstance.Extensions().ListState()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Options, model, and tool management
|
||||
// -------------------------------------------------------------------
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
// Headless model switch. The interactive TUI overrides this with a
|
||||
// version that also notifies the TUI and refreshes the usage tracker.
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
// Headless reload. The interactive TUI overrides this to also
|
||||
// refresh widgets/status/commands.
|
||||
ReloadExtensions: func() error {
|
||||
return kitInstance.Extensions().Reload()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// LLM completions and subagents
|
||||
// -------------------------------------------------------------------
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API (context-injection variants are TUI-specific and
|
||||
// wired by the interactive overlay)
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package extensions
|
||||
|
||||
// ToolKind constants classify what a tool does, enabling UIs to render
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
//
|
||||
// This is the single source of truth for tool-kind classification; the
|
||||
// pkg/kit SDK re-exports these constants.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
// MCP and extension tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// ToolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools (including MCP tools).
|
||||
func ToolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
}
|
||||
@@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo
|
||||
return tools
|
||||
}
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// "execute" for unknown tools (including MCP tools).
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return "execute"
|
||||
}
|
||||
|
||||
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
|
||||
// Returns nil on failure (non-fatal convenience parsing).
|
||||
func parseToolArgsJSON(input string) map[string]any {
|
||||
@@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
kind := ToolKindFor(toolName)
|
||||
|
||||
// 1. Emit ToolCall — extensions can block execution.
|
||||
if w.runner.HasHandlers(ToolCall) {
|
||||
|
||||
@@ -69,13 +69,6 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the process-global viper store. Keys are "provider/model" strings.
|
||||
// Returns nil if no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
return LoadModelSettingsFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
|
||||
// supplied per-instance store. When v is nil the process-global store is used.
|
||||
// Keys are "provider/model" strings. Returns nil if no model settings are
|
||||
|
||||
@@ -932,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
|
||||
}
|
||||
|
||||
// Handle OAuth vs API key authentication
|
||||
if strings.HasPrefix(source, "stored OAuth") {
|
||||
if source == auth.CredentialSourceOAuth {
|
||||
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
|
||||
opts = append(opts, anthropic.WithHTTPClient(httpClient))
|
||||
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
|
||||
|
||||
@@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) {
|
||||
}
|
||||
|
||||
// ParseCommandArgs splits a command line into arguments respecting quotes.
|
||||
// It handles single quotes, double quotes, and backslash escaping.
|
||||
// It handles single quotes, double quotes, backslash escaping, and splits on
|
||||
// spaces and tabs.
|
||||
func ParseCommandArgs(input string) []string {
|
||||
var args []string
|
||||
var current strings.Builder
|
||||
@@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string {
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range input {
|
||||
for _, r := range input {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
@@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == ' ' && !inSingleQuote && !inDoubleQuote {
|
||||
if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote {
|
||||
if current.Len() > 0 {
|
||||
args = append(args, current.String())
|
||||
current.Reset()
|
||||
@@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string {
|
||||
}
|
||||
|
||||
current.WriteRune(r)
|
||||
_ = i // silence unused warning when we need position later
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
@@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
args := ParseCommandArgs(argsInput)
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
// ExpandWithArgs substitutes the provided arguments into the template content.
|
||||
func (t *PromptTemplate) ExpandWithArgs(args []string) string {
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
@@ -18,8 +18,11 @@ type PromptTemplate struct {
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// variableRe matches {{variable_name}} placeholders.
|
||||
var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`)
|
||||
// variableRe matches {{variable_name}} placeholders, tolerating surrounding
|
||||
// whitespace inside the braces (e.g. {{ name }}). This is the canonical
|
||||
// template grammar shared by skill prompts and the extension template API
|
||||
// (pkg/kit ParseTemplate/RenderTemplate delegate here).
|
||||
var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`)
|
||||
|
||||
// NewPromptTemplate creates a PromptTemplate, automatically extracting
|
||||
// variable names from {{...}} placeholders in content.
|
||||
@@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
|
||||
// Expand replaces all {{variable}} placeholders with values from the
|
||||
// provided map. Missing variables are left as-is (no error).
|
||||
func (t *PromptTemplate) Expand(values map[string]string) string {
|
||||
result := t.Content
|
||||
for k, v := range values {
|
||||
result = strings.ReplaceAll(result, "{{"+k+"}}", v)
|
||||
}
|
||||
return result
|
||||
return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string {
|
||||
name := variableRe.FindStringSubmatch(m)[1]
|
||||
if v, ok := values[name]; ok {
|
||||
return v
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// ExpandStrict replaces all {{variable}} placeholders and returns an error
|
||||
|
||||
+60
-57
@@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var result *mcp.CallToolResult
|
||||
err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
// Task-augmented path. Bypass the upstream CallTool helper because its
|
||||
@@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var (
|
||||
callResult *mcp.CallToolResult
|
||||
taskResult *mcp.CreateTaskResult
|
||||
)
|
||||
err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Server chose to answer synchronously — same shape as the no-task path.
|
||||
if callResult != nil {
|
||||
marshaledResult, mErr := json.Marshal(callResult)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: callResult.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(callResult)
|
||||
}
|
||||
|
||||
// Asynchronous task path: poll until terminal, then return the result.
|
||||
@@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
}
|
||||
|
||||
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
||||
adapted := &mcp.CallToolResult{
|
||||
return marshalToolResult(&mcp.CallToolResult{
|
||||
Content: final.Content,
|
||||
StructuredContent: final.StructuredContent,
|
||||
IsError: final.IsError,
|
||||
})
|
||||
}
|
||||
|
||||
// withOAuthRetry runs call once; when it fails with an OAuth error and an
|
||||
// OAuth flow is configured, it re-authorizes the server and retries once.
|
||||
// Connection failures are reported to the pool and wrapped uniformly. This
|
||||
// consolidates the retry/error chain shared by the synchronous and
|
||||
// task-augmented tool-call paths.
|
||||
func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error {
|
||||
callErr := call()
|
||||
if callErr == nil {
|
||||
return nil
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(adapted)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil {
|
||||
return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr)
|
||||
}
|
||||
if callErr = call(); callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
|
||||
// marshalToolResult converts an MCP CallToolResult into the JSON-encoded
|
||||
// MCPToolResult shape returned to the agent.
|
||||
func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) {
|
||||
if result == nil {
|
||||
return nil, errors.New("mcp tool call returned nil result")
|
||||
}
|
||||
marshaled, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: final.IsError,
|
||||
Content: string(marshaled),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+29
-35
@@ -2,7 +2,6 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) {
|
||||
// ollama or unrecognised models). This is used by the interactive TUI path
|
||||
// which doesn't go through SetupCLI.
|
||||
func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel(provider, model)
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(providerAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
return NewUsageTracker(modelInfo, provider, 80, isOAuth)
|
||||
}
|
||||
|
||||
// UpdateUsageTrackerForModel refreshes an existing tracker after a model
|
||||
// switch so token counting and cost reporting use the new model's metadata.
|
||||
// No-op for a nil tracker or untrackable models (unknown/ollama).
|
||||
func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return
|
||||
}
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
t.UpdateModelInfo(modelInfo, provider, isOAuth)
|
||||
}
|
||||
|
||||
// lookupTrackableModel resolves a model string to registry metadata, returning
|
||||
// nil for models without usage tracking support (unknown or ollama models).
|
||||
func lookupTrackableModel(modelString string) (*models.ModelInfo, string) {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil, provider
|
||||
}
|
||||
return models.GetGlobalRegistry().LookupModel(provider, model), provider
|
||||
}
|
||||
|
||||
// SetupCLI creates, configures, and initializes a CLI instance with the provided
|
||||
// options. It sets up model display, usage tracking for supported providers, and
|
||||
// shows initial loading information. Returns nil in quiet mode or an initialized
|
||||
@@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
|
||||
// Set up usage tracking for supported providers
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
}
|
||||
if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil {
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
|
||||
+84
-116
@@ -1208,53 +1208,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.modelSelector = nil
|
||||
m.state = stateInput
|
||||
if m.setModel != nil {
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName)
|
||||
if fallback != models.ThinkingOff {
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.",
|
||||
modelName, currentLevel, fallback,
|
||||
))
|
||||
m.thinkingLevel = string(fallback)
|
||||
if m.setThinkingLevel != nil {
|
||||
_ = m.setThinkingLevel(string(fallback))
|
||||
}
|
||||
go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.setModel(msg.ModelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
} else {
|
||||
// Update display state directly — we cannot use
|
||||
// NotifyModelChanged (prog.Send) from inside Update()
|
||||
// without deadlocking BubbleTea.
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(msg.ModelString) }()
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
newModel := msg.ModelString
|
||||
prev := previousModel
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
}
|
||||
m.switchModel(msg.ModelString)
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
|
||||
@@ -4211,11 +4165,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
m.switchModel(args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// switchModel performs a direct model switch, shared by the model selector
|
||||
// overlay and the /model slash command: it adjusts the thinking level when
|
||||
// the new model doesn't support the current one, calls the setModel
|
||||
// callback, updates display state, persists preferences, and emits the
|
||||
// ModelChange extension event.
|
||||
//
|
||||
// Display state is updated directly — we cannot use NotifyModelChanged
|
||||
// (prog.Send) from inside Update() without deadlocking BubbleTea.
|
||||
func (m *AppModel) switchModel(modelString string) {
|
||||
if m.setModel == nil {
|
||||
m.printSystemMessage("Model switching is not available.")
|
||||
return
|
||||
}
|
||||
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
@@ -4235,32 +4209,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
if err := m.setModel(args); err != nil {
|
||||
if err := m.setModel(modelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Update display state directly (cannot use prog.Send from Update).
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
prev := previousModel
|
||||
newModel := args
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString))
|
||||
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(args) }()
|
||||
go func() { _ = prefs.SaveModelPreference(modelString) }()
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
|
||||
return nil
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
go emit(modelString, previousModel, "user")
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -4827,61 +4795,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
return r
|
||||
}, name)
|
||||
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
tmpPath, err := buildShareFile(name, data, sysPromptJSON)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err))
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err))
|
||||
return nil
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the header.
|
||||
// The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
|
||||
if len(lines) > 0 {
|
||||
// Write header (first line)
|
||||
if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write system prompt entry
|
||||
if _, err := tmpFile.Write(sysPromptJSON); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
|
||||
return nil
|
||||
}
|
||||
if _, err := tmpFile.WriteString("\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write remaining lines
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = tmpFile.Close()
|
||||
|
||||
m.printSystemMessage("Uploading session to GitHub Gist...")
|
||||
|
||||
@@ -4907,6 +4825,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// buildShareFile assembles a temp JSONL file containing the session data
|
||||
// with the system-prompt entry inserted after the header line. On success
|
||||
// the caller owns the returned file and must remove it when done; on error
|
||||
// any partially-written temp file has already been cleaned up.
|
||||
func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) {
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath = tmpFile.Name()
|
||||
defer func() {
|
||||
_ = tmpFile.Close()
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the
|
||||
// header. The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.Write(sysPromptJSON); err != nil {
|
||||
return "", fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.WriteString("\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// handleImportCommand imports a session from a JSONL file.
|
||||
// Usage: /import path.jsonl
|
||||
func (m *AppModel) handleImportCommand(args string) tea.Cmd {
|
||||
|
||||
+2
-2
@@ -11,12 +11,12 @@ import (
|
||||
// treeManagerAdapter adapts TreeManager to SessionManager interface.
|
||||
// This is unexported - users don't interact with it directly.
|
||||
type treeManagerAdapter struct {
|
||||
inner *session.TreeManager
|
||||
inner *TreeManager
|
||||
}
|
||||
|
||||
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
|
||||
// This is used by the SDK when no custom SessionManager is provided.
|
||||
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
|
||||
func NewTreeManagerAdapter(tm *TreeManager) SessionManager {
|
||||
return &treeManagerAdapter{inner: tm}
|
||||
}
|
||||
|
||||
|
||||
+11
-22
@@ -3,6 +3,8 @@ package kit
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -103,34 +105,21 @@ type Event interface {
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
//
|
||||
// These constants re-export the canonical classification used by extension
|
||||
// events, so SDK events and extension events always agree.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
ToolKindExecute = extensions.ToolKindExecute // Shell execution (bash)
|
||||
ToolKindEdit = extensions.ToolKindEdit // File modification (edit, write)
|
||||
ToolKindRead = extensions.ToolKindRead // File reading (read, ls)
|
||||
ToolKindSearch = extensions.ToolKindSearch // Content/file search (grep, find)
|
||||
ToolKindSubagent = extensions.ToolKindSubagent // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind. MCP and extension
|
||||
// tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools.
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
return extensions.ToolKindFor(toolName)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON-encoded tool args string into a map.
|
||||
|
||||
@@ -20,3 +20,9 @@ func (m *Kit) ConfigFloatForTest(key string) float64 { return m.v.GetFloat64(key
|
||||
// ConfigBoolForTest returns the bool value of key from this Kit's isolated
|
||||
// configuration store.
|
||||
func (m *Kit) ConfigBoolForTest(key string) bool { return m.v.GetBool(key) }
|
||||
|
||||
// ConfigStringSliceForTest returns the string slice value of key from this
|
||||
// Kit's isolated configuration store.
|
||||
func (m *Kit) ConfigStringSliceForTest(key string) []string {
|
||||
return m.v.GetStringSlice(key)
|
||||
}
|
||||
|
||||
@@ -578,18 +578,13 @@ func llmUsageMeta(m *Kit, usage LLMUsage) (provider, modelID string, cost float6
|
||||
}
|
||||
|
||||
// isAnthropicOAuth reports whether the current Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token).
|
||||
// Mirrors the OAuth detection in cmd/extension_context.go's usage tracker
|
||||
// update path so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
|
||||
// to a stored OAuth token (in which case the user is not billed per-token),
|
||||
// so OnLLMUsage cost reporting agrees with ctx.GetSessionUsage().
|
||||
func isAnthropicOAuth(m *Kit, provider string) bool {
|
||||
if m == nil || provider != "anthropic" {
|
||||
return false
|
||||
}
|
||||
_, source, err := auth.GetAnthropicAPIKey(m.v.GetString("provider-api-key"))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.HasPrefix(source, "stored OAuth")
|
||||
return auth.IsAnthropicOAuth(m.v.GetString("provider-api-key"))
|
||||
}
|
||||
|
||||
// llmToContextMessages converts a slice of LLM messages to extension
|
||||
|
||||
+19
-3
@@ -1160,7 +1160,7 @@ type CLIOptions struct {
|
||||
// - Continue: resume most recent session for SessionDir (or cwd)
|
||||
// - SessionPath: open a specific JSONL session file
|
||||
// - default: create a new tree session for SessionDir (or cwd)
|
||||
func InitTreeSession(opts *Options) (*session.TreeManager, error) {
|
||||
func InitTreeSession(opts *Options) (*TreeManager, error) {
|
||||
if opts == nil {
|
||||
opts = &Options{}
|
||||
}
|
||||
@@ -1330,9 +1330,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
if !opts.NoSkills {
|
||||
// Merge viper config with opts: CLI flag / config file values are
|
||||
// already bound to viper by cmd/root.go, so v.GetBool("no-skills"),
|
||||
// v.GetStringSlice("skill"), and v.GetString("skills-dir") capture
|
||||
// both --flag and .kit.yml keys transparently.
|
||||
noSkills := opts.NoSkills || v.GetBool("no-skills")
|
||||
skillPaths := opts.Skills
|
||||
if len(skillPaths) == 0 {
|
||||
skillPaths = v.GetStringSlice("skill")
|
||||
}
|
||||
skillsDir := opts.SkillsDir
|
||||
if skillsDir == "" {
|
||||
skillsDir = v.GetString("skills-dir")
|
||||
}
|
||||
if !noSkills {
|
||||
mergedOpts := *opts
|
||||
mergedOpts.Skills = skillPaths
|
||||
mergedOpts.SkillsDir = skillsDir
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
loadedSkills, err = loadSkills(&mergedOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
|
||||
@@ -365,6 +365,81 @@ func TestNewSystemPromptFilePath(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewWithSkillsOptions verifies that the three skills-related Options
|
||||
// fields (NoSkills, Skills, SkillsDir) are wired correctly into kit.New().
|
||||
func TestNewWithSkillsOptions(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("NoSkills disables skill loading", func(t *testing.T) {
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
NoSkills: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := host.GetSkills(); len(got) != 0 {
|
||||
t.Errorf("NoSkills=true: expected 0 skills, got %d", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SkillsDir propagates", func(t *testing.T) {
|
||||
// Use a non-existent dir — no skills will load but the option must be
|
||||
// accepted without error and result in zero skills.
|
||||
dir := t.TempDir()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
SkillsDir: dir,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Empty dir → no skills; the important thing is no error.
|
||||
_ = host.GetSkills()
|
||||
})
|
||||
|
||||
t.Run("explicit Skills paths load correctly", func(t *testing.T) {
|
||||
// Write a minimal skill file to a temp dir.
|
||||
dir := t.TempDir()
|
||||
skillFile := dir + "/my-skill.md"
|
||||
content := "---\nname: test-skill\ndescription: A test skill\n---\nDo the thing.\n"
|
||||
if err := os.WriteFile(skillFile, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("failed to write skill file: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
Skills: []string{skillFile},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
skills := host.GetSkills()
|
||||
if len(skills) != 1 {
|
||||
t.Fatalf("expected 1 skill, got %d", len(skills))
|
||||
}
|
||||
if skills[0].Name != "test-skill" {
|
||||
t.Errorf("skill name = %q; want %q", skills[0].Name, "test-skill")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewSystemPromptInline confirms that inline system-prompt strings still
|
||||
// flow through unchanged after the file-path resolution change.
|
||||
func TestNewSystemPromptInline(t *testing.T) {
|
||||
|
||||
+27
-71
@@ -7,45 +7,36 @@ import (
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/prompts"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template Parsing Bridge for Extensions (Phase 3)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// varRegex matches {{variable}} placeholders in templates.
|
||||
var varRegex = regexp.MustCompile(`\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\}\}`)
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
// ParseTemplate extracts {{variables}} from template content. The template
|
||||
// grammar is shared with skill prompt templates, so a template parses
|
||||
// identically regardless of which API loads it.
|
||||
func ParseTemplate(name, content string) extensions.PromptTemplate {
|
||||
matches := varRegex.FindAllStringSubmatch(content, -1)
|
||||
vars := make([]string, 0, len(matches))
|
||||
seen := make(map[string]bool)
|
||||
for _, m := range matches {
|
||||
if len(m) > 1 && !seen[m[1]] {
|
||||
seen[m[1]] = true
|
||||
vars = append(vars, m[1])
|
||||
}
|
||||
tpl := skills.NewPromptTemplate(name, content)
|
||||
vars := tpl.Variables
|
||||
if vars == nil {
|
||||
vars = []string{}
|
||||
}
|
||||
return extensions.PromptTemplate{
|
||||
Name: name,
|
||||
Content: content,
|
||||
Name: tpl.Name,
|
||||
Content: tpl.Content,
|
||||
Variables: vars,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders.
|
||||
// Handles {{name}} and {{ name }} (any whitespace) placeholders; missing
|
||||
// variables are left as-is.
|
||||
func RenderTemplate(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return varRegex.ReplaceAllStringFunc(tpl.Content, func(m string) string {
|
||||
sub := varRegex.FindStringSubmatch(m)
|
||||
if len(sub) > 1 {
|
||||
if v, ok := vars[sub[1]]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return m
|
||||
})
|
||||
t := skills.PromptTemplate{Content: tpl.Content}
|
||||
return t.Expand(vars)
|
||||
}
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
@@ -183,44 +174,12 @@ func SimpleParseArguments(input string, count int) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// parseFields splits input respecting quoted strings.
|
||||
// parseFields splits input into arguments respecting quoted strings and
|
||||
// backslash escaping. It delegates to the canonical tokenizer in
|
||||
// internal/prompts so extension argument parsing and builtin prompt-template
|
||||
// parsing agree on grammar.
|
||||
func parseFields(input string) []string {
|
||||
var fields []string
|
||||
var current strings.Builder
|
||||
inQuote := false
|
||||
quoteChar := rune(0)
|
||||
|
||||
for _, r := range input {
|
||||
switch r {
|
||||
case '"', '\'':
|
||||
if !inQuote {
|
||||
inQuote = true
|
||||
quoteChar = r
|
||||
} else if r == quoteChar {
|
||||
inQuote = false
|
||||
quoteChar = 0
|
||||
} else {
|
||||
current.WriteRune(r)
|
||||
}
|
||||
case ' ', '\t':
|
||||
if inQuote {
|
||||
current.WriteRune(r)
|
||||
} else {
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
current.Reset()
|
||||
}
|
||||
}
|
||||
default:
|
||||
current.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
fields = append(fields, current.String())
|
||||
}
|
||||
|
||||
return fields
|
||||
return prompts.ParseCommandArgs(input)
|
||||
}
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
@@ -417,21 +376,18 @@ func MatchModelGlob(model, pattern string) bool {
|
||||
}
|
||||
|
||||
// ExtractProviderFromPath extracts provider from a path-like model string.
|
||||
//
|
||||
// Deprecated: Use GetCurrentProvider instead.
|
||||
func ExtractProviderFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
return GetCurrentProvider(model)
|
||||
}
|
||||
|
||||
// ExtractModelFromPath extracts model ID from a path-like model string.
|
||||
//
|
||||
// Deprecated: Use RemoveProviderFromModel instead, which correctly handles
|
||||
// model IDs containing "/" (e.g. "openrouter/meta/llama").
|
||||
func ExtractModelFromPath(model string) string {
|
||||
parts := strings.Split(model, "/")
|
||||
if len(parts) >= 2 {
|
||||
return parts[1]
|
||||
}
|
||||
return model
|
||||
return RemoveProviderFromModel(model)
|
||||
}
|
||||
|
||||
// IsBareModelID checks if a string is a bare model ID (no provider).
|
||||
|
||||
@@ -205,6 +205,131 @@ func TestNewZeroOptionsKeepsStreamingDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestSkillsViperKeys verifies that the three skills config keys (no-skills,
|
||||
// skill, skills-dir) flow through viper when set via a config file, matching
|
||||
// the pattern used by no-extensions and no-core-tools. This test does not
|
||||
// require an API key because it only exercises Options struct plumbing.
|
||||
func TestSkillsViperKeys(t *testing.T) {
|
||||
t.Run("NoSkills option disables skill loading", func(t *testing.T) {
|
||||
o := &kit.Options{}
|
||||
o.NoSkills = true
|
||||
if !o.NoSkills {
|
||||
t.Error("Options.NoSkills = true not reflected on struct")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Skills paths set on Options", func(t *testing.T) {
|
||||
o := &kit.Options{
|
||||
Skills: []string{"/a/skill.md", "/b/skill.md"},
|
||||
}
|
||||
if len(o.Skills) != 2 {
|
||||
t.Errorf("Options.Skills: got %d paths, want 2", len(o.Skills))
|
||||
}
|
||||
if o.Skills[0] != "/a/skill.md" {
|
||||
t.Errorf("Options.Skills[0] = %q; want %q", o.Skills[0], "/a/skill.md")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SkillsDir set on Options", func(t *testing.T) {
|
||||
o := &kit.Options{
|
||||
SkillsDir: "/custom/skills",
|
||||
}
|
||||
if o.SkillsDir != "/custom/skills" {
|
||||
t.Errorf("Options.SkillsDir = %q; want %q", o.SkillsDir, "/custom/skills")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSkillsConfigFileKeys verifies that no-skills, skill, and skills-dir
|
||||
// config file keys are read via viper and applied correctly. Requires an API
|
||||
// key because kit.New() is called to exercise the full config-load path.
|
||||
func TestSkillsConfigFileKeys(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("no-skills config key disables skill loading", func(t *testing.T) {
|
||||
// Write a config file with no-skills: true.
|
||||
cfgFile := t.TempDir() + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte("no-skills: true\n"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if got := host.GetSkills(); len(got) != 0 {
|
||||
t.Errorf("no-skills:true in config: expected 0 skills, got %d", len(got))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skill config key loads explicit skill files", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
skillFile := dir + "/cfg-skill.md"
|
||||
if err := os.WriteFile(skillFile, []byte("---\nname: cfg-skill\ndescription: from config\n---\nContent.\n"), 0o644); err != nil {
|
||||
t.Fatalf("failed to write skill file: %v", err)
|
||||
}
|
||||
|
||||
cfgContent := "skill:\n - " + skillFile + "\n"
|
||||
cfgFile := dir + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
skills := host.GetSkills()
|
||||
if len(skills) != 1 {
|
||||
t.Fatalf("expected 1 skill from config, got %d", len(skills))
|
||||
}
|
||||
if skills[0].Name != "cfg-skill" {
|
||||
t.Errorf("skill name = %q; want %q", skills[0].Name, "cfg-skill")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("skills-dir config key overrides auto-discovery root", func(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
cfgContent := "skills-dir: " + dir + "\n"
|
||||
cfgFile := dir + "/.kit.yml"
|
||||
if err := os.WriteFile(cfgFile, []byte(cfgContent), 0o644); err != nil {
|
||||
t.Fatalf("failed to write config: %v", err)
|
||||
}
|
||||
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
ConfigFile: cfgFile,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("kit.New failed: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Empty dir → 0 skills; the key point is no error during init.
|
||||
_ = host.GetSkills()
|
||||
})
|
||||
}
|
||||
|
||||
// TestNewStreamingExplicitOptOut verifies that a raw Options can still disable
|
||||
// streaming by setting Streaming to a pointer to false.
|
||||
func TestNewStreamingExplicitOptOut(t *testing.T) {
|
||||
|
||||
@@ -56,6 +56,26 @@ kit install --all # Install all extensions without prompting
|
||||
kit skill # Install the Kit extensions skill via skills.sh
|
||||
```
|
||||
|
||||
### Skills CLI flags
|
||||
|
||||
Control which skills are loaded at startup:
|
||||
|
||||
```bash
|
||||
# Load a specific skill file
|
||||
kit --skill path/to/skill.md "prompt"
|
||||
|
||||
# Load multiple skill files or directories (flag is repeatable)
|
||||
kit --skill ./skill1.md --skill ./skill2.md "prompt"
|
||||
|
||||
# Load all skills from a custom directory instead of the default locations
|
||||
kit --skills-dir /path/to/skills "prompt"
|
||||
|
||||
# Disable all skill loading (auto-discovery and explicit)
|
||||
kit --no-skills "prompt"
|
||||
```
|
||||
|
||||
Skills are auto-discovered from `~/.config/kit/skills/`, `.kit/skills/`, and `.agents/skills/` by default. Use `--skills-dir` to override the project-local search root, or `--skill` to load files explicitly (which disables auto-discovery). `--no-skills` suppresses all skill loading regardless of other flags.
|
||||
|
||||
## Interactive slash commands
|
||||
|
||||
These commands are available inside the Kit TUI during an interactive session:
|
||||
|
||||
@@ -48,6 +48,14 @@ These flags control Kit's behavior. When a prompt is passed as a positional argu
|
||||
| `--prompt-template` | — | — | Load a specific prompt template by name |
|
||||
| `--no-prompt-templates` | — | `false` | Disable prompt template loading |
|
||||
|
||||
## Skills
|
||||
|
||||
| Flag | Short | Default | Description |
|
||||
|------|-------|---------|-------------|
|
||||
| `--skill` | — | — | Load skill file or directory (repeatable) |
|
||||
| `--skills-dir` | — | — | Override the project-local skills directory for auto-discovery |
|
||||
| `--no-skills` | — | `false` | Disable skill loading (auto-discovery and explicit) |
|
||||
|
||||
## Generation parameters
|
||||
|
||||
| Flag | Short | Default | Description |
|
||||
|
||||
@@ -47,6 +47,9 @@ stream: true
|
||||
| `theme` | object or string | — | UI theme ([inline overrides or file path](/themes)) |
|
||||
| `prompt-templates` | bool | `true` | Enable prompt template loading |
|
||||
| `prompt-template` | string | — | Specific template to load by name |
|
||||
| `no-skills` | bool | `false` | Disable skill loading (auto-discovery and explicit) |
|
||||
| `skill` | list | — | Explicit skill files or directories to load (disables auto-discovery) |
|
||||
| `skills-dir` | string | — | Override the project-local directory used for skill auto-discovery |
|
||||
|
||||
## Environment variables
|
||||
|
||||
|
||||
Reference in New Issue
Block a user