mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aecce001ee | |||
| 32d73171fd | |||
| 265fd2ec0c | |||
| efebf2eba6 | |||
| f7b655ae33 | |||
| 35982b41ad | |||
| 788e3b71fd | |||
| 3496bc2684 | |||
| 997c7d15ff | |||
| 83246e47d5 | |||
| 50e7b78c33 | |||
| b937af3056 | |||
| a5e995c750 | |||
| e95e08a699 | |||
| bcaf92f62a | |||
| ead4afbfe6 | |||
| 685aaf207f | |||
| 76ff6c9639 | |||
| 1cf24ee5de | |||
| c9637090fa | |||
| 0ff0ff42ab | |||
| a4fb32ff2b | |||
| 7d2f078111 | |||
| b0b66941ab | |||
| cbb7387a72 |
@@ -28,11 +28,15 @@ type lintResult struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint on Go file edits")
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
@@ -43,30 +47,72 @@ func Init(api ext.API) {
|
||||
return nil
|
||||
}
|
||||
|
||||
report := runGoDiagnostics(ctx.CWD, absPath)
|
||||
|
||||
// Check if there are issues and add explicit prompt for the LLM to react
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
var enhanced string
|
||||
if hasIssues {
|
||||
enhanced = e.Content + "\n\n" + report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review the issues above and fix them before proceeding."
|
||||
} else {
|
||||
enhanced = e.Content + "\n\n" + report
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Show TUI message block for diagnostics visibility (only if there are issues)
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("File: %s", filepath.Base(absPath)))
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
msgLines = append(msgLines, "", "⚠️ Please fix these issues before proceeding.")
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
@@ -78,9 +124,16 @@ func Init(api ext.API) {
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
|
||||
return &ext.ToolResultResult{Content: &enhanced}
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -106,18 +159,6 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGoDiagnostics(cwd, absPath string) string {
|
||||
gopls := runGopls(cwd, absPath)
|
||||
lint := runGolangCILint(cwd, "./...")
|
||||
|
||||
return fmt.Sprintf(
|
||||
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
filepath.Base(absPath),
|
||||
formatToolResult(gopls, "No diagnostics."),
|
||||
formatToolResult(lint, "No lint issues."),
|
||||
)
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
@@ -178,7 +219,9 @@ func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
lines = append(lines, emptyFallback)
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
@@ -197,17 +240,15 @@ func truncate(s string, max int) string {
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
// Extract gopls section
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
// Count non-empty lines excluding the header and "No diagnostics." message
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." {
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
@@ -215,7 +256,6 @@ func countIssues(report string) (goplsCount, lintCount int) {
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
// Count non-empty lines excluding the header and "No lint issues." message
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
|
||||
+118
-13
@@ -154,6 +154,9 @@ func InitConfig() {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
// Rebuild the model registry now that viper has the config loaded,
|
||||
// so customModels defined in the config file are picked up.
|
||||
models.ReloadGlobalRegistry()
|
||||
}
|
||||
|
||||
// LoadConfigWithEnvSubstitution loads a config file with environment variable
|
||||
@@ -714,13 +717,20 @@ func runNormalMode(ctx context.Context) error {
|
||||
|
||||
// Build Kit options from CLI flags and create the SDK instance.
|
||||
// kit.New() handles: config → skills → agent → session → extension bridge.
|
||||
authHandler, authErr := kit.NewCLIMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote MCP servers.
|
||||
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
|
||||
}
|
||||
|
||||
kitOpts := &kit.Options{
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
Quiet: quietFlag,
|
||||
Debug: debugMode,
|
||||
NoSession: noSessionFlag,
|
||||
Continue: continueFlag,
|
||||
SessionPath: sessionPath,
|
||||
AutoCompact: autoCompactFlag,
|
||||
MCPAuthHandler: authHandler,
|
||||
CLI: &kit.CLIOptions{
|
||||
MCPConfig: mcpConfig,
|
||||
ShowSpinner: true,
|
||||
@@ -793,6 +803,13 @@ func runNormalMode(ctx context.Context) error {
|
||||
appInstance := app.New(appOpts, messages)
|
||||
defer appInstance.Close()
|
||||
|
||||
// Wire OAuth handler to route messages through the TUI once it's running.
|
||||
if authHandler != nil {
|
||||
authHandler.NotifyFunc = func(serverName, message string) {
|
||||
appInstance.PrintFromExtension("info", message)
|
||||
}
|
||||
}
|
||||
|
||||
// Buffer for extension messages during startup (printed after startup banner).
|
||||
var startupExtensionMessages []string
|
||||
|
||||
@@ -816,7 +833,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
@@ -1237,7 +1284,37 @@ func runNormalMode(ctx context.Context) error {
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
@@ -1605,9 +1682,36 @@ func runNormalMode(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build extension reload callback for the /reload-ext command.
|
||||
reloadExtensionsForUI := func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start file watcher for automatic extension hot-reload.
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
watchDirs := extensions.WatchedDirs(extraPaths)
|
||||
if len(watchDirs) > 0 {
|
||||
extWatcher, watchErr := extensions.NewWatcher(watchDirs, func() {
|
||||
if err := reloadExtensionsForUI(); err != nil {
|
||||
log.Printf("auto-reload extensions failed: %v", err)
|
||||
}
|
||||
})
|
||||
if watchErr != nil {
|
||||
log.Printf("extension file watcher not started: %v", watchErr)
|
||||
} else {
|
||||
go extWatcher.Start(ctx)
|
||||
defer func() { _ = extWatcher.Close() }()
|
||||
}
|
||||
}
|
||||
|
||||
// Check if running in non-interactive mode
|
||||
if positionalPrompt != "" {
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
|
||||
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
|
||||
}
|
||||
|
||||
// Quiet mode is not allowed in interactive mode
|
||||
@@ -1615,7 +1719,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
return fmt.Errorf("--quiet requires a prompt")
|
||||
}
|
||||
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
|
||||
}
|
||||
|
||||
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
|
||||
@@ -1628,7 +1732,7 @@ func runNormalMode(ctx context.Context) error {
|
||||
//
|
||||
// When --no-exit is set, after the prompt completes the interactive BubbleTea
|
||||
// TUI is started so the user can continue the conversation.
|
||||
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
|
||||
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
|
||||
// Expand @file references in the prompt before sending to the agent.
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
prompt = ui.ProcessFileAttachments(prompt, cwd)
|
||||
@@ -1671,7 +1775,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
|
||||
|
||||
// If --no-exit was requested, hand off to the interactive TUI.
|
||||
if noExit {
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
|
||||
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1769,7 +1873,7 @@ func writeJSONError(err error) {
|
||||
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
|
||||
//
|
||||
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
|
||||
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
|
||||
// Determine terminal size; fall back gracefully.
|
||||
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
|
||||
if err != nil || termWidth == 0 {
|
||||
@@ -1813,6 +1917,7 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
IsReasoningModel: isReasoningModel,
|
||||
SetThinkingLevel: setThinkingLevel,
|
||||
SwitchSession: switchSession,
|
||||
ReloadExtensions: reloadExtensions,
|
||||
ShowSessionPicker: resumeFlag,
|
||||
})
|
||||
|
||||
|
||||
@@ -7,10 +7,12 @@
|
||||
// development: edit your extension source, then type /reload to pick up
|
||||
// changes immediately.
|
||||
//
|
||||
// Event handlers, slash commands, tool renderers, message renderers, and
|
||||
// keyboard shortcuts update immediately. Extension-defined tools are NOT
|
||||
// updated (they are baked into the agent at creation time and require a
|
||||
// restart).
|
||||
// Note: Extensions in autoloaded directories (~/.config/kit/extensions/
|
||||
// and .kit/extensions/) are automatically reloaded on save. The /reload
|
||||
// command is useful for extensions loaded via -e from other locations.
|
||||
//
|
||||
// Event handlers, slash commands, tool definitions, tool renderers,
|
||||
// message renderers, and keyboard shortcuts all update immediately.
|
||||
//
|
||||
// Commands:
|
||||
// /reload — hot-reload all extensions from disk
|
||||
|
||||
@@ -168,6 +168,10 @@ var (
|
||||
// Test
|
||||
pendingTest *PendingTest
|
||||
|
||||
// Typing indicator
|
||||
typingTicker *time.Ticker
|
||||
typingStop chan struct{}
|
||||
|
||||
// Latest context for background goroutines
|
||||
latestCtx ext.Context
|
||||
latestCtxSet bool
|
||||
@@ -203,8 +207,23 @@ func configDir() string {
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func configPath() string {
|
||||
return filepath.Join(configDir(), "kit-telegram.json")
|
||||
// Prefer project-local config, fall back to global config.
|
||||
local := filepath.Join(configDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
return local
|
||||
}
|
||||
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(global); err == nil {
|
||||
return global
|
||||
}
|
||||
// Neither exists — return local path (will be created on connect).
|
||||
return local
|
||||
}
|
||||
|
||||
func failureLogDir() string {
|
||||
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func tgSendChatAction(token string, chatID int64, action string) error {
|
||||
_, err := telegramRequest(token, "sendChatAction", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"action": action,
|
||||
}, 15)
|
||||
return err
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Error classification
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -637,6 +664,48 @@ func clearHealthTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Typing indicator
|
||||
// ──────────────────────────────────────────────
|
||||
|
||||
func startTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
return
|
||||
}
|
||||
cfg := config
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return
|
||||
}
|
||||
token := cfg.BotToken
|
||||
chatID := cfg.ChatID
|
||||
typingTicker = time.NewTicker(4 * time.Second)
|
||||
typingStop = make(chan struct{})
|
||||
// Send immediately, then every 4 seconds.
|
||||
go func() {
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
for {
|
||||
select {
|
||||
case <-typingTicker.C:
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
case <-typingStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
typingTicker.Stop()
|
||||
close(typingStop)
|
||||
typingTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Polling lifecycle
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
sendShutdownDisconnectedMessage()
|
||||
stopTypingLoop()
|
||||
stopPolling()
|
||||
clearHealthTimer()
|
||||
clearFooter()
|
||||
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
report("run.start", fmt.Sprintf("runId=%d", run.ID))
|
||||
startTypingLoop()
|
||||
ensureProgressMessage()
|
||||
updateProgressMessage()
|
||||
})
|
||||
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
|
||||
run := activeRun
|
||||
mu.Unlock()
|
||||
|
||||
stopTypingLoop()
|
||||
|
||||
if run != nil {
|
||||
// Capture final response from event
|
||||
if e.Response != "" {
|
||||
|
||||
+66
-4
@@ -25,6 +25,11 @@ type AgentConfig struct {
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
@@ -63,6 +68,10 @@ type ToolCallContentHandler func(content string)
|
||||
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
|
||||
type ReasoningDeltaHandler func(delta string)
|
||||
|
||||
// ReasoningCompleteHandler is a function type for handling reasoning/thinking completion.
|
||||
// Called when the last reasoning token has been processed, before text streaming starts.
|
||||
type ReasoningCompleteHandler func()
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
@@ -135,6 +144,10 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
toolManager = tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
@@ -231,7 +244,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
@@ -242,6 +255,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
@@ -295,6 +309,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning/thinking complete callback
|
||||
OnReasoningEnd: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text streaming callback
|
||||
OnTextDelta: func(id, text string) error {
|
||||
if ctx.Err() != nil {
|
||||
@@ -381,7 +406,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
@@ -398,9 +423,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
@@ -623,6 +648,43 @@ func (a *Agent) GetExtensionToolCount() int {
|
||||
return len(a.extraTools)
|
||||
}
|
||||
|
||||
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
|
||||
// tools) and rebuilds the internal agent with the updated tool list. The
|
||||
// model, system prompt, and all other configuration are preserved.
|
||||
func (a *Agent) SetExtraTools(tools []fantasy.AgentTool) {
|
||||
a.extraTools = tools
|
||||
|
||||
// Rebuild tool list (same as NewAgent / SetModel).
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild agent options with the existing model.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
|
||||
}
|
||||
if a.maxSteps > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithStopConditions(
|
||||
fantasy.StepCountIs(a.maxSteps),
|
||||
))
|
||||
}
|
||||
|
||||
// Swap the fantasy agent (model and provider are unchanged).
|
||||
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
|
||||
@@ -36,6 +36,8 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
@@ -56,6 +58,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
|
||||
+15
-4
@@ -1,6 +1,17 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// SteerMessage carries a steering prompt and optional file attachments
|
||||
// (e.g. clipboard images) through the steer channel.
|
||||
type SteerMessage struct {
|
||||
Text string
|
||||
Files []fantasy.FilePart
|
||||
}
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
@@ -11,7 +22,7 @@ type steerConsumedKey struct{}
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan SteerMessage) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
@@ -23,8 +34,8 @@ func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.C
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
func steerChFromContext(ctx context.Context) <-chan SteerMessage {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan SteerMessage)
|
||||
return ch
|
||||
}
|
||||
|
||||
|
||||
+106
-5
@@ -162,6 +162,24 @@ func (a *App) CancelCurrentStep() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// IsBusy returns true when the agent is currently processing a turn.
|
||||
func (a *App) IsBusy() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.busy
|
||||
}
|
||||
|
||||
// Abort cancels the current agent step (if running) and clears the queue.
|
||||
// Unlike InterruptAndSend, no new message is injected — the agent simply
|
||||
// stops. Safe to call when idle (no-op).
|
||||
func (a *App) Abort() {
|
||||
a.mu.Lock()
|
||||
a.queue = a.queue[:0]
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// QueueLength returns the number of prompts currently waiting in the queue.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
@@ -187,6 +205,15 @@ func (a *App) QueueLength() int {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
return a.SteerWithFiles(prompt, nil)
|
||||
}
|
||||
|
||||
// SteerWithFiles injects a steering message with optional file attachments
|
||||
// (e.g. pasted images) into the currently running agent turn. Behaves like
|
||||
// Steer but includes file parts alongside the text.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -195,8 +222,8 @@ func (a *App) Steer(prompt string) int {
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
// Not busy — start immediately, same as RunWithFiles().
|
||||
item := queueItem{Prompt: prompt, Files: files}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
@@ -211,7 +238,7 @@ func (a *App) Steer(prompt string) int {
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
a.opts.Kit.InjectSteerWithFiles(prompt, files)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -390,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("cannot compact while the agent is working")
|
||||
}
|
||||
if a.opts.Kit == nil {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("SDK instance not available")
|
||||
}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
if onError != nil {
|
||||
onError(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
|
||||
if onError != nil {
|
||||
onError("nothing to compact")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -530,8 +629,8 @@ func (a *App) drainQueue(first queueItem) {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
@@ -788,6 +887,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
sendFn(StreamChunkEvent{Content: ev.Chunk})
|
||||
case kit.ReasoningDeltaEvent:
|
||||
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
|
||||
case kit.ReasoningCompleteEvent:
|
||||
sendFn(ReasoningCompleteEvent{})
|
||||
case kit.ToolOutputEvent:
|
||||
sendFn(ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
|
||||
@@ -16,6 +16,11 @@ type ReasoningChunkEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
// ReasoningCompleteEvent is sent when reasoning/thinking is finished, after
|
||||
// the last reasoning token has been processed. The TUI uses this to freeze
|
||||
// the reasoning duration counter.
|
||||
type ReasoningCompleteEvent struct{}
|
||||
|
||||
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
|
||||
// It carries the tool name and its arguments for display purposes.
|
||||
type ToolCallStartedEvent struct {
|
||||
|
||||
@@ -162,6 +162,8 @@ type Theme struct {
|
||||
// and merged into the custom provider in the model registry.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
|
||||
+1
-20
@@ -67,7 +67,7 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return readDirectory(absPath)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("'%s' is a directory, not a file. Use the ls tool to list directory contents.", args.Path)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
@@ -116,25 +116,6 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
func readDirectory(absPath string) (fantasy.ToolResponse, error) {
|
||||
entries, err := os.ReadDir(absPath)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to read directory: %v", err)), nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() {
|
||||
name += "/"
|
||||
}
|
||||
result.WriteString(name + "\n")
|
||||
}
|
||||
|
||||
tr := truncateHead(result.String(), 500, defaultMaxBytes)
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
// resolvePathWithWorkDir resolves a path to an absolute path relative to the
|
||||
// given workDir. If workDir is empty, os.Getwd() is used.
|
||||
func resolvePathWithWorkDir(path, workDir string) (string, error) {
|
||||
|
||||
+26
-33
@@ -130,13 +130,22 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
), fmt.Errorf("no subagent spawner in context")
|
||||
}
|
||||
|
||||
// Detach from the parent's deadline so the subagent gets its own
|
||||
// independent timeout (applied downstream in Kit.Subagent). The parent
|
||||
// context may carry a tight deadline from the LLM generation loop or
|
||||
// other tool timeouts that would prematurely kill the subagent.
|
||||
// We preserve context values (spawner, etc.) and propagate parent
|
||||
// cancellation (e.g. user hits Ctrl-C) without inheriting the deadline.
|
||||
spawnCtx := detachedWithCancel(ctx)
|
||||
// Build a clean context for the subagent that inherits values (e.g. the
|
||||
// spawner callback) but is completely detached from the parent's
|
||||
// deadline AND cancellation. The subagent gets its own independent
|
||||
// timeout (applied downstream in Kit.Subagent).
|
||||
//
|
||||
// Why full detachment instead of propagating parent cancellation?
|
||||
// The parent context may already be done (deadline exceeded or
|
||||
// cancelled) by the time this tool handler executes — for example when
|
||||
// the generation loop context carries a deadline, when the user
|
||||
// double-ESC cancels mid-turn, or when parallel tool execution
|
||||
// encounters a race between stream completion and tool dispatch. Using
|
||||
// context.WithoutCancel (Go 1.21+) ensures the subagent always starts
|
||||
// cleanly with a fresh timeout, following the pattern used by crush for
|
||||
// shutdown-resilient child work. The subagent's own timeout
|
||||
// (defaultSubagentTimeout / user-specified) provides the safety net.
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: ctx})
|
||||
|
||||
// Spawn in-process subagent.
|
||||
result, err := spawner(spawnCtx, call.ID, args.Task, args.Model, args.SystemPrompt, timeout)
|
||||
@@ -173,37 +182,21 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context detachment
|
||||
// Context helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// detachedContext wraps a parent context, preserving its values but removing
|
||||
// its deadline and cancellation. This allows the subagent to have its own
|
||||
// independent timeout while still accessing context-stored values (e.g. the
|
||||
// subagent spawner function).
|
||||
type detachedContext struct {
|
||||
// valuesContext preserves a parent context's values (e.g. the subagent
|
||||
// spawner callback) while stripping its deadline and cancellation. Combined
|
||||
// with context.WithoutCancel() this gives the subagent a completely clean
|
||||
// context that only inherits value-based dependencies.
|
||||
type valuesContext struct {
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (d detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (d detachedContext) Done() <-chan struct{} { return nil }
|
||||
func (d detachedContext) Err() error { return nil }
|
||||
func (d detachedContext) Value(key any) any { return d.parent.Value(key) }
|
||||
|
||||
// detachedWithCancel creates a new context that inherits values from the
|
||||
// parent but has no deadline. Cancellation of the parent is propagated: when
|
||||
// the parent is cancelled the returned context is also cancelled, but the
|
||||
// parent's deadline does not apply to the child.
|
||||
func detachedWithCancel(parent context.Context) context.Context {
|
||||
child, cancel := context.WithCancel(detachedContext{parent: parent})
|
||||
go func() {
|
||||
select {
|
||||
case <-parent.Done():
|
||||
cancel()
|
||||
case <-child.Done():
|
||||
}
|
||||
}()
|
||||
return child
|
||||
}
|
||||
func (v valuesContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (v valuesContext) Done() <-chan struct{} { return nil }
|
||||
func (v valuesContext) Err() error { return nil }
|
||||
func (v valuesContext) Value(key any) any { return v.parent.Value(key) }
|
||||
|
||||
// truncateResponse limits the response length to avoid overwhelming context windows.
|
||||
func truncateResponse(s string, maxLen int) string {
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValuesContext_StripsDeadlineAndCancellation(t *testing.T) {
|
||||
// Parent with a tight deadline.
|
||||
parent, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(5 * time.Millisecond) // Let deadline expire.
|
||||
|
||||
if parent.Err() == nil {
|
||||
t.Fatal("expected parent to be expired")
|
||||
}
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
if _, ok := vc.Deadline(); ok {
|
||||
t.Error("valuesContext should report no deadline")
|
||||
}
|
||||
if vc.Done() != nil {
|
||||
t.Error("valuesContext.Done() should return nil")
|
||||
}
|
||||
if vc.Err() != nil {
|
||||
t.Errorf("valuesContext.Err() should be nil, got %v", vc.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesContext_PreservesValues(t *testing.T) {
|
||||
type testKey struct{}
|
||||
parent := context.WithValue(context.Background(), testKey{}, "hello")
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
got, ok := vc.Value(testKey{}).(string)
|
||||
if !ok || got != "hello" {
|
||||
t.Errorf("expected value 'hello', got %q (ok=%v)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesCancelledParent(t *testing.T) {
|
||||
// Simulate the exact scenario from the bug: the parent generation
|
||||
// context is already cancelled when the subagent tool handler runs.
|
||||
parent, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancelled before detach.
|
||||
|
||||
// This is what executeSubagent now does:
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
// The spawn context must be alive.
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive, got err: %v", spawnCtx.Err())
|
||||
}
|
||||
|
||||
// Adding a timeout should produce a working context.
|
||||
tCtx, tCancel := context.WithTimeout(spawnCtx, 5*time.Second)
|
||||
defer tCancel()
|
||||
|
||||
if tCtx.Err() != nil {
|
||||
t.Fatalf("timeout context should be alive, got err: %v", tCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesDeadlineExceededParent(t *testing.T) {
|
||||
// Simulate: parent had a deadline that already expired.
|
||||
parent, pCancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer pCancel()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if parent.Err() != context.DeadlineExceeded {
|
||||
t.Fatalf("expected parent deadline exceeded, got: %v", parent.Err())
|
||||
}
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive after deadline-exceeded parent, got: %v", spawnCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_PreservesSpawnerValue(t *testing.T) {
|
||||
// Verify the subagent spawner callback survives context detachment.
|
||||
called := false
|
||||
spawner := SubagentSpawnFunc(func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error) {
|
||||
called = true
|
||||
return &SubagentSpawnResult{Response: "ok"}, nil
|
||||
})
|
||||
|
||||
parent := WithSubagentSpawner(context.Background(), spawner)
|
||||
// Cancel the parent.
|
||||
parentCtx, cancel := context.WithCancel(parent)
|
||||
cancel()
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parentCtx})
|
||||
|
||||
// Should be able to retrieve the spawner from the detached context.
|
||||
recovered := getSubagentSpawner(spawnCtx)
|
||||
if recovered == nil {
|
||||
t.Fatal("spawner should be recoverable from detached context")
|
||||
}
|
||||
|
||||
result, err := recovered(spawnCtx, "tc1", "test task", "", "", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("spawner call failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("spawner was not called")
|
||||
}
|
||||
if result.Response != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", result.Response)
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,64 @@ type Context struct {
|
||||
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
|
||||
CancelAndSend func(string)
|
||||
|
||||
// Abort cancels the current agent turn (if running) and clears the
|
||||
// message queue. Unlike CancelAndSend, no new message is injected —
|
||||
// the agent simply stops. Safe to call when idle (no-op).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.Abort() // stop whatever the agent is doing
|
||||
Abort func()
|
||||
|
||||
// IsIdle returns true when the agent is not processing a turn.
|
||||
// Extensions can use this to decide whether to dispatch immediately
|
||||
// or queue work for later.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if ctx.IsIdle() {
|
||||
// ctx.SendMessage("start new task")
|
||||
// }
|
||||
IsIdle func() bool
|
||||
|
||||
// Compact triggers context compaction, summarising older messages to
|
||||
// free context window space. Returns an error if compaction cannot
|
||||
// start (e.g. agent is busy or app is closed). The actual compaction
|
||||
// runs asynchronously; use OnComplete/OnError callbacks in
|
||||
// CompactConfig to observe the result.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.Compact(ext.CompactConfig{
|
||||
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
|
||||
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
|
||||
// })
|
||||
Compact func(CompactConfig) error
|
||||
|
||||
// SendMultimodalMessage injects a message with file attachments (images,
|
||||
// documents) into the conversation and triggers a new agent turn. Files
|
||||
// are described by FilePart structs containing the raw bytes, filename,
|
||||
// and MIME type. If the agent is busy the message is queued.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := os.ReadFile("photo.jpg")
|
||||
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
|
||||
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
|
||||
// })
|
||||
SendMultimodalMessage func(text string, files []FilePart)
|
||||
|
||||
// GetSessionUsage returns aggregated token usage and cost statistics
|
||||
// for the current session. This includes total input/output tokens,
|
||||
// cache read/write tokens, total cost, and request count.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// usage := ctx.GetSessionUsage()
|
||||
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
|
||||
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
|
||||
GetSessionUsage func() SessionUsage
|
||||
|
||||
// SetWidget places or updates a persistent widget in the TUI. Widgets
|
||||
// remain visible across agent turns until explicitly removed. The
|
||||
// widget is identified by WidgetConfig.ID; calling SetWidget with the
|
||||
@@ -937,6 +995,48 @@ type StatusBarEntry struct {
|
||||
Priority int
|
||||
}
|
||||
|
||||
// CompactConfig configures a programmatic context compaction request.
|
||||
type CompactConfig struct {
|
||||
// CustomInstructions is optional text appended to the summary prompt
|
||||
// (e.g. "Focus on the API design decisions"). Empty uses the default.
|
||||
CustomInstructions string
|
||||
// OnComplete is called when compaction finishes successfully.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnComplete func()
|
||||
// OnError is called when compaction fails. The argument is the error message.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnError func(errMsg string)
|
||||
}
|
||||
|
||||
// FilePart describes a file attachment for multimodal messages. Extensions
|
||||
// use this with SendMultimodalMessage to attach images or documents.
|
||||
type FilePart struct {
|
||||
// Filename is the name of the file (e.g. "photo.jpg").
|
||||
Filename string
|
||||
// Data is the raw file content.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// SessionUsage contains aggregated token usage and cost statistics for
|
||||
// the current session. Extensions use this with GetSessionUsage() to
|
||||
// report usage information.
|
||||
type SessionUsage struct {
|
||||
// TotalInputTokens is the sum of input tokens across all requests.
|
||||
TotalInputTokens int
|
||||
// TotalOutputTokens is the sum of output tokens across all requests.
|
||||
TotalOutputTokens int
|
||||
// TotalCacheReadTokens is the sum of cache read tokens.
|
||||
TotalCacheReadTokens int
|
||||
// TotalCacheWriteTokens is the sum of cache write tokens.
|
||||
TotalCacheWriteTokens int
|
||||
// TotalCost is the total cost in USD across all requests.
|
||||
TotalCost float64
|
||||
// RequestCount is the number of LLM requests made in this session.
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
|
||||
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
return i.install(source, scope, nil)
|
||||
}
|
||||
|
||||
// install is the internal implementation that supports optional include paths.
|
||||
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Include: includePaths,
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
// Update manifest timestamp, preserving existing fields like Include
|
||||
existing, _ := i.loadManifest(scope)
|
||||
var include []string
|
||||
var installed time.Time
|
||||
if existing != nil {
|
||||
for _, p := range existing.Packages {
|
||||
if p.Host+"/"+p.Path == source.Identity() {
|
||||
include = p.Include
|
||||
installed = p.Installed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if installed.IsZero() {
|
||||
installed = time.Now()
|
||||
}
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Installed: installed,
|
||||
Updated: time.Now(),
|
||||
Include: include,
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return i.install(source, scope, includePaths)
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
|
||||
@@ -133,7 +133,7 @@ func findExtensionsInDir(dir string) []string {
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
@@ -190,9 +190,13 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
@@ -202,13 +206,6 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -227,7 +224,7 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -86,6 +86,21 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.Abort == nil {
|
||||
ctx.Abort = func() {}
|
||||
}
|
||||
if ctx.IsIdle == nil {
|
||||
ctx.IsIdle = func() bool { return true }
|
||||
}
|
||||
if ctx.Compact == nil {
|
||||
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
|
||||
}
|
||||
if ctx.SendMultimodalMessage == nil {
|
||||
ctx.SendMultimodalMessage = func(string, []FilePart) {}
|
||||
}
|
||||
if ctx.GetSessionUsage == nil {
|
||||
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
|
||||
// Session types
|
||||
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
|
||||
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
|
||||
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
|
||||
|
||||
// Option types
|
||||
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
|
||||
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
|
||||
// LLM completion types
|
||||
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
|
||||
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
|
||||
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
|
||||
"FilePart": reflect.ValueOf((*FilePart)(nil)),
|
||||
|
||||
// Status bar types
|
||||
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Debug("watcher: skipping directory", "dir", dir, "err", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Debug("watcher: skipping subdirectory", "dir", subdir, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("watcher: file changed", "file", event.Name, "op", event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Debug("watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Warn("watcher: error", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
// This includes the global extensions directory and the project-local
|
||||
// .kit/extensions/ directory (if they exist). Explicit -e paths that
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatcher_ReloadsOnGoFileChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write an initial extension file.
|
||||
extFile := filepath.Join(dir, "test.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Modify the file.
|
||||
time.Sleep(50 * time.Millisecond) // let watcher settle
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for debounce (300ms) + margin.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_IgnoresNonGoFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Write a non-.go file.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
txtFile := filepath.Join(dir, "notes.txt")
|
||||
if err := os.WriteFile(txtFile, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait past the debounce window.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 0 {
|
||||
t.Errorf("expected 0 reloads for .txt file, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_Debounces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
extFile := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Rapid-fire writes (simulating editor save: write temp, rename, etc.).
|
||||
for range 5 {
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for debounce to fire.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 debounced reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_Deduplicates(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dirs := WatchedDirs([]string{dir, dir})
|
||||
|
||||
count := 0
|
||||
for _, d := range dirs {
|
||||
abs, _ := filepath.Abs(dir)
|
||||
if d == abs {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected directory to appear once, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_FileParent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(file, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dirs := WatchedDirs([]string{file})
|
||||
|
||||
abs, _ := filepath.Abs(dir)
|
||||
found := false
|
||||
for _, d := range dirs {
|
||||
if d == abs {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected parent dir %s in watched dirs %v", abs, dirs)
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,27 @@ type AgentSetupOptions struct {
|
||||
// wrapping. Used by the SDK hook system. Both wrappers compose:
|
||||
// extension wrapper runs first (inner), then this wrapper (outer).
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
Debug bool
|
||||
// NoExtensions skips extension loading. When false, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
NoExtensions bool
|
||||
// MaxSteps overrides the agent step limit. 0 means use viper value.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
|
||||
// is also set.
|
||||
StreamingEnabled bool
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -88,15 +109,36 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
// SetupAgent creates an agent from the current viper state + the provided
|
||||
// options. It wraps BuildProviderConfig and agent.CreateAgent.
|
||||
func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, error) {
|
||||
modelConfig, systemPrompt, err := BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
var bufferedLogger *tools.BufferedDebugLogger
|
||||
if viper.GetBool("debug") {
|
||||
if debugEnabled {
|
||||
if opts.UseBufferedLogger {
|
||||
bufferedLogger = tools.NewBufferedDebugLogger(true)
|
||||
debugLogger = bufferedLogger
|
||||
@@ -108,7 +150,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
// Load extensions unless --no-extensions is set.
|
||||
var extRunner *extensions.Runner
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !viper.GetBool("no-extensions") {
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
if extErr != nil {
|
||||
@@ -140,12 +182,13 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streamingEnabled,
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
|
||||
@@ -37,6 +37,8 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Attachment: cfg.Attachment,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
Cost: Cost{
|
||||
Input: cfg.Cost.Input,
|
||||
Output: cfg.Cost.Output,
|
||||
@@ -52,6 +54,8 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
|
||||
+21
-134
@@ -10,7 +10,6 @@ import (
|
||||
"maps"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -525,13 +524,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
|
||||
return new(openai.ReasoningEffortLow)
|
||||
case ThinkingMedium:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
|
||||
return new(openai.ReasoningEffortMedium)
|
||||
case ThinkingHigh:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
|
||||
return new(openai.ReasoningEffortHigh)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -1000,139 +999,29 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// thinkTagRegex matches <think>...</think> tags for extracting reasoning content
|
||||
// from models that wrap thinking in XML-like tags (e.g., Qwen, DeepSeek).
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)<think>(.*?)</think>`)
|
||||
|
||||
// customExtraContentFunc extracts reasoning from <think> tags in the content field.
|
||||
// This handles models like Qwen and DeepSeek that return reasoning wrapped in XML tags
|
||||
// rather than using a separate reasoning_content field.
|
||||
func customExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
|
||||
var content []fantasy.Content
|
||||
if choice.Message.Content == "" {
|
||||
return content
|
||||
}
|
||||
|
||||
// Check for <think> tags in the content
|
||||
matches := thinkTagRegex.FindStringSubmatch(choice.Message.Content)
|
||||
if len(matches) > 1 {
|
||||
// Found reasoning content in <think> tags
|
||||
reasoning := strings.TrimSpace(matches[1])
|
||||
if reasoning != "" {
|
||||
content = append(content, fantasy.ReasoningContent{
|
||||
Text: reasoning,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return content
|
||||
}
|
||||
|
||||
// customStreamExtraFunc handles streaming responses with <think> tags.
|
||||
// It extracts reasoning content and emits proper reasoning events.
|
||||
func customStreamExtraFunc(
|
||||
chunk openaisdk.ChatCompletionChunk,
|
||||
yield func(fantasy.StreamPart) bool,
|
||||
ctx map[string]any,
|
||||
) (map[string]any, bool) {
|
||||
if len(chunk.Choices) == 0 {
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
const reasoningStartedKey = "reasoning_started"
|
||||
const reasoningBufferKey = "reasoning_buffer"
|
||||
const inThinkTagKey = "in_think_tag"
|
||||
|
||||
reasoningStarted, _ := ctx[reasoningStartedKey].(bool)
|
||||
inThinkTag, _ := ctx[inThinkTagKey].(bool)
|
||||
reasoningBuffer, _ := ctx[reasoningBufferKey].(string)
|
||||
|
||||
for i, choice := range chunk.Choices {
|
||||
content := choice.Delta.Content
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for <think> tag start
|
||||
if strings.Contains(content, "<think>") {
|
||||
inThinkTag = true
|
||||
ctx[inThinkTagKey] = true
|
||||
|
||||
// Emit reasoning start event
|
||||
if !reasoningStarted {
|
||||
reasoningStarted = true
|
||||
ctx[reasoningStartedKey] = true
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningStart,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
}
|
||||
|
||||
// Extract content after <think>
|
||||
parts := strings.SplitN(content, "<think>", 2)
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
reasoningBuffer += parts[1]
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for </think> tag end
|
||||
if strings.Contains(content, "</think>") {
|
||||
inThinkTag = false
|
||||
ctx[inThinkTagKey] = false
|
||||
|
||||
// Extract content before </think>
|
||||
parts := strings.SplitN(content, "</think>", 2)
|
||||
if len(parts) > 0 {
|
||||
reasoningBuffer += parts[0]
|
||||
}
|
||||
|
||||
// Emit the accumulated reasoning
|
||||
if reasoningBuffer != "" {
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningDelta,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
Delta: reasoningBuffer,
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
ctx[reasoningBufferKey] = ""
|
||||
}
|
||||
|
||||
// Emit reasoning end
|
||||
if !yield(fantasy.StreamPart{
|
||||
Type: fantasy.StreamPartTypeReasoningEnd,
|
||||
ID: fmt.Sprintf("%d", i),
|
||||
}) {
|
||||
return ctx, false
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Accumulate reasoning content while in think tag
|
||||
if inThinkTag {
|
||||
reasoningBuffer += content
|
||||
ctx[reasoningBufferKey] = reasoningBuffer
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, true
|
||||
}
|
||||
|
||||
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
|
||||
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
|
||||
return openai.DefaultToPrompt(prompt, systemPrompt, user)
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
// Resolve base URL: per-model override > global provider-url flag/config
|
||||
registry := GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel("custom", modelName)
|
||||
|
||||
baseURL := config.ProviderURL
|
||||
if modelInfo != nil && modelInfo.BaseURL != "" {
|
||||
baseURL = modelInfo.BaseURL
|
||||
}
|
||||
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url or a baseUrl in the model config")
|
||||
}
|
||||
|
||||
apiKey := config.ProviderAPIKey
|
||||
if modelInfo != nil && modelInfo.APIKey != "" {
|
||||
apiKey = modelInfo.APIKey
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("CUSTOM_API_KEY")
|
||||
}
|
||||
@@ -1141,15 +1030,13 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
// Use the openai provider directly with custom hooks to handle <think> tags
|
||||
// from models like Qwen and DeepSeek that wrap reasoning in XML tags.
|
||||
// <think> tag extraction is handled transparently at the agent layer,
|
||||
// so no provider-level hooks are needed here.
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithName("custom"))
|
||||
opts = append(opts, openai.WithLanguageModelOptions(
|
||||
openai.WithLanguageModelExtraContentFunc(customExtraContentFunc),
|
||||
openai.WithLanguageModelStreamExtraFunc(customStreamExtraFunc),
|
||||
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
|
||||
))
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ type ModelInfo struct {
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
@@ -367,8 +369,8 @@ func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
// Ollama and custom are always supported (model names are user-defined).
|
||||
if providerID == "ollama" || providerID == "custom" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -400,6 +402,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
return &info
|
||||
}
|
||||
|
||||
// ValidateModelString checks whether a model string is well-formed and refers
|
||||
// to a known provider. It returns a user-friendly error with suggestions when
|
||||
// the model or provider is unrecognised. Passing validation does not guarantee
|
||||
// that API authentication will succeed — it only catches obvious mistakes
|
||||
// (typos, missing provider prefix, non-existent provider names) early so that
|
||||
// callers such as subagent spawning can return fast feedback.
|
||||
//
|
||||
// Unknown models under a known provider are allowed (the provider API is the
|
||||
// authority), but a completely unknown provider is rejected.
|
||||
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ollama and custom are always valid — model names are user-defined.
|
||||
if provider == "ollama" || provider == "custom" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the provider exists in the registry.
|
||||
providerInfo := r.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
known := r.GetSupportedProviders()
|
||||
return fmt.Errorf(
|
||||
"unknown provider %q in model string %q. Known providers: %s",
|
||||
provider, modelString, strings.Join(known, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// Provider exists — check if the model is known. An unknown model is
|
||||
// only a warning (the provider API decides), but we surface suggestions
|
||||
// so the caller can self-correct.
|
||||
if r.LookupModel(provider, modelName) == nil {
|
||||
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
return fmt.Errorf(
|
||||
"model %q not found for provider %s. Did you mean one of: %s",
|
||||
modelName, provider, strings.Join(suggestions, ", "),
|
||||
)
|
||||
}
|
||||
// No suggestions — let it through; the provider API is the authority.
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateModelString(t *testing.T) {
|
||||
registry := GetGlobalRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantErr bool
|
||||
errSubstr string // expected substring in error message (empty = don't check)
|
||||
}{
|
||||
{
|
||||
name: "valid anthropic model",
|
||||
model: "anthropic/claude-sonnet-4-6",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider prefix",
|
||||
model: "claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
model: "fakeprovider/some-model",
|
||||
wantErr: true,
|
||||
errSubstr: "unknown provider",
|
||||
},
|
||||
{
|
||||
name: "ollama always valid",
|
||||
model: "ollama/llama3",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "custom always valid",
|
||||
model: "custom/my-fine-tune",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty provider",
|
||||
model: "/claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
model: "anthropic/",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown model under known provider (no suggestions)",
|
||||
model: "anthropic/totally-unknown-xyz-999",
|
||||
wantErr: false, // no suggestions → passes through
|
||||
},
|
||||
{
|
||||
name: "typo model under known provider with suggestions",
|
||||
model: "anthropic/claude-sonet", // misspelled "sonnet"
|
||||
wantErr: true,
|
||||
errSubstr: "Did you mean",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := registry.ValidateModelString(tt.model)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("ValidateModelString(%q) = nil, want error", tt.model)
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err)
|
||||
}
|
||||
if tt.errSubstr != "" && err != nil {
|
||||
if !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("ValidateModelString(%q) error = %q, want substring %q",
|
||||
tt.model, err.Error(), tt.errSubstr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -68,6 +68,7 @@ type MCPConnectionPool struct {
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
@@ -75,7 +76,7 @@ type MCPConnectionPool struct {
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
@@ -90,6 +91,10 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
|
||||
debug: debug,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
|
||||
}
|
||||
|
||||
go pool.startHealthCheck()
|
||||
return pool
|
||||
}
|
||||
@@ -103,6 +108,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
|
||||
// When set, the pool can trigger OAuth re-authorization when a tool call fails
|
||||
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oauthFlow = flow
|
||||
}
|
||||
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
@@ -230,18 +244,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
client, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry after successful auth
|
||||
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
@@ -323,13 +362,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sseClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start SSE client: %w", err)
|
||||
}
|
||||
|
||||
return sseClient, nil
|
||||
@@ -354,13 +409,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := streamableClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
|
||||
}
|
||||
|
||||
return streamableClient, nil
|
||||
@@ -381,7 +452,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
@@ -539,6 +610,9 @@ func (p *MCPConnectionPool) Close() error {
|
||||
|
||||
// isConnectionError checks if the error is connection-related
|
||||
func isConnectionError(err error) bool {
|
||||
if IsOAuthError(err) {
|
||||
return false // OAuth errors are recoverable, not connection failures
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "Connection not found") ||
|
||||
strings.Contains(errStr, "transport error") ||
|
||||
|
||||
@@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
|
||||
+10
-1
@@ -22,6 +22,7 @@ type MCPToolManager struct {
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
@@ -53,6 +54,14 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
|
||||
m.model = model
|
||||
}
|
||||
|
||||
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
|
||||
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
|
||||
// support, enabling automatic authorization flows when servers require authentication.
|
||||
// This method should be called before LoadTools.
|
||||
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
||||
m.authHandler = handler
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
@@ -76,7 +85,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) e
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(config.Debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug, m.authHandler)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
|
||||
var loadErrors []string
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
)
|
||||
|
||||
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
|
||||
// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go
|
||||
// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK.
|
||||
type MCPAuthHandler interface {
|
||||
// RedirectURI returns the OAuth redirect URI for transport setup.
|
||||
RedirectURI() string
|
||||
// HandleAuth is called when a server requires OAuth authorization.
|
||||
// It receives the server name and the authorization URL the user must visit.
|
||||
// It returns the full callback URL (containing code and state query params)
|
||||
// after the user completes authorization.
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
|
||||
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
|
||||
// registration, PKCE generation, user authorization (via MCPAuthHandler),
|
||||
// and token exchange.
|
||||
type OAuthFlowRunner struct {
|
||||
handler MCPAuthHandler
|
||||
}
|
||||
|
||||
// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler.
|
||||
func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner {
|
||||
return &OAuthFlowRunner{handler: handler}
|
||||
}
|
||||
|
||||
// RunAuthFlow executes the OAuth authorization flow for the given server.
|
||||
// It extracts the OAuthHandler from the error, performs dynamic client registration
|
||||
// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user
|
||||
// interaction, and exchanges the authorization code for a token.
|
||||
func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error {
|
||||
// Extract the OAuthHandler from the authorization-required error.
|
||||
oauthHandler := client.GetOAuthHandler(authErr)
|
||||
if oauthHandler == nil {
|
||||
return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr)
|
||||
}
|
||||
|
||||
// Perform dynamic client registration if no client ID is configured yet.
|
||||
if oauthHandler.GetClientID() == "" {
|
||||
if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil {
|
||||
return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge.
|
||||
codeVerifier, err := client.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err)
|
||||
}
|
||||
codeChallenge := client.GenerateCodeChallenge(codeVerifier)
|
||||
|
||||
// Generate a random state parameter for CSRF protection.
|
||||
state, err := client.GenerateState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to generate state: %w", err)
|
||||
}
|
||||
|
||||
// Build the authorization URL the user needs to visit.
|
||||
authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err)
|
||||
}
|
||||
|
||||
// Delegate to the MCPAuthHandler for user-facing authorization (e.g. open
|
||||
// browser, wait for redirect). It returns the full callback URL containing
|
||||
// the authorization code and state.
|
||||
callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: user authorization failed: %w", err)
|
||||
}
|
||||
|
||||
// Parse the callback URL to extract the authorization code and state.
|
||||
parsed, err := url.Parse(callbackURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err)
|
||||
}
|
||||
|
||||
code := parsed.Query().Get("code")
|
||||
returnedState := parsed.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'code' parameter")
|
||||
}
|
||||
if returnedState == "" {
|
||||
return fmt.Errorf("oauth flow: callback URL missing 'state' parameter")
|
||||
}
|
||||
|
||||
// Exchange the authorization code for an access token.
|
||||
if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil {
|
||||
return fmt.Errorf("oauth flow: token exchange failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError.
|
||||
func IsOAuthError(err error) bool {
|
||||
return client.IsOAuthAuthorizationRequiredError(err)
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// Compile-time check that FileTokenStore implements transport.TokenStore.
|
||||
var _ transport.TokenStore = (*FileTokenStore)(nil)
|
||||
|
||||
// FileTokenStore is a file-backed implementation of transport.TokenStore that
|
||||
// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file
|
||||
// keyed by server URL, allowing multiple MCP servers to maintain independent tokens.
|
||||
//
|
||||
// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back
|
||||
// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set.
|
||||
//
|
||||
// FileTokenStore is safe for concurrent use.
|
||||
type FileTokenStore struct {
|
||||
serverKey string
|
||||
filePath string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFileTokenStore creates a new FileTokenStore for the given server URL.
|
||||
// The serverKey is used as the map key in the shared token file, and should
|
||||
// typically be the MCP server's base URL.
|
||||
//
|
||||
// Returns an error if the token file path cannot be resolved.
|
||||
func NewFileTokenStore(serverKey string) (*FileTokenStore, error) {
|
||||
filePath, err := resolveTokenFilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolving token file path: %w", err)
|
||||
}
|
||||
|
||||
return &FileTokenStore{
|
||||
serverKey: serverKey,
|
||||
filePath: filePath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetToken returns the stored token for this store's server key.
|
||||
// Returns transport.ErrNoToken if no token exists for the server key or if
|
||||
// the token file does not yet exist.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
return nil, fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
|
||||
token, ok := tokens[s.serverKey]
|
||||
if !ok {
|
||||
return nil, transport.ErrNoToken
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// SaveToken persists the given token for this store's server key.
|
||||
// If the token file or its parent directories do not exist, they are created.
|
||||
// Existing tokens for other server keys are preserved.
|
||||
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
|
||||
func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
tokens, err := readTokenFile(s.filePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("reading token file: %w", err)
|
||||
}
|
||||
if tokens == nil {
|
||||
tokens = make(map[string]*transport.Token)
|
||||
}
|
||||
|
||||
tokens[s.serverKey] = token
|
||||
|
||||
if err := writeTokenFile(s.filePath, tokens); err != nil {
|
||||
return fmt.Errorf("writing token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveTokenFilePath determines the path to the token file using
|
||||
// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/.
|
||||
func resolveTokenFilePath() (string, error) {
|
||||
configDir := os.Getenv("XDG_CONFIG_HOME")
|
||||
if configDir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determining user home directory: %w", err)
|
||||
}
|
||||
configDir = filepath.Join(home, ".config")
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil
|
||||
}
|
||||
|
||||
// readTokenFile reads and unmarshals the token file into a server-keyed map.
|
||||
// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist.
|
||||
func readTokenFile(path string) (map[string]*transport.Token, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tokens map[string]*transport.Token
|
||||
if err := json.Unmarshal(data, &tokens); err != nil {
|
||||
return nil, fmt.Errorf("unmarshaling token file: %w", err)
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// writeTokenFile marshals the token map and writes it to disk, creating
|
||||
// parent directories as needed. The file is written with 0600 permissions
|
||||
// to protect sensitive token data.
|
||||
func writeTokenFile(path string, tokens map[string]*transport.Token) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("creating token directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(tokens, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling tokens: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return fmt.Errorf("writing token file %s: %w", path, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -119,6 +119,12 @@ var SlashCommands = []SlashCommand{
|
||||
return matches
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "/reload-ext",
|
||||
Description: "Hot-reload all extensions from disk",
|
||||
Category: "System",
|
||||
Aliases: []string{"/re"},
|
||||
},
|
||||
{
|
||||
Name: "/quit",
|
||||
Description: "Exit the application",
|
||||
|
||||
@@ -334,6 +334,12 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
body = r.ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Wrap all tool errors in a herald Caution alert so the error text
|
||||
// renders inside a contained block instead of spilling into the layout.
|
||||
if isError && strings.TrimSpace(body) != "" {
|
||||
body = r.ty.Alert(herald.AlertCaution, body)
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := r.ty.Compose(
|
||||
headerLine,
|
||||
|
||||
+88
-8
@@ -118,6 +118,10 @@ type AppController interface {
|
||||
// message starts executing immediately. Returns 0 if started
|
||||
// immediately, >0 if injected/pending.
|
||||
Steer(prompt string) int
|
||||
// SteerWithFiles injects a steering message with optional file
|
||||
// attachments (e.g. pasted images) into the currently running agent
|
||||
// turn. Behaves like Steer but includes file parts alongside the text.
|
||||
SteerWithFiles(prompt string, files []kit.LLMFilePart) int
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
@@ -388,6 +392,11 @@ type AppModelOptions struct {
|
||||
// initialization. They are displayed in the ScrollList at startup.
|
||||
StartupExtensionMessages []string
|
||||
|
||||
// ReloadExtensions hot-reloads all extensions from disk. Called by
|
||||
// the /reload-ext command and the automatic file watcher. May be nil
|
||||
// if no extensions are loaded.
|
||||
ReloadExtensions func() error
|
||||
|
||||
// ThinkingLevel is the initial thinking level (e.g. "off", "medium").
|
||||
ThinkingLevel string
|
||||
// IsReasoningModel is true when the current model supports reasoning.
|
||||
@@ -563,6 +572,9 @@ type AppModel struct {
|
||||
// sessionSelector is the session picker overlay, active in stateSessionSelector.
|
||||
sessionSelector *SessionSelectorComponent
|
||||
|
||||
// reloadExtensions hot-reloads all extensions from disk. May be nil.
|
||||
reloadExtensions func() error
|
||||
|
||||
// switchSession opens a session by JSONL path, replacing the active session.
|
||||
// Wired from cmd/root.go.
|
||||
switchSession func(path string) error
|
||||
@@ -728,6 +740,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
m.isReasoningModel = opts.IsReasoningModel
|
||||
m.setThinkingLevel = opts.SetThinkingLevel
|
||||
m.switchSession = opts.SwitchSession
|
||||
m.reloadExtensions = opts.ReloadExtensions
|
||||
|
||||
// Store context/skills metadata and tool counts for startup display.
|
||||
m.contextPaths = opts.ContextPaths
|
||||
@@ -1246,10 +1259,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
text = strings.TrimSpace(ic.textarea.Value())
|
||||
}
|
||||
if text != "" {
|
||||
// Clear the input and push to history.
|
||||
// Clear the input, collect pending images, and push to history.
|
||||
var images []uicore.ImageAttachment
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.pushHistory(text)
|
||||
ic.textarea.SetValue("")
|
||||
images = ic.ClearPendingImages()
|
||||
}
|
||||
|
||||
// Preprocess @file references.
|
||||
@@ -1258,14 +1273,29 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
|
||||
}
|
||||
|
||||
// Convert image attachments to kit.LLMFilePart for the app layer.
|
||||
var fileParts []kit.LLMFilePart
|
||||
for _, img := range images {
|
||||
fileParts = append(fileParts, kit.LLMFilePart{
|
||||
Data: img.Data,
|
||||
MediaType: img.MediaType,
|
||||
})
|
||||
}
|
||||
|
||||
// Build display text (include image count if any).
|
||||
displayText := text
|
||||
if len(images) > 0 {
|
||||
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
|
||||
}
|
||||
|
||||
// Inject the steer message.
|
||||
sLen := m.appCtrl.Steer(processedText)
|
||||
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
|
||||
if sLen > 0 {
|
||||
m.steeringMessages = append(m.steeringMessages, text)
|
||||
m.steeringMessages = append(m.steeringMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
// Started immediately (agent was idle).
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, text)
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
@@ -1471,6 +1501,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Also update/create StreamingMessageItem in ScrollList for live display
|
||||
m.appendStreamingChunk("reasoning", msg.Delta)
|
||||
|
||||
case app.ReasoningCompleteEvent:
|
||||
// Forward to stream component to freeze reasoning duration
|
||||
if m.stream != nil {
|
||||
updated, cmd := m.stream.Update(msg)
|
||||
m.stream, _ = updated.(streamComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
// Mark the reasoning StreamingMessageItem as complete to freeze its counter
|
||||
if len(m.messages) > 0 {
|
||||
if streamMsg, ok := m.messages[len(m.messages)-1].(*StreamingMessageItem); ok && streamMsg.role == "reasoning" {
|
||||
streamMsg.MarkComplete()
|
||||
}
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
// Forward to stream component for display rendering
|
||||
if m.stream != nil {
|
||||
@@ -1826,6 +1871,13 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.printSystemMessage(msg.output)
|
||||
}
|
||||
|
||||
case extReloadResultMsg:
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
} else {
|
||||
m.printSystemMessage("Extensions reloaded.")
|
||||
}
|
||||
|
||||
case beforeSessionSwitchResultMsg:
|
||||
// Async before-session-switch hook completed. Proceed with the
|
||||
// session reset if the hook did not cancel.
|
||||
@@ -1904,10 +1956,7 @@ func (m *AppModel) View() tea.View {
|
||||
return m.treeSelector.View()
|
||||
}
|
||||
|
||||
// Model selector overlay replaces the normal layout.
|
||||
if m.state == stateModelSelector && m.modelSelector != nil {
|
||||
return m.modelSelector.View()
|
||||
}
|
||||
// Model selector is rendered as a centered overlay later (see below).
|
||||
|
||||
// Session selector overlay replaces the normal layout.
|
||||
if m.state == stateSessionSelector && m.sessionSelector != nil {
|
||||
@@ -2031,6 +2080,12 @@ func (m *AppModel) View() tea.View {
|
||||
}
|
||||
}
|
||||
|
||||
// Render model selector as centered overlay if active
|
||||
if m.state == stateModelSelector && m.modelSelector != nil {
|
||||
popupContent := m.modelSelector.RenderOverlay(m.width, m.height)
|
||||
finalContent = overlayContent(finalContent, popupContent, m.width, m.height)
|
||||
}
|
||||
|
||||
v := tea.NewView(finalContent)
|
||||
v.AltScreen = true
|
||||
v.MouseMode = tea.MouseModeCellMotion
|
||||
@@ -2472,6 +2527,8 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
|
||||
return m.handleThinkingCommand(args)
|
||||
case "/compact":
|
||||
return m.handleCompactCommand(args)
|
||||
case "/reload-ext":
|
||||
return m.handleReloadExtCommand()
|
||||
case "/clear":
|
||||
if m.appCtrl != nil {
|
||||
m.appCtrl.ClearMessages()
|
||||
@@ -2757,6 +2814,22 @@ func (m *AppModel) printResetUsage() {
|
||||
// the app controller rejects the request (busy, closed) it prints an error
|
||||
// instead. customInstructions is optional text appended to the summary
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
// handleReloadExtCommand reloads all extensions from disk asynchronously.
|
||||
// It returns a tea.Cmd to avoid calling prog.Send() from inside Update()
|
||||
// which would deadlock if any extension handler calls ctx.Print() during
|
||||
// SessionShutdown or SessionStart events.
|
||||
func (m *AppModel) handleReloadExtCommand() tea.Cmd {
|
||||
if m.reloadExtensions == nil {
|
||||
m.printSystemMessage("No extensions loaded.")
|
||||
return nil
|
||||
}
|
||||
reload := m.reloadExtensions
|
||||
return func() tea.Msg {
|
||||
err := reload()
|
||||
return extReloadResultMsg{err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AppModel) handleCompactCommand(customInstructions string) tea.Cmd {
|
||||
if m.appCtrl == nil {
|
||||
m.printSystemMessage("Compaction is not available.")
|
||||
@@ -3694,6 +3767,13 @@ type shareResultMsg struct {
|
||||
viewerURL string
|
||||
}
|
||||
|
||||
// extReloadResultMsg carries the result of an asynchronously executed
|
||||
// /reload-ext command. The reload runs async to avoid deadlocking the
|
||||
// TUI event loop (extension handlers may call prog.Send via ctx.Print).
|
||||
type extReloadResultMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// extensionCmdResultMsg carries the result of an asynchronously executed
|
||||
// extension slash command. Extension commands run async (via tea.Cmd) so they
|
||||
// can safely call blocking operations like ctx.PromptSelect().
|
||||
|
||||
+98
-287
@@ -5,12 +5,9 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// ModelEntry holds display metadata for a single model in the selector.
|
||||
@@ -30,16 +27,14 @@ type ModelSelectedMsg struct {
|
||||
// ModelSelectorCancelledMsg is sent when the user cancels the selector.
|
||||
type ModelSelectorCancelledMsg struct{}
|
||||
|
||||
// ModelSelectorComponent is a full-screen Bubble Tea component that displays
|
||||
// a filterable list of available models. It follows the same pattern as
|
||||
// TreeSelectorComponent: inline text search, scrolling list, and custom
|
||||
// messages for result delivery.
|
||||
// ModelSelectorComponent is a Bubble Tea component that displays a filterable
|
||||
// list of available models as a centered overlay popup. It delegates rendering
|
||||
// and keyboard navigation to PopupList and converts results into the
|
||||
// ModelSelectedMsg / ModelSelectorCancelledMsg messages expected by AppModel.
|
||||
type ModelSelectorComponent struct {
|
||||
allModels []ModelEntry // all available models (pre-sorted)
|
||||
filtered []ModelEntry // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
currentModel string // "provider/model" of the active model (for checkmark)
|
||||
popup *PopupList
|
||||
allModels []ModelEntry // kept for the custom filter callback
|
||||
currentModel string // "provider/model" of the active model
|
||||
width int
|
||||
height int
|
||||
active bool
|
||||
@@ -62,7 +57,22 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
continue
|
||||
}
|
||||
|
||||
// For the custom provider, skip the built-in "custom" stub when
|
||||
// user-defined models are present — the stub is a fallback for
|
||||
// --provider-url usage and would just clutter the list.
|
||||
userDefinedCustomModels := 0
|
||||
if providerID == "custom" {
|
||||
for modelID := range modelsMap {
|
||||
if modelID != "custom" {
|
||||
userDefinedCustomModels++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for modelID, info := range modelsMap {
|
||||
if providerID == "custom" && modelID == "custom" && userDefinedCustomModels > 0 {
|
||||
continue
|
||||
}
|
||||
allModels = append(allModels, ModelEntry{
|
||||
Provider: providerID,
|
||||
ModelID: modelID,
|
||||
@@ -81,24 +91,31 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
|
||||
return allModels[i].ModelID < allModels[j].ModelID
|
||||
})
|
||||
|
||||
ms := &ModelSelectorComponent{
|
||||
// Build PopupItems from model entries.
|
||||
items := make([]PopupItem, len(allModels))
|
||||
for i, m := range allModels {
|
||||
items[i] = PopupItem{
|
||||
Label: m.ModelID,
|
||||
Description: fmt.Sprintf("[%s]", m.Provider),
|
||||
Active: m.Provider+"/"+m.ModelID == currentModel,
|
||||
Meta: m,
|
||||
}
|
||||
}
|
||||
|
||||
popup := NewPopupList("Model Selector", items, width, height)
|
||||
popup.Subtitle = "Only showing models with configured API keys"
|
||||
popup.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
|
||||
return filterModels(query, allItems)
|
||||
}
|
||||
|
||||
return &ModelSelectorComponent{
|
||||
popup: popup,
|
||||
allModels: allModels,
|
||||
filtered: allModels,
|
||||
currentModel: currentModel,
|
||||
width: width,
|
||||
height: height,
|
||||
active: true,
|
||||
}
|
||||
|
||||
// Position cursor on the current model if found.
|
||||
for i, m := range ms.filtered {
|
||||
if m.Provider+"/"+m.ModelID == currentModel {
|
||||
ms.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return ms
|
||||
}
|
||||
|
||||
// Init implements tea.Model.
|
||||
@@ -112,236 +129,94 @@ func (ms *ModelSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
ms.width = msg.Width
|
||||
ms.height = msg.Height
|
||||
ms.popup.SetSize(msg.Width, msg.Height)
|
||||
return ms, nil
|
||||
|
||||
case tea.KeyPressMsg:
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
|
||||
if ms.cursor > 0 {
|
||||
ms.cursor--
|
||||
}
|
||||
result := ms.popup.HandleKey(msg.String(), msg.Text)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
|
||||
if ms.cursor < len(ms.filtered)-1 {
|
||||
ms.cursor++
|
||||
if result.Selected != nil {
|
||||
ms.active = false
|
||||
entry := result.Selected.Meta.(ModelEntry)
|
||||
modelStr := entry.Provider + "/" + entry.ModelID
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectedMsg{ModelString: modelStr}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
|
||||
ms.cursor -= ms.visibleHeight()
|
||||
if ms.cursor < 0 {
|
||||
ms.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
|
||||
ms.cursor += ms.visibleHeight()
|
||||
if ms.cursor >= len(ms.filtered) {
|
||||
ms.cursor = len(ms.filtered) - 1
|
||||
}
|
||||
if ms.cursor < 0 {
|
||||
ms.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
|
||||
ms.cursor = 0
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
|
||||
ms.cursor = max(len(ms.filtered)-1, 0)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if ms.cursor < len(ms.filtered) {
|
||||
entry := ms.filtered[ms.cursor]
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectedMsg{
|
||||
ModelString: entry.Provider + "/" + entry.ModelID,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
if ms.search != "" {
|
||||
ms.search = ""
|
||||
ms.rebuildFiltered()
|
||||
} else {
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Inline text search.
|
||||
if msg.Text != "" && len(msg.Text) == 1 {
|
||||
ch := msg.Text[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
ms.search += string(ch)
|
||||
ms.rebuildFiltered()
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ms.search) > 0 {
|
||||
ms.search = ms.search[:len(ms.search)-1]
|
||||
ms.rebuildFiltered()
|
||||
}
|
||||
if result.Cancelled {
|
||||
ms.active = false
|
||||
return ms, func() tea.Msg {
|
||||
return ModelSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// View implements tea.Model.
|
||||
// View implements tea.Model — not used for overlay rendering.
|
||||
// Use RenderOverlay for the centered overlay approach.
|
||||
func (ms *ModelSelectorComponent) View() tea.View {
|
||||
theme := style.GetTheme()
|
||||
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
PaddingLeft(2)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(2)
|
||||
|
||||
infoStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
PaddingLeft(2)
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Header.
|
||||
b.WriteString(headerStyle.Render("Model Selector"))
|
||||
b.WriteString("\n")
|
||||
// Adapt help text to terminal width.
|
||||
if ms.width >= 56 {
|
||||
b.WriteString(helpStyle.Render("↑/↓: move enter: select esc: cancel type to filter"))
|
||||
} else if ms.width >= 35 {
|
||||
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc type"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("↑↓ ↵ esc"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
if ms.width >= 48 {
|
||||
b.WriteString(infoStyle.Render("Only showing models with configured API keys"))
|
||||
} else {
|
||||
b.WriteString(infoStyle.Render("Models with API keys"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Search input.
|
||||
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(2)
|
||||
if ms.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ms.search)))
|
||||
} else {
|
||||
b.WriteString(searchStyle.Render("> "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
|
||||
b.WriteString("\n")
|
||||
|
||||
if len(ms.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
if ms.search != "" {
|
||||
b.WriteString(emptyStyle.Render("No models matching \"" + ms.search + "\""))
|
||||
} else {
|
||||
b.WriteString(emptyStyle.Render("No models available (check API keys)"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
// Visible window.
|
||||
visH := ms.visibleHeight()
|
||||
startIdx := 0
|
||||
if ms.cursor >= visH {
|
||||
startIdx = ms.cursor - visH + 1
|
||||
}
|
||||
endIdx := min(startIdx+visH, len(ms.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := ms.filtered[i]
|
||||
line := ms.renderEntry(entry, i == ms.cursor)
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer.
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
|
||||
b.WriteString("\n")
|
||||
|
||||
footerParts := []string{
|
||||
fmt.Sprintf("(%d/%d)", ms.cursor+1, len(ms.filtered)),
|
||||
}
|
||||
if ms.cursor < len(ms.filtered) {
|
||||
entry := ms.filtered[ms.cursor]
|
||||
if entry.Name != "" {
|
||||
footerParts = append(footerParts, fmt.Sprintf("Model Name: %s", entry.Name))
|
||||
}
|
||||
if entry.ContextLimit > 0 {
|
||||
footerParts = append(footerParts, fmt.Sprintf("Context: %dK", entry.ContextLimit/1000))
|
||||
}
|
||||
}
|
||||
|
||||
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(footerStyle.Render(strings.Join(footerParts, " ")))
|
||||
|
||||
v := tea.NewView(b.String())
|
||||
// Fallback full-screen rendering (unused when rendered as overlay).
|
||||
v := tea.NewView(ms.popup.RenderCentered(ms.width, ms.height))
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
|
||||
// RenderOverlay returns the popup as a centered overlay string, ready to be
|
||||
// composited on top of the main content via overlayContent().
|
||||
func (ms *ModelSelectorComponent) RenderOverlay(termWidth, termHeight int) string {
|
||||
return ms.popup.RenderCentered(termWidth, termHeight)
|
||||
}
|
||||
|
||||
// IsActive returns whether the selector is still accepting input.
|
||||
func (ms *ModelSelectorComponent) IsActive() bool {
|
||||
return ms.active
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
// --- Model-specific fuzzy filter ---
|
||||
|
||||
func (ms *ModelSelectorComponent) visibleHeight() int {
|
||||
// Reserve: header(1) + help(1) + info(1) + search(1) + separator(1) + footer(2) = 7.
|
||||
// Minimum 3 entries so the selector is still usable on short terminals.
|
||||
return max(ms.height-7, 3)
|
||||
}
|
||||
// filterModels scores and filters PopupItems whose Meta is a ModelEntry.
|
||||
func filterModels(query string, items []PopupItem) []PopupItem {
|
||||
if query == "" {
|
||||
return items
|
||||
}
|
||||
q := strings.ToLower(query)
|
||||
|
||||
func (ms *ModelSelectorComponent) rebuildFiltered() {
|
||||
if ms.search == "" {
|
||||
ms.filtered = ms.allModels
|
||||
} else {
|
||||
query := strings.ToLower(ms.search)
|
||||
ms.filtered = ms.filtered[:0]
|
||||
type scored struct {
|
||||
item PopupItem
|
||||
score int
|
||||
}
|
||||
var matches []scored
|
||||
|
||||
type scored struct {
|
||||
entry ModelEntry
|
||||
score int
|
||||
for _, item := range items {
|
||||
entry, ok := item.Meta.(ModelEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var matches []scored
|
||||
|
||||
for _, entry := range ms.allModels {
|
||||
s := ms.fuzzyScoreModel(query, entry)
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{entry: entry, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending, then alphabetically.
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if matches[i].score != matches[j].score {
|
||||
return matches[i].score > matches[j].score
|
||||
}
|
||||
return matches[i].entry.ModelID < matches[j].entry.ModelID
|
||||
})
|
||||
|
||||
ms.filtered = make([]ModelEntry, len(matches))
|
||||
for i, m := range matches {
|
||||
ms.filtered[i] = m.entry
|
||||
s := fuzzyScoreModelEntry(q, entry)
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{item: item, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Clamp cursor.
|
||||
if ms.cursor >= len(ms.filtered) {
|
||||
ms.cursor = max(len(ms.filtered)-1, 0)
|
||||
sort.Slice(matches, func(i, j int) bool {
|
||||
if matches[i].score != matches[j].score {
|
||||
return matches[i].score > matches[j].score
|
||||
}
|
||||
a := matches[i].item.Meta.(ModelEntry)
|
||||
b := matches[j].item.Meta.(ModelEntry)
|
||||
return a.ModelID < b.ModelID
|
||||
})
|
||||
|
||||
result := make([]PopupItem, len(matches))
|
||||
for i, m := range matches {
|
||||
result[i] = m.item
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// fuzzyScoreModel scores a model entry against the search query.
|
||||
func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry) int {
|
||||
// fuzzyScoreModelEntry scores a model entry against the search query.
|
||||
func fuzzyScoreModelEntry(query string, entry ModelEntry) int {
|
||||
modelID := strings.ToLower(entry.ModelID)
|
||||
provider := strings.ToLower(entry.Provider)
|
||||
name := strings.ToLower(entry.Name)
|
||||
@@ -394,67 +269,3 @@ func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ms *ModelSelectorComponent) renderEntry(entry ModelEntry, isCursor bool) string {
|
||||
theme := style.GetTheme()
|
||||
modelStr := entry.ModelID
|
||||
providerStr := fmt.Sprintf("[%s]", entry.Provider)
|
||||
|
||||
// Cursor indicator.
|
||||
var cursor string
|
||||
if isCursor {
|
||||
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("-> ")
|
||||
} else {
|
||||
cursor = " "
|
||||
}
|
||||
|
||||
// Active model checkmark.
|
||||
var active string
|
||||
activeWidth := 0
|
||||
if entry.Provider+"/"+entry.ModelID == ms.currentModel {
|
||||
active = lipgloss.NewStyle().Foreground(theme.Success).Render(" \u2713")
|
||||
activeWidth = 2 // " ✓"
|
||||
}
|
||||
|
||||
// Truncate model ID and provider tag to fit terminal width.
|
||||
// Layout: cursor(3) + model + " " + provider + active.
|
||||
// Use rune length for display-width accuracy (the "…" suffix is 1 rune / 1 column).
|
||||
const cursorWidth = 3
|
||||
available := max(ms.width-cursorWidth-activeWidth-1, 10) // 1 for space between model and provider
|
||||
provDisplayLen := len([]rune(providerStr))
|
||||
modelDisplayLen := len([]rune(modelStr))
|
||||
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
// Prioritize model name — truncate it, but keep provider visible.
|
||||
maxModel := max(available-provDisplayLen-1, 6)
|
||||
if maxModel < modelDisplayLen {
|
||||
if maxModel > 3 {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel-1]) + "…"
|
||||
} else {
|
||||
runes := []rune(modelStr)
|
||||
modelStr = string(runes[:maxModel])
|
||||
}
|
||||
}
|
||||
// If provider itself is too long, drop it.
|
||||
modelDisplayLen = len([]rune(modelStr))
|
||||
if modelDisplayLen+1+provDisplayLen > available {
|
||||
providerStr = ""
|
||||
}
|
||||
}
|
||||
|
||||
// Style the model ID.
|
||||
modelStyle := lipgloss.NewStyle().Foreground(theme.Text)
|
||||
if isCursor {
|
||||
modelStyle = modelStyle.Bold(true).Foreground(theme.Accent)
|
||||
}
|
||||
|
||||
// Style the provider tag.
|
||||
providerStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
|
||||
result := cursor + modelStyle.Render(modelStr)
|
||||
if providerStr != "" {
|
||||
result += " " + providerStyle.Render(providerStr)
|
||||
}
|
||||
return result + active
|
||||
}
|
||||
|
||||
@@ -81,6 +81,11 @@ func (s *stubAppController) Steer(prompt string) int {
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) int {
|
||||
s.runCalls = append(s.runCalls, prompt)
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -668,6 +673,7 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
|
||||
}
|
||||
if bashItem == nil {
|
||||
t.Fatal("expected StreamingBashOutputItem in messages after ToolOutputEvent")
|
||||
return
|
||||
}
|
||||
if len(bashItem.stdoutLines) != 1 || bashItem.stdoutLines[0] != "line one\n" {
|
||||
t.Fatalf("expected stdout=['line one\\n'], got %v", bashItem.stdoutLines)
|
||||
|
||||
@@ -0,0 +1,501 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// PopupItem represents a single entry in a PopupList. The component renders
|
||||
// Label as the primary text and Description as secondary text to its right.
|
||||
// The Active flag renders a checkmark to indicate the currently-active item
|
||||
// (e.g. the current model). Meta is opaque caller data returned on selection.
|
||||
type PopupItem struct {
|
||||
Label string // primary display text
|
||||
Description string // secondary text (shown right of label)
|
||||
Active bool // true → render checkmark indicator
|
||||
Meta any // opaque data returned on selection
|
||||
}
|
||||
|
||||
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
|
||||
// rendered as a centered overlay on top of the normal TUI layout and can be
|
||||
// reused by any feature that needs a selection popup (slash commands, model
|
||||
// selector, session picker, extension-provided lists, etc.).
|
||||
//
|
||||
// The caller is responsible for:
|
||||
// - Building the initial item list
|
||||
// - Providing a fuzzy-filter callback (or nil for substring matching)
|
||||
// - Handling the result when the user selects or cancels
|
||||
//
|
||||
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
|
||||
type PopupList struct {
|
||||
// Title shown at the top of the popup.
|
||||
Title string
|
||||
// Subtitle shown below the title (dimmed).
|
||||
Subtitle string
|
||||
// FooterHint overrides the default keyboard-hint footer.
|
||||
FooterHint string
|
||||
|
||||
allItems []PopupItem // full unfiltered list
|
||||
filtered []PopupItem // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
|
||||
// FilterFunc is called with (query, allItems) and should return the
|
||||
// filtered+scored subset. When nil, a default substring match is used.
|
||||
FilterFunc func(query string, items []PopupItem) []PopupItem
|
||||
|
||||
width int
|
||||
height int
|
||||
maxVisible int // max items visible at once (0 = auto from height)
|
||||
showSearch bool
|
||||
}
|
||||
|
||||
// PopupResult is returned by HandleKey to tell the caller what happened.
|
||||
type PopupResult struct {
|
||||
// Selected is non-nil when the user pressed Enter on an item.
|
||||
Selected *PopupItem
|
||||
// Cancelled is true when the user pressed Esc with no search text.
|
||||
Cancelled bool
|
||||
// Changed is true when the search or cursor moved (caller should re-render).
|
||||
Changed bool
|
||||
}
|
||||
|
||||
// NewPopupList creates a new popup list with the given items and dimensions.
|
||||
func NewPopupList(title string, items []PopupItem, width, height int) *PopupList {
|
||||
p := &PopupList{
|
||||
Title: title,
|
||||
allItems: items,
|
||||
filtered: items,
|
||||
width: width,
|
||||
height: height,
|
||||
showSearch: true,
|
||||
}
|
||||
// Position cursor on the active item if one exists.
|
||||
for i, item := range p.filtered {
|
||||
if item.Active {
|
||||
p.cursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// SetSize updates the popup dimensions (e.g. on window resize).
|
||||
func (p *PopupList) SetSize(width, height int) {
|
||||
p.width = width
|
||||
p.height = height
|
||||
}
|
||||
|
||||
// visibleCount returns the number of items visible at once.
|
||||
func (p *PopupList) visibleCount() int {
|
||||
if p.maxVisible > 0 {
|
||||
return p.maxVisible
|
||||
}
|
||||
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
|
||||
overhead := 8
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.showSearch {
|
||||
overhead += 2 // search line + separator
|
||||
}
|
||||
return max(p.height/2-overhead, 3)
|
||||
}
|
||||
|
||||
// HandleKey processes a single key event and returns the result. The caller
|
||||
// should inspect PopupResult to decide whether to re-render, close the popup,
|
||||
// or act on a selection.
|
||||
//
|
||||
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
|
||||
// keyText is the printable text for character keys (e.g. "a", "1").
|
||||
func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
|
||||
switch keyName {
|
||||
case "up":
|
||||
if p.cursor > 0 {
|
||||
p.cursor--
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "down":
|
||||
if p.cursor < len(p.filtered)-1 {
|
||||
p.cursor++
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "pgup":
|
||||
p.cursor -= p.visibleCount()
|
||||
if p.cursor < 0 {
|
||||
p.cursor = 0
|
||||
}
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "pgdown":
|
||||
p.cursor += p.visibleCount()
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "home":
|
||||
p.cursor = 0
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "end":
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
return PopupResult{Changed: true}
|
||||
|
||||
case "enter":
|
||||
if p.cursor < len(p.filtered) {
|
||||
item := p.filtered[p.cursor]
|
||||
return PopupResult{Selected: &item}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
case "esc":
|
||||
if p.search != "" {
|
||||
p.search = ""
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{Cancelled: true}
|
||||
|
||||
case "backspace":
|
||||
if len(p.search) > 0 {
|
||||
p.search = p.search[:len(p.search)-1]
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
return PopupResult{}
|
||||
|
||||
default:
|
||||
// Printable character → append to search.
|
||||
if keyText != "" && len(keyText) == 1 {
|
||||
ch := keyText[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
p.search += string(ch)
|
||||
p.rebuildFiltered()
|
||||
return PopupResult{Changed: true}
|
||||
}
|
||||
}
|
||||
return PopupResult{}
|
||||
}
|
||||
}
|
||||
|
||||
// Render returns the styled popup content (bordered box) ready to be placed
|
||||
// as a centered overlay via lipgloss.Place + overlayContent.
|
||||
func (p *PopupList) Render() string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(min(p.width-4, 80), 20)
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginBottom(1)
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
// Title.
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
b.WriteString(titleStyle.Render(p.Title))
|
||||
b.WriteString("\n")
|
||||
|
||||
// Subtitle.
|
||||
if p.Subtitle != "" {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
b.WriteString(subtitleStyle.Render(p.Subtitle))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Search input.
|
||||
if p.showSearch {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
if p.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
|
||||
} else {
|
||||
b.WriteString(searchStyle.Render("> "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
|
||||
// Separator.
|
||||
sepStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg)
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Item list.
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
vis := p.visibleCount()
|
||||
var items []string
|
||||
|
||||
if len(p.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
if p.search != "" {
|
||||
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
|
||||
} else {
|
||||
items = append(items, emptyStyle.Render("No items"))
|
||||
}
|
||||
} else {
|
||||
startIdx := 0
|
||||
if p.cursor >= vis {
|
||||
startIdx = p.cursor - vis + 1
|
||||
}
|
||||
endIdx := min(startIdx+vis, len(p.filtered))
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append(items, scrollStyle.Render(" ↑ more above"))
|
||||
}
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := p.filtered[i]
|
||||
isCursor := i == p.cursor
|
||||
|
||||
itemStyle := normalItemBg
|
||||
if isCursor {
|
||||
itemStyle = selectedItemBg
|
||||
}
|
||||
|
||||
// Build indicator.
|
||||
var indicator string
|
||||
if isCursor {
|
||||
indicator = "> "
|
||||
} else {
|
||||
indicator = " "
|
||||
}
|
||||
|
||||
// Build content: indicator + label + description + active checkmark.
|
||||
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
if endIdx < len(p.filtered) {
|
||||
items = append(items, scrollStyle.Render(" ↓ more below"))
|
||||
}
|
||||
}
|
||||
|
||||
content := b.String() + strings.Join(items, "\n")
|
||||
|
||||
// Footer with count and keyboard hints.
|
||||
var footerParts []string
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
|
||||
footerHint := p.FooterHint
|
||||
if footerHint == "" {
|
||||
if innerWidth >= 50 {
|
||||
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
|
||||
} else if innerWidth >= 30 {
|
||||
footerHint = "↑↓ nav • ↵ select • esc"
|
||||
} else {
|
||||
footerHint = "↑↓ ↵ esc"
|
||||
}
|
||||
}
|
||||
footerParts = append(footerParts, footerHint)
|
||||
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true).
|
||||
Render(strings.Join(footerParts, " "))
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
}
|
||||
|
||||
// RenderCentered returns the popup placed at the center of a termWidth×termHeight
|
||||
// canvas, ready to be composed with overlayContent().
|
||||
func (p *PopupList) RenderCentered(termWidth, termHeight int) string {
|
||||
popupContent := p.Render()
|
||||
return lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
}
|
||||
|
||||
// IsSearching returns true when the search input is non-empty.
|
||||
func (p *PopupList) IsSearching() bool {
|
||||
return p.search != ""
|
||||
}
|
||||
|
||||
// SelectedItem returns the item under the cursor, or nil if the list is empty.
|
||||
func (p *PopupList) SelectedItem() *PopupItem {
|
||||
if p.cursor < len(p.filtered) {
|
||||
item := p.filtered[p.cursor]
|
||||
return &item
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (p *PopupList) rebuildFiltered() {
|
||||
if p.FilterFunc != nil {
|
||||
p.filtered = p.FilterFunc(p.search, p.allItems)
|
||||
} else {
|
||||
p.filtered = defaultFilter(p.search, p.allItems)
|
||||
}
|
||||
// Clamp cursor.
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// defaultFilter is a simple case-insensitive substring + fuzzy character match.
|
||||
func defaultFilter(query string, items []PopupItem) []PopupItem {
|
||||
if query == "" {
|
||||
return items
|
||||
}
|
||||
q := strings.ToLower(query)
|
||||
type scored struct {
|
||||
item PopupItem
|
||||
score int
|
||||
}
|
||||
var matches []scored
|
||||
for _, item := range items {
|
||||
label := strings.ToLower(item.Label)
|
||||
desc := strings.ToLower(item.Description)
|
||||
|
||||
var s int
|
||||
switch {
|
||||
case label == q:
|
||||
s = 1000
|
||||
case strings.HasPrefix(label, q):
|
||||
s = 800 - len(label) + len(q)
|
||||
case strings.Contains(label, q):
|
||||
s = 600
|
||||
case strings.Contains(desc, q):
|
||||
s = 400
|
||||
default:
|
||||
s = fuzzyCharacterMatch(q, label)
|
||||
}
|
||||
if s > 0 {
|
||||
matches = append(matches, scored{item: item, score: s})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score descending, then alphabetically by label.
|
||||
for i := 0; i < len(matches)-1; i++ {
|
||||
for j := i + 1; j < len(matches); j++ {
|
||||
if matches[j].score > matches[i].score ||
|
||||
(matches[j].score == matches[i].score && matches[j].item.Label < matches[i].item.Label) {
|
||||
matches[i], matches[j] = matches[j], matches[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]PopupItem, len(matches))
|
||||
for i, m := range matches {
|
||||
result[i] = m.item
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// renderItemContent builds the display string for a single item row.
|
||||
func (p *PopupList) renderItemContent(indicator string, entry PopupItem, innerWidth int, isCursor bool) string {
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Reserve space: indicator(2) + potential checkmark(2)
|
||||
activeWidth := 0
|
||||
if entry.Active {
|
||||
activeWidth = 2
|
||||
}
|
||||
available := max(innerWidth-2-activeWidth, 6) // 2 for indicator, already included
|
||||
|
||||
label := entry.Label
|
||||
desc := entry.Description
|
||||
|
||||
if desc != "" {
|
||||
// Two-column layout: label + description.
|
||||
descWidth := len([]rune(desc)) + 1 // 1 space gap
|
||||
labelMax := max(available-descWidth, available*2/3)
|
||||
if len([]rune(label)) > labelMax && labelMax > 3 {
|
||||
runes := []rune(label)
|
||||
label = string(runes[:labelMax-1]) + "…"
|
||||
}
|
||||
labelDisplayLen := len([]rune(label))
|
||||
|
||||
// If label + desc don't fit, truncate or drop desc.
|
||||
if labelDisplayLen+1+len([]rune(desc)) > available {
|
||||
remaining := available - labelDisplayLen - 1
|
||||
if remaining >= 4 {
|
||||
runes := []rune(desc)
|
||||
if len(runes) > remaining {
|
||||
desc = string(runes[:remaining-1]) + "…"
|
||||
}
|
||||
} else {
|
||||
desc = ""
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single column: just the label.
|
||||
if len([]rune(label)) > available && available > 3 {
|
||||
runes := []rune(label)
|
||||
label = string(runes[:available-1]) + "…"
|
||||
}
|
||||
}
|
||||
|
||||
result := indicator + label
|
||||
if desc != "" {
|
||||
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
if isCursor {
|
||||
// When selected, use a dimmer foreground that still contrasts with Primary bg.
|
||||
descStyle = lipgloss.NewStyle().Foreground(theme.Background)
|
||||
}
|
||||
result += " " + descStyle.Render(desc)
|
||||
}
|
||||
if entry.Active {
|
||||
checkStyle := lipgloss.NewStyle().Foreground(theme.Success)
|
||||
if isCursor {
|
||||
checkStyle = lipgloss.NewStyle().Foreground(theme.Background)
|
||||
}
|
||||
result += checkStyle.Render(" ✓")
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPopupList_NewPositionsCursorOnActiveItem(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma", Active: true},
|
||||
{Label: "delta"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("expected cursor on active item (index 2), got %d", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_Navigation(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Initial cursor at 0.
|
||||
if p.cursor != 0 {
|
||||
t.Fatalf("expected cursor 0, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Down → 1.
|
||||
res := p.HandleKey("down", "")
|
||||
if !res.Changed || p.cursor != 1 {
|
||||
t.Errorf("down: changed=%v cursor=%d", res.Changed, p.cursor)
|
||||
}
|
||||
|
||||
// Down → 2.
|
||||
p.HandleKey("down", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Down at end → stays at 2.
|
||||
res = p.HandleKey("down", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("down at end: expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Up → 1.
|
||||
res = p.HandleKey("up", "")
|
||||
if !res.Changed || p.cursor != 1 {
|
||||
t.Errorf("up: changed=%v cursor=%d", res.Changed, p.cursor)
|
||||
}
|
||||
|
||||
// Home → 0.
|
||||
p.HandleKey("home", "")
|
||||
if p.cursor != 0 {
|
||||
t.Errorf("home: expected cursor 0, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// End → 2.
|
||||
p.HandleKey("end", "")
|
||||
if p.cursor != 2 {
|
||||
t.Errorf("end: expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_Search(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "apple"},
|
||||
{Label: "banana"},
|
||||
{Label: "cherry"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Type "an" → should filter to banana.
|
||||
p.HandleKey("a", "a")
|
||||
p.HandleKey("n", "n")
|
||||
|
||||
if !p.IsSearching() {
|
||||
t.Error("expected IsSearching() to be true")
|
||||
}
|
||||
if len(p.filtered) == 0 {
|
||||
t.Fatal("expected at least one filtered result")
|
||||
}
|
||||
// banana should match (contains "an").
|
||||
found := false
|
||||
for _, item := range p.filtered {
|
||||
if item.Label == "banana" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("expected 'banana' in filtered results")
|
||||
}
|
||||
|
||||
// Backspace removes last char.
|
||||
p.HandleKey("backspace", "")
|
||||
if p.search != "a" {
|
||||
t.Errorf("expected search 'a' after backspace, got %q", p.search)
|
||||
}
|
||||
|
||||
// Esc clears search.
|
||||
res := p.HandleKey("esc", "")
|
||||
if res.Cancelled {
|
||||
t.Error("esc with search should clear search, not cancel")
|
||||
}
|
||||
if p.search != "" {
|
||||
t.Errorf("expected empty search after esc, got %q", p.search)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_HandleKey_SelectAndCancel(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha", Meta: "first"},
|
||||
{Label: "beta", Meta: "second"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Select first item.
|
||||
res := p.HandleKey("enter", "")
|
||||
if res.Selected == nil {
|
||||
t.Fatal("expected a selection on enter")
|
||||
}
|
||||
if res.Selected.Label != "alpha" {
|
||||
t.Errorf("expected 'alpha', got %q", res.Selected.Label)
|
||||
}
|
||||
if res.Selected.Meta != "first" {
|
||||
t.Errorf("expected meta 'first', got %v", res.Selected.Meta)
|
||||
}
|
||||
|
||||
// Cancel with esc (no search text).
|
||||
p2 := NewPopupList("Test", items, 80, 40)
|
||||
res = p2.HandleKey("esc", "")
|
||||
if !res.Cancelled {
|
||||
t.Error("expected Cancelled on esc with no search")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_DefaultFilter(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "foo-bar"},
|
||||
{Label: "baz-qux"},
|
||||
{Label: "foobar"},
|
||||
}
|
||||
|
||||
// Exact prefix.
|
||||
result := defaultFilter("foo", items)
|
||||
if len(result) < 2 {
|
||||
t.Fatalf("expected at least 2 matches for 'foo', got %d", len(result))
|
||||
}
|
||||
// "foobar" should rank higher (shorter match) or equal to "foo-bar".
|
||||
if result[0].Label != "foobar" && result[1].Label != "foobar" {
|
||||
t.Error("expected 'foobar' in top results")
|
||||
}
|
||||
|
||||
// No match.
|
||||
result = defaultFilter("zzz", items)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected 0 matches for 'zzz', got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_CustomFilterFunc(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
p.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
|
||||
// Custom: only return items whose label starts with query.
|
||||
var result []PopupItem
|
||||
for _, item := range allItems {
|
||||
if strings.HasPrefix(item.Label, query) {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
p.HandleKey("b", "b")
|
||||
if len(p.filtered) != 1 || p.filtered[0].Label != "beta" {
|
||||
t.Errorf("expected ['beta'], got %v", p.filtered)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_Render(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha", Description: "[test]"},
|
||||
{Label: "beta", Description: "[test]", Active: true},
|
||||
}
|
||||
p := NewPopupList("My List", items, 80, 40)
|
||||
p.Subtitle = "Some subtitle"
|
||||
|
||||
rendered := p.Render()
|
||||
if rendered == "" {
|
||||
t.Fatal("expected non-empty rendered output")
|
||||
}
|
||||
|
||||
// Strip ANSI escape sequences for content checking.
|
||||
plain := stripAnsi(rendered)
|
||||
if !strings.Contains(plain, "My List") {
|
||||
t.Error("expected title 'My List' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "alpha") {
|
||||
t.Error("expected 'alpha' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "beta") {
|
||||
t.Error("expected 'beta' in rendered output")
|
||||
}
|
||||
if !strings.Contains(plain, "✓") {
|
||||
t.Error("expected checkmark for active item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_RenderCentered(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "item1"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
centered := p.RenderCentered(80, 40)
|
||||
if centered == "" {
|
||||
t.Fatal("expected non-empty centered output")
|
||||
}
|
||||
// Should contain newlines for vertical centering.
|
||||
lines := strings.Split(centered, "\n")
|
||||
if len(lines) < 10 {
|
||||
t.Errorf("expected centered output to have many lines, got %d", len(lines))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_EmptyItems(t *testing.T) {
|
||||
p := NewPopupList("Empty", nil, 80, 40)
|
||||
|
||||
rendered := p.Render()
|
||||
if !strings.Contains(rendered, "No items") {
|
||||
t.Error("expected 'No items' for empty list")
|
||||
}
|
||||
|
||||
// Navigate on empty list shouldn't panic.
|
||||
p.HandleKey("down", "")
|
||||
p.HandleKey("up", "")
|
||||
res := p.HandleKey("enter", "")
|
||||
if res.Selected != nil {
|
||||
t.Error("enter on empty list should not select")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_SearchNoResults(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Type something that doesn't match.
|
||||
p.HandleKey("z", "z")
|
||||
p.HandleKey("z", "z")
|
||||
p.HandleKey("z", "z")
|
||||
|
||||
rendered := p.Render()
|
||||
if !strings.Contains(rendered, "No matches") {
|
||||
t.Error("expected 'No matches' message for empty search results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPopupList_CursorClamping(t *testing.T) {
|
||||
items := []PopupItem{
|
||||
{Label: "alpha"},
|
||||
{Label: "beta"},
|
||||
{Label: "gamma"},
|
||||
}
|
||||
p := NewPopupList("Test", items, 80, 40)
|
||||
|
||||
// Move to last item.
|
||||
p.HandleKey("end", "")
|
||||
if p.cursor != 2 {
|
||||
t.Fatalf("expected cursor 2, got %d", p.cursor)
|
||||
}
|
||||
|
||||
// Search that reduces list to 1 item → cursor should clamp.
|
||||
p.HandleKey("a", "a")
|
||||
p.HandleKey("l", "l")
|
||||
// Only "alpha" should match.
|
||||
if p.cursor >= len(p.filtered) {
|
||||
t.Errorf("cursor %d should be < filtered count %d", p.cursor, len(p.filtered))
|
||||
}
|
||||
}
|
||||
|
||||
// stripAnsi is defined in usage_tracker_render_test.go
|
||||
@@ -47,7 +47,6 @@ func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme
|
||||
contentRendered := mutedStyle.Render(ty.Italic(contentStr))
|
||||
|
||||
// Build label based on duration
|
||||
var labelText string
|
||||
if duration > 0 {
|
||||
var durationStr string
|
||||
if duration < 1000 {
|
||||
@@ -55,12 +54,14 @@ func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", float64(duration)/1000)
|
||||
}
|
||||
labelText = "Thought for " + durationStr
|
||||
} else {
|
||||
labelText = "Thought"
|
||||
labelPart := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationPart := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
label := labelPart + durationPart
|
||||
rendered := contentRendered + "\n" + label
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(labelText)
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought")
|
||||
rendered := contentRendered + "\n" + label
|
||||
|
||||
return styleMarginBottom(theme, rendered)
|
||||
|
||||
+129
-54
@@ -252,57 +252,107 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// View implements tea.Model.
|
||||
func (ss *SessionSelectorComponent) View() tea.View {
|
||||
theme := style.GetTheme()
|
||||
w := ss.width
|
||||
var b strings.Builder
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ss.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ss.width
|
||||
}
|
||||
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ss.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// ── Header: title + scope badges ─────────────────────────────
|
||||
titleStyle := lipgloss.NewStyle().Bold(true).Foreground(theme.Accent).PaddingLeft(1)
|
||||
b.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
b.WriteString("\n")
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Help / keybindings ───────────────────────────────────────
|
||||
helpStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(1)
|
||||
if w >= 75 {
|
||||
b.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if w >= 50 {
|
||||
b.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if innerWidth >= 75 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if innerWidth >= 50 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("tab N D esc"))
|
||||
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Search (only shown when active) ──────────────────────────
|
||||
if ss.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(1)
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
b.WriteString("\n")
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
// Separator line
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Delete confirmation ──────────────────────────────────────
|
||||
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
|
||||
warnStyle := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).PaddingLeft(1)
|
||||
warnStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Background(theme.Background)
|
||||
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
|
||||
b.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// ── Session list ─────────────────────────────────────────────
|
||||
if len(ss.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if ss.search != "" {
|
||||
b.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
} else if ss.filter == SessionFilterNamed {
|
||||
b.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
} else if ss.scope == SessionScopeCwd {
|
||||
b.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
} else {
|
||||
b.WriteString(emptyStyle.Render("No sessions found"))
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
visH := ss.visibleHeight()
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ss.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
// Center the cursor in the visible window.
|
||||
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
|
||||
@@ -313,20 +363,40 @@ func (ss *SessionSelectorComponent) View() tea.View {
|
||||
isCursor := i == ss.cursor
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := i == ss.confirmDelete
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, w)
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Scroll position indicator.
|
||||
if len(ss.filtered) > visH {
|
||||
posStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
b.WriteString("\n")
|
||||
posStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
v := tea.NewView(b.String())
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Footer with filter info
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
|
||||
|
||||
// Apply the bordered container
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
@@ -411,7 +481,7 @@ func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCurs
|
||||
// ── Cursor indicator (2 chars) ───────────────────────────────
|
||||
cursorStr := " "
|
||||
if isCursor {
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("› ")
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
}
|
||||
const cursorW = 2
|
||||
|
||||
@@ -439,45 +509,50 @@ func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCurs
|
||||
msgW := utf8.RuneCountInString(displayText)
|
||||
|
||||
// ── Style the message ────────────────────────────────────────
|
||||
msgStyle := lipgloss.NewStyle()
|
||||
var msgStyle lipgloss.Style
|
||||
switch {
|
||||
case isDeleting:
|
||||
msgStyle = msgStyle.Foreground(theme.Error)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
case isCurrent:
|
||||
msgStyle = msgStyle.Foreground(theme.Accent)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
case info.Name != "":
|
||||
msgStyle = msgStyle.Foreground(theme.Warning)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
default:
|
||||
msgStyle = msgStyle.Foreground(theme.Text)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
}
|
||||
if isCursor {
|
||||
msgStyle = msgStyle.Bold(true)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
|
||||
// ── Style the right part ─────────────────────────────────────
|
||||
rightColor := theme.Muted
|
||||
if isDeleting {
|
||||
rightColor = theme.Error
|
||||
}
|
||||
styledRight := lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
var styledRight string
|
||||
|
||||
// ── Assemble with spacing ────────────────────────────────────
|
||||
spacing := max(width-cursorW-msgW-rightW, 1)
|
||||
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
// ── Background highlight for selected row ────────────────────
|
||||
// If selected, use inverted colors like PopupList
|
||||
if isCursor {
|
||||
// Use a subtle background highlight. We apply it by wrapping the
|
||||
// full line in a style with a background color.
|
||||
bgStyle := lipgloss.NewStyle().
|
||||
Background(theme.Highlight).
|
||||
Width(width)
|
||||
line = bgStyle.Render(line)
|
||||
// Inverted colors for selected item
|
||||
msgStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
styledRight = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(rightColor).
|
||||
Render(rightPart)
|
||||
cursorStr = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Accent).
|
||||
Render("> ")
|
||||
} else {
|
||||
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
|
||||
+15
-56
@@ -2,7 +2,6 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -14,17 +13,6 @@ import (
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from streaming text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// thinkTagOpen and thinkTagClose are the opening and closing think tag strings.
|
||||
const (
|
||||
thinkTagOpen = "<think>"
|
||||
thinkTagClose = "</think>"
|
||||
)
|
||||
|
||||
// knightRiderFrames generates a KITT-style scanning animation where a bright
|
||||
// light bounces back and forth across a row of dots with a trailing glow.
|
||||
// Colors are derived from the active theme. Used by StreamComponent (TUI
|
||||
@@ -207,10 +195,6 @@ type StreamComponent struct {
|
||||
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
|
||||
reasoningDuration time.Duration
|
||||
|
||||
// inThinkTag tracks whether we're currently inside a section
|
||||
// from models that wrap reasoning in XML-like tags (Qwen, DeepSeek).
|
||||
inThinkTag bool
|
||||
|
||||
// renderer renders streaming assistant text.
|
||||
renderer Renderer
|
||||
|
||||
@@ -319,9 +303,7 @@ func (s *StreamComponent) GetRenderedContent() string {
|
||||
// Called before reading content for output or on flush tick.
|
||||
func (s *StreamComponent) commitPending() {
|
||||
if s.pendingStream.Len() > 0 {
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(s.pendingStream.String(), "")
|
||||
s.streamContent.WriteString(cleanedText)
|
||||
s.streamContent.WriteString(s.pendingStream.String())
|
||||
s.pendingStream.Reset()
|
||||
}
|
||||
if s.pendingReasoning.Len() > 0 {
|
||||
@@ -401,6 +383,17 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, streamFlushTickCmd(s.flushGeneration)
|
||||
}
|
||||
|
||||
case app.ReasoningCompleteEvent:
|
||||
// Freeze reasoning duration when reasoning finishes (before text streaming starts).
|
||||
if s.reasoningDuration == 0 && !s.reasoningStartTime.IsZero() {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
// Flush any remaining pending reasoning content.
|
||||
if s.pendingReasoning.Len() > 0 {
|
||||
s.reasoningContent.WriteString(s.pendingReasoning.String())
|
||||
s.pendingReasoning.Reset()
|
||||
}
|
||||
|
||||
case app.StreamChunkEvent:
|
||||
s.phase = streamPhaseActive
|
||||
if s.timestamp.IsZero() {
|
||||
@@ -411,43 +404,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.reasoningDuration = time.Since(s.reasoningStartTime)
|
||||
}
|
||||
|
||||
// Handle models that wrap reasoning in tags (Qwen, DeepSeek)
|
||||
// Filter out all content between and tags
|
||||
content := msg.Content
|
||||
|
||||
// Check for opening tag
|
||||
if strings.Contains(content, thinkTagOpen) {
|
||||
parts := strings.SplitN(content, thinkTagOpen, 2)
|
||||
// Content before the tag can be written
|
||||
if !s.inThinkTag && parts[0] != "" {
|
||||
s.pendingStream.WriteString(parts[0])
|
||||
}
|
||||
s.inThinkTag = true
|
||||
// Content after the opening tag is reasoning - don't write it
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
// Check if the same chunk contains the closing tag
|
||||
if strings.Contains(parts[1], thinkTagClose) {
|
||||
innerParts := strings.SplitN(parts[1], thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(innerParts) > 1 && innerParts[1] != "" {
|
||||
s.pendingStream.WriteString(innerParts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(content, thinkTagClose) {
|
||||
// Closing tag found
|
||||
parts := strings.SplitN(content, thinkTagClose, 2)
|
||||
s.inThinkTag = false
|
||||
// Content after closing tag can be written
|
||||
if len(parts) > 1 && parts[1] != "" {
|
||||
s.pendingStream.WriteString(parts[1])
|
||||
}
|
||||
} else if !s.inThinkTag {
|
||||
// Normal content, not inside think tags
|
||||
s.pendingStream.WriteString(content)
|
||||
}
|
||||
// else: inside think tag, don't write this content
|
||||
// <think> tag filtering is handled at the agent layer — chunks here
|
||||
// are already clean text.
|
||||
s.pendingStream.WriteString(msg.Content)
|
||||
|
||||
if !s.flushPending && s.pendingStream.Len() > 0 {
|
||||
s.flushPending = true
|
||||
|
||||
@@ -301,13 +301,13 @@ func KitBanner() string {
|
||||
kittDark := lipgloss.Color("#8B0000")
|
||||
kittBright := lipgloss.Color("#FF2200")
|
||||
lines := []string{
|
||||
" ██╗ ██╗ ██╗ ████████╗",
|
||||
" ██║ ██╔╝ ██║ ╚══██╔══╝",
|
||||
" █████╔╝ ██║ ██║",
|
||||
" ██╔═██╗ ██║ ██║",
|
||||
" ██║ ██╗ ██║ ██║",
|
||||
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
|
||||
" ░░░░░░▒▒▒▒▒▓▓▓▓███████████████▓▓▓▓▒▒▒▒▒░░░░░░",
|
||||
" ██╗ ██╗ ██╗ ████████╗",
|
||||
" ██║ ██╔╝ ██║ ╚══██╔══╝",
|
||||
" █████╔╝ ██║ ██║",
|
||||
" ██╔═██╗ ██║ ██║",
|
||||
" ██║ ██╗ ██║ ██║",
|
||||
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
|
||||
"░░ ░░ ░░ ▒▒ ▒▒ ▓▓ ▓▓ ████ ▓▓ ▓▓ ▒▒ ▒▒ ░░ ░░ ░░",
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
|
||||
+277
-32
@@ -28,10 +28,10 @@ const (
|
||||
maxLsLines = 20 // lines for Ls directory listings
|
||||
)
|
||||
|
||||
// isShellTool reports if the tool name matches a shell-like tool (bash, grep, find, or
|
||||
// isShellTool reports if the tool name matches a shell-like tool (bash or
|
||||
// tools with "shell"/"command" in the name). Used by renderToolBody.
|
||||
func isShellTool(toolName string) bool {
|
||||
return toolName == "bash" || toolName == "grep" || toolName == "find" ||
|
||||
return toolName == "bash" ||
|
||||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command")
|
||||
}
|
||||
|
||||
@@ -55,8 +55,16 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
if body := renderWriteBody(toolArgs, toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "find":
|
||||
if body := renderFindBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "grep":
|
||||
if body := renderGrepBody(toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case isShellTool(toolName):
|
||||
if body := renderBashBody(toolResult, width); body != "" {
|
||||
if body := renderBashBody(toolArgs, toolResult, width); body != "" {
|
||||
return body
|
||||
}
|
||||
case toolName == "subagent":
|
||||
@@ -337,6 +345,148 @@ func renderDiffBlock(before, after string, startLine int, width int) string {
|
||||
// Ls tool — simple list without gutter
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderFindBody renders find output as a plain list with code background.
|
||||
// Similar to ls but with results-specific caption.
|
||||
func renderFindBody(toolResult string, width int) string {
|
||||
content := strings.TrimSpace(toolResult)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
totalResults := len(lines)
|
||||
|
||||
// Truncate to maxLsLines for display
|
||||
var hiddenCount int
|
||||
if len(lines) > maxLsLines {
|
||||
hiddenCount = len(lines) - maxLsLines
|
||||
lines = lines[:maxLsLines]
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
codeWidth := max(width-len(lineIndent), 20)
|
||||
|
||||
theme := GetTheme()
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var rendered []string
|
||||
for _, line := range lines {
|
||||
// Truncate before styling to prevent wrapping.
|
||||
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
|
||||
styled := codeStyle.Width(codeWidth).Render(line)
|
||||
rendered = append(rendered, styled)
|
||||
}
|
||||
|
||||
content = strings.Join(rendered, "\n")
|
||||
|
||||
// Build caption with results info
|
||||
var captionParts []string
|
||||
if totalResults == 1 {
|
||||
captionParts = append(captionParts, "1 result")
|
||||
} else {
|
||||
captionParts = append(captionParts, fmt.Sprintf("%d results", totalResults))
|
||||
}
|
||||
if hiddenCount > 0 {
|
||||
captionParts = append(captionParts, fmt.Sprintf("%d more", hiddenCount))
|
||||
}
|
||||
|
||||
if len(captionParts) > 1 || hiddenCount > 0 {
|
||||
ty := herald.New(herald.WithTheme(herald.Theme{
|
||||
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
FigureCaptionPosition: herald.CaptionBottom,
|
||||
}))
|
||||
caption := strings.Join(captionParts, " • ")
|
||||
result := ty.Figure(content, caption)
|
||||
|
||||
// Indent entire block (content + caption) to match other tools
|
||||
const blockIndent = " "
|
||||
resultLines := strings.Split(result, "\n")
|
||||
for i, line := range resultLines {
|
||||
resultLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(resultLines, "\n")
|
||||
}
|
||||
|
||||
// Single result with no truncation - just return indented content
|
||||
const blockIndent = " "
|
||||
contentLines := strings.Split(content, "\n")
|
||||
for i, line := range contentLines {
|
||||
contentLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// renderGrepBody renders grep output as a plain list with code background.
|
||||
// Similar to find but with match-specific caption terminology.
|
||||
func renderGrepBody(toolResult string, width int) string {
|
||||
content := strings.TrimSpace(toolResult)
|
||||
if content == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lines := strings.Split(content, "\n")
|
||||
totalMatches := len(lines)
|
||||
|
||||
// Truncate to maxLsLines for display
|
||||
var hiddenCount int
|
||||
if len(lines) > maxLsLines {
|
||||
hiddenCount = len(lines) - maxLsLines
|
||||
lines = lines[:maxLsLines]
|
||||
}
|
||||
|
||||
const lineIndent = " "
|
||||
codeWidth := max(width-len(lineIndent), 20)
|
||||
|
||||
theme := GetTheme()
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var rendered []string
|
||||
for _, line := range lines {
|
||||
// Truncate before styling to prevent wrapping.
|
||||
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
|
||||
styled := codeStyle.Width(codeWidth).Render(line)
|
||||
rendered = append(rendered, styled)
|
||||
}
|
||||
|
||||
content = strings.Join(rendered, "\n")
|
||||
|
||||
// Build caption with match info
|
||||
var captionParts []string
|
||||
if totalMatches == 1 {
|
||||
captionParts = append(captionParts, "1 match")
|
||||
} else {
|
||||
captionParts = append(captionParts, fmt.Sprintf("%d matches", totalMatches))
|
||||
}
|
||||
if hiddenCount > 0 {
|
||||
captionParts = append(captionParts, fmt.Sprintf("%d more", hiddenCount))
|
||||
}
|
||||
|
||||
if len(captionParts) > 1 || hiddenCount > 0 {
|
||||
ty := herald.New(herald.WithTheme(herald.Theme{
|
||||
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
FigureCaptionPosition: herald.CaptionBottom,
|
||||
}))
|
||||
caption := strings.Join(captionParts, " • ")
|
||||
result := ty.Figure(content, caption)
|
||||
|
||||
// Indent entire block (content + caption) to match other tools
|
||||
const blockIndent = " "
|
||||
resultLines := strings.Split(result, "\n")
|
||||
for i, line := range resultLines {
|
||||
resultLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(resultLines, "\n")
|
||||
}
|
||||
|
||||
// Single match with no truncation - just return indented content
|
||||
const blockIndent = " "
|
||||
contentLines := strings.Split(content, "\n")
|
||||
for i, line := range contentLines {
|
||||
contentLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// renderLsBody renders ls output as a plain list with code background and no
|
||||
// line-number gutter.
|
||||
func renderLsBody(toolResult string, width int) string {
|
||||
@@ -354,28 +504,47 @@ func renderLsBody(toolResult string, width int) string {
|
||||
lines = lines[:maxLsLines]
|
||||
}
|
||||
|
||||
const indent = " "
|
||||
codeWidth := max(width-len(indent), 20)
|
||||
const lineIndent = " "
|
||||
codeWidth := max(width-len(lineIndent), 20)
|
||||
|
||||
theme := GetTheme()
|
||||
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
|
||||
|
||||
var result []string
|
||||
var rendered []string
|
||||
for _, line := range lines {
|
||||
// Truncate before styling to prevent wrapping.
|
||||
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
|
||||
styled := codeStyle.Width(codeWidth).Render(line)
|
||||
result = append(result, indent+styled)
|
||||
rendered = append(rendered, styled)
|
||||
}
|
||||
|
||||
content = strings.Join(rendered, "\n")
|
||||
|
||||
// Build caption with hidden entries info
|
||||
if hiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more entries)", hiddenCount)
|
||||
hintContent := codeStyle.Width(codeWidth).
|
||||
Foreground(theme.Muted).Italic(true).Render(hint)
|
||||
result = append(result, indent+hintContent)
|
||||
ty := herald.New(herald.WithTheme(herald.Theme{
|
||||
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
FigureCaptionPosition: herald.CaptionBottom,
|
||||
}))
|
||||
caption := fmt.Sprintf("%d more entries", hiddenCount)
|
||||
result := ty.Figure(content, caption)
|
||||
|
||||
// Indent entire block (content + caption) to match other tools
|
||||
const blockIndent = " "
|
||||
resultLines := strings.Split(result, "\n")
|
||||
for i, line := range resultLines {
|
||||
resultLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(resultLines, "\n")
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
// No caption - just return indented content
|
||||
const blockIndent = " "
|
||||
contentLines := strings.Split(content, "\n")
|
||||
for i, line := range contentLines {
|
||||
contentLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -461,19 +630,50 @@ func renderReadBody(toolArgs, toolResult string, width int) string {
|
||||
)
|
||||
|
||||
// Render the code block
|
||||
result := ty.CodeBlock(codeContent, lang)
|
||||
codeBlock := ty.CodeBlock(codeContent, lang)
|
||||
|
||||
// Add truncation hint if needed
|
||||
// Herald's codeBlockWithLineNumbers() hardcodes PaddingTop(1) and
|
||||
// PaddingBottom(1), adding invisible blank lines with background color
|
||||
// above and below the code. These interfere with mouse selection
|
||||
// (off-by-one) because the padding line looks blank but occupies a
|
||||
// line index in the rendered item. Strip them since the Compose
|
||||
// separator above and Figure caption below already provide spacing.
|
||||
codeBlock = stripCodeBlockPadding(codeBlock)
|
||||
|
||||
// Parse total lines from footer if available (e.g., "[showing lines 1-100 of 407 total...]")
|
||||
totalLines := totalCodeLines
|
||||
for _, footer := range footerLines {
|
||||
if matches := regexp.MustCompile(`of (\d+) total`).FindStringSubmatch(footer); len(matches) > 1 {
|
||||
if t, _ := strconv.Atoi(matches[1]); t > totalLines {
|
||||
totalLines = t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build caption with file metadata
|
||||
var captionParts []string
|
||||
if fileName != "" {
|
||||
captionParts = append(captionParts, filepath.Base(fileName))
|
||||
}
|
||||
if len(codeLines) > 0 {
|
||||
endLine := offset + len(codeLines) - 1
|
||||
captionParts = append(captionParts, fmt.Sprintf("lines %d-%d of %d", offset, endLine, totalLines))
|
||||
}
|
||||
if codeHiddenCount > 0 {
|
||||
hint := fmt.Sprintf("...(%d more lines)", codeHiddenCount)
|
||||
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Italic(true).Render(hint)
|
||||
nextOffset := offset + len(codeLines)
|
||||
captionParts = append(captionParts, fmt.Sprintf("offset=%d to continue", nextOffset))
|
||||
}
|
||||
|
||||
// Add any footer lines
|
||||
if len(footerLines) > 0 {
|
||||
footer := strings.Join(footerLines, "\n")
|
||||
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Render(footer)
|
||||
caption := strings.Join(captionParts, " • ")
|
||||
|
||||
// Use Figure with caption below content (default behavior)
|
||||
// Apply theme to ensure caption is positioned below
|
||||
figTheme := herald.Theme{
|
||||
FigureCaption: lipgloss.NewStyle().Foreground(GetTheme().Muted),
|
||||
FigureCaptionPosition: herald.CaptionBottom,
|
||||
}
|
||||
tyFig := herald.New(herald.WithTheme(figTheme))
|
||||
result := tyFig.Figure(codeBlock, caption)
|
||||
|
||||
// Indent entire block to match Write/Edit tools (2 spaces)
|
||||
const blockIndent = " "
|
||||
@@ -582,7 +782,7 @@ func renderWriteBlock(content, fileName string, width int) string {
|
||||
|
||||
// renderBashBody renders bash output with per-line background and stderr
|
||||
// in error color.
|
||||
func renderBashBody(toolResult string, width int) string {
|
||||
func renderBashBody(toolArgs, toolResult string, width int) string {
|
||||
if strings.TrimSpace(toolResult) == "" {
|
||||
return ""
|
||||
}
|
||||
@@ -609,6 +809,7 @@ func renderBashBody(toolResult string, width int) string {
|
||||
maxLineChars := lineWidth - 1
|
||||
|
||||
var rendered []string
|
||||
exitCode := -1 // -1 means not found
|
||||
inStderr := false
|
||||
for _, line := range lines {
|
||||
line = truncateLine(line, maxLineChars)
|
||||
@@ -617,30 +818,55 @@ func renderBashBody(toolResult string, width int) string {
|
||||
inStderr = true
|
||||
continue
|
||||
}
|
||||
// Exit code line
|
||||
// Exit code line - extract it for caption
|
||||
if strings.HasPrefix(line, "Exit code:") {
|
||||
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
|
||||
rendered = append(rendered, lineIndent+styled)
|
||||
continue
|
||||
_, _ = fmt.Sscanf(line, "Exit code: %d", &exitCode)
|
||||
continue // Don't render exit code inline, it goes in caption
|
||||
}
|
||||
|
||||
if inStderr {
|
||||
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
|
||||
rendered = append(rendered, lineIndent+styled)
|
||||
rendered = append(rendered, styled)
|
||||
} else {
|
||||
styled := outputStyle.Width(width - len(lineIndent)).Render(line)
|
||||
rendered = append(rendered, lineIndent+styled)
|
||||
rendered = append(rendered, styled)
|
||||
}
|
||||
}
|
||||
|
||||
// Build caption with status info
|
||||
var captionParts []string
|
||||
if hiddenCount > 0 {
|
||||
truncMsg := fmt.Sprintf("...(%d more lines)", hiddenCount)
|
||||
hint := outputStyle.Width(width - len(lineIndent)).
|
||||
Foreground(theme.Muted).Italic(true).Render(truncMsg)
|
||||
rendered = append(rendered, lineIndent+hint)
|
||||
captionParts = append(captionParts, fmt.Sprintf("%d more lines", hiddenCount))
|
||||
}
|
||||
if exitCode >= 0 {
|
||||
captionParts = append(captionParts, fmt.Sprintf("exit code %d", exitCode))
|
||||
}
|
||||
|
||||
return strings.Join(rendered, "\n")
|
||||
content := strings.Join(rendered, "\n")
|
||||
if len(captionParts) > 0 {
|
||||
ty := herald.New(herald.WithTheme(herald.Theme{
|
||||
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
FigureCaptionPosition: herald.CaptionBottom,
|
||||
}))
|
||||
caption := strings.Join(captionParts, " • ")
|
||||
result := ty.Figure(content, caption)
|
||||
|
||||
// Indent entire block (content + caption) to match other tools
|
||||
const blockIndent = " "
|
||||
lines := strings.Split(result, "\n")
|
||||
for i, line := range lines {
|
||||
lines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// No caption - just return indented content
|
||||
const blockIndent = " "
|
||||
contentLines := strings.Split(content, "\n")
|
||||
for i, line := range contentLines {
|
||||
contentLines[i] = blockIndent + line
|
||||
}
|
||||
return strings.Join(contentLines, "\n")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -724,6 +950,25 @@ func padRight(s string, width int) string {
|
||||
return s + strings.Repeat(" ", width-w)
|
||||
}
|
||||
|
||||
// stripCodeBlockPadding removes the top and bottom padding lines that herald's
|
||||
// codeBlockWithLineNumbers() hardcodes via PaddingTop(1)/PaddingBottom(1).
|
||||
// These padding lines are blank lines with background color that look invisible
|
||||
// but occupy line indices, causing mouse selection to be off by one row.
|
||||
func stripCodeBlockPadding(block string) string {
|
||||
lines := strings.Split(block, "\n")
|
||||
if len(lines) < 3 {
|
||||
return block
|
||||
}
|
||||
// The first and last lines are padding (blank with bg color).
|
||||
// Strip them only if they contain no visible text.
|
||||
first := xansi.Strip(lines[0])
|
||||
last := xansi.Strip(lines[len(lines)-1])
|
||||
if strings.TrimSpace(first) == "" && strings.TrimSpace(last) == "" {
|
||||
return strings.Join(lines[1:len(lines)-1], "\n")
|
||||
}
|
||||
return block
|
||||
}
|
||||
|
||||
// truncateLine truncates a line to maxWidth visual characters, adding "…"
|
||||
// if truncated. This is ANSI-aware: escape codes are preserved and wide
|
||||
// characters are measured correctly.
|
||||
|
||||
+150
-51
@@ -226,46 +226,92 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
func (ts *TreeSelectorComponent) View() tea.View {
|
||||
theme := GetTheme()
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ts.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ts.width
|
||||
}
|
||||
maxHeight := ts.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ts.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
// Header style with background highlight (like PopupList title)
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
PaddingLeft(2)
|
||||
Background(theme.Background)
|
||||
|
||||
// Help text style
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
PaddingLeft(2)
|
||||
Background(theme.Background)
|
||||
|
||||
var b strings.Builder
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// Header.
|
||||
b.WriteString(headerStyle.Render("Session Tree"))
|
||||
b.WriteString("\n")
|
||||
// Adapt help text to terminal width.
|
||||
// Header row with title and help
|
||||
headerRow := headerStyle.Render("Session Tree")
|
||||
contentBuilder.WriteString(headerRow)
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Help text - adapt to terminal width
|
||||
var helpText string
|
||||
if ts.width >= 70 {
|
||||
b.WriteString(helpStyle.Render("↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"))
|
||||
helpText = "↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"
|
||||
} else if ts.width >= 45 {
|
||||
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc cancel ^O filter"))
|
||||
helpText = "↑↓ move ↵ select esc cancel ^O filter"
|
||||
} else {
|
||||
b.WriteString(helpStyle.Render("↑↓ ↵ esc ^O"))
|
||||
helpText = "↑↓ ↵ esc ^O"
|
||||
}
|
||||
b.WriteString("\n")
|
||||
contentBuilder.WriteString(helpStyle.Render(helpText))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Search display (if active)
|
||||
if ts.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(2)
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("Search: %s", ts.search)))
|
||||
b.WriteString("\n")
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ts.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ts.width)))
|
||||
b.WriteString("\n")
|
||||
// Separator line - full width
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Tree content
|
||||
if len(ts.flatNodes) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
b.WriteString(emptyStyle.Render("No entries in session"))
|
||||
b.WriteString("\n")
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(emptyStyle.Render("No entries in session"))
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
// Compute visible window.
|
||||
visH := ts.visibleHeight()
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ts.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
startIdx := 0
|
||||
if ts.cursor >= visH {
|
||||
startIdx = ts.cursor - visH + 1
|
||||
@@ -274,21 +320,32 @@ func (ts *TreeSelectorComponent) View() tea.View {
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
node := ts.flatNodes[i]
|
||||
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID)
|
||||
b.WriteString(line)
|
||||
b.WriteString("\n")
|
||||
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer.
|
||||
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ts.width)))
|
||||
b.WriteString("\n")
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
|
||||
// Footer with count and filter
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
footer := fmt.Sprintf("(%d/%d) [%s]", ts.cursor+1, len(ts.flatNodes), ts.filter)
|
||||
b.WriteString(footerStyle.Render(footer))
|
||||
contentBuilder.WriteString(footerStyle.Render(footer))
|
||||
|
||||
v := tea.NewView(b.String())
|
||||
// Apply the bordered container - full width, no centering
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
@@ -420,21 +477,23 @@ func (ts *TreeSelectorComponent) passesFilter(node *session.TreeNode) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool) string {
|
||||
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool, innerWidth int) string {
|
||||
theme := GetTheme()
|
||||
maxWidth := max(ts.width-4, 10)
|
||||
|
||||
// Cursor indicator.
|
||||
// Cursor indicator - use ">" for selected (like PopupList)
|
||||
var cursor string
|
||||
if isCursor {
|
||||
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("› ")
|
||||
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
} else {
|
||||
cursor = " "
|
||||
}
|
||||
|
||||
// Role-colored content.
|
||||
// Role-colored content with background support for selection
|
||||
text := ts.entryDisplayText(node.Entry)
|
||||
available := maxWidth - len(node.Prefix) - 10
|
||||
|
||||
// Calculate available width accounting for cursor, prefix, and markers
|
||||
prefixLen := len(node.Prefix)
|
||||
available := innerWidth - prefixLen - 4 // 4 for cursor and some padding
|
||||
if available > 3 && len(text) > available {
|
||||
trimLen := max(available-3, 1)
|
||||
if trimLen < len(text) {
|
||||
@@ -442,48 +501,88 @@ func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool
|
||||
}
|
||||
}
|
||||
|
||||
var style lipgloss.Style
|
||||
// Build the full line style
|
||||
var lineStyle lipgloss.Style
|
||||
var textStyle lipgloss.Style
|
||||
|
||||
// Base text color based on role
|
||||
switch e := node.Entry.(type) {
|
||||
case *session.MessageEntry:
|
||||
switch e.Role {
|
||||
case "user":
|
||||
style = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
case "assistant":
|
||||
style = lipgloss.NewStyle().Foreground(theme.Success)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Success)
|
||||
default:
|
||||
style = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
}
|
||||
case *session.BranchSummaryEntry:
|
||||
style = lipgloss.NewStyle().Foreground(theme.Warning).Italic(true)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Warning).Italic(true)
|
||||
case *session.CompactionEntry:
|
||||
style = lipgloss.NewStyle().Foreground(theme.Info).Italic(true)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Info).Italic(true)
|
||||
default:
|
||||
style = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
}
|
||||
|
||||
// Apply selection highlighting (like PopupList)
|
||||
if isCursor {
|
||||
style = style.Bold(true)
|
||||
// Inverted colors for selected item - matches PopupList style
|
||||
lineStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
textStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
content := style.Render(text)
|
||||
// Render components
|
||||
content := textStyle.Render(text)
|
||||
|
||||
// Label badge.
|
||||
var labelBadge string
|
||||
if node.Label != "" {
|
||||
labelBadge = " " + lipgloss.NewStyle().Foreground(theme.Warning).Render("["+node.Label+"]")
|
||||
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
if isCursor {
|
||||
labelStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Warning)
|
||||
}
|
||||
labelBadge = " " + labelStyle.Render("["+node.Label+"]")
|
||||
}
|
||||
|
||||
// Active marker.
|
||||
// Active marker - use Success color for better visibility
|
||||
var activeMarker string
|
||||
if isLeaf {
|
||||
activeMarker = lipgloss.NewStyle().Foreground(theme.Accent).Bold(true).Render(" ← active")
|
||||
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
|
||||
if isCursor {
|
||||
markerStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
activeMarker = markerStyle.Render(" ← active")
|
||||
}
|
||||
|
||||
// Prefix (tree lines).
|
||||
prefixStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
// Prefix (tree lines) - use MutedBorder for subtler appearance
|
||||
prefixStyle := lipgloss.NewStyle().Foreground(theme.MutedBorder)
|
||||
if isCursor {
|
||||
prefixStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.MutedBorder)
|
||||
}
|
||||
renderedPrefix := prefixStyle.Render(node.Prefix)
|
||||
|
||||
return cursor + renderedPrefix + content + labelBadge + activeMarker
|
||||
// Combine all parts
|
||||
line := cursor + renderedPrefix + content + labelBadge + activeMarker
|
||||
|
||||
// If selected, apply the background to the entire line
|
||||
if isCursor {
|
||||
return lineStyle.Render(line)
|
||||
}
|
||||
|
||||
return line
|
||||
}
|
||||
|
||||
func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
|
||||
|
||||
@@ -202,6 +202,7 @@ func Init(api ext.API) {
|
||||
footer := harness.Context().GetFooter()
|
||||
if footer == nil {
|
||||
t.Fatal("expected footer to be set")
|
||||
return
|
||||
}
|
||||
if footer.Content.Text != "Status: OK" {
|
||||
t.Errorf("expected footer text 'Status: OK', got %q", footer.Content.Text)
|
||||
@@ -258,6 +259,7 @@ func Init(api ext.API) {
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Block {
|
||||
|
||||
@@ -39,6 +39,9 @@ const (
|
||||
EventCompaction EventType = "compaction"
|
||||
// EventReasoningDelta fires for each streaming reasoning/thinking chunk.
|
||||
EventReasoningDelta EventType = "reasoning_delta"
|
||||
// EventReasoningComplete fires when reasoning/thinking is finished,
|
||||
// after the last reasoning token has been processed.
|
||||
EventReasoningComplete EventType = "reasoning_complete"
|
||||
// EventToolOutput fires when a tool produces streaming output chunks.
|
||||
EventToolOutput EventType = "tool_output"
|
||||
EventStepUsage EventType = "step_usage"
|
||||
@@ -149,6 +152,13 @@ type ReasoningDeltaEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e ReasoningDeltaEvent) EventType() EventType { return EventReasoningDelta }
|
||||
|
||||
// ReasoningCompleteEvent fires when reasoning/thinking is finished, after the
|
||||
// last reasoning token has been processed.
|
||||
type ReasoningCompleteEvent struct{}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e ReasoningCompleteEvent) EventType() EventType { return EventReasoningComplete }
|
||||
|
||||
// ToolOutputEvent fires when a tool produces streaming output chunks (e.g., bash output).
|
||||
type ToolOutputEvent struct {
|
||||
ToolCallID string
|
||||
|
||||
@@ -177,6 +177,7 @@ func TestEventTypes(t *testing.T) {
|
||||
{ResponseEvent{}, EventResponse},
|
||||
{CompactionEvent{}, EventCompaction},
|
||||
{ReasoningDeltaEvent{}, EventReasoningDelta},
|
||||
{ReasoningCompleteEvent{}, EventReasoningComplete},
|
||||
{ToolOutputEvent{}, EventToolOutput},
|
||||
{StepUsageEvent{}, EventStepUsage},
|
||||
{SteerConsumedEvent{}, EventSteerConsumed},
|
||||
@@ -224,6 +225,7 @@ func TestEventOrdering(t *testing.T) {
|
||||
EventMessageStart,
|
||||
EventMessageUpdate,
|
||||
EventReasoningDelta,
|
||||
EventReasoningComplete,
|
||||
EventToolOutput,
|
||||
EventToolCall,
|
||||
EventToolExecutionStart,
|
||||
@@ -242,6 +244,7 @@ func TestEventOrdering(t *testing.T) {
|
||||
bus.emit(MessageStartEvent{})
|
||||
bus.emit(MessageUpdateEvent{Chunk: "hello"})
|
||||
bus.emit(ReasoningDeltaEvent{Delta: "thinking..."})
|
||||
bus.emit(ReasoningCompleteEvent{})
|
||||
bus.emit(ToolOutputEvent{ToolName: "bash", Chunk: "output"})
|
||||
bus.emit(ToolCallEvent{ToolName: "bash"})
|
||||
bus.emit(ToolExecutionStartEvent{ToolName: "bash"})
|
||||
|
||||
+243
-101
@@ -48,6 +48,7 @@ type Kit struct {
|
||||
skills []*skills.Skill
|
||||
extRunner *extensions.Runner
|
||||
bufferedLogger *tools.BufferedDebugLogger
|
||||
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
@@ -80,8 +81,8 @@ type Kit struct {
|
||||
// the running agent turn via the LLM library's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
steerMu sync.Mutex
|
||||
steerCh chan string
|
||||
leftoverSteer []string // unconsumed steer messages from the last turn
|
||||
steerCh chan agent.SteerMessage
|
||||
leftoverSteer []agent.SteerMessage // unconsumed steer messages from the last turn
|
||||
}
|
||||
|
||||
// Subscribe registers an EventListener that will be called for every lifecycle
|
||||
@@ -268,8 +269,8 @@ func (m *Kit) GetAvailableModels() []extensions.ModelInfoEntry {
|
||||
}
|
||||
|
||||
// ReloadExtensions hot-reloads all extensions from disk. Event handlers,
|
||||
// commands, renderers, and shortcuts update immediately. Extension-defined
|
||||
// tools are NOT updated (they are baked into the agent at creation time).
|
||||
// commands, renderers, shortcuts, and extension-defined tools all update
|
||||
// immediately.
|
||||
func (m *Kit) ReloadExtensions() error {
|
||||
if m.extRunner == nil {
|
||||
return fmt.Errorf("no extensions loaded")
|
||||
@@ -290,6 +291,12 @@ func (m *Kit) ReloadExtensions() error {
|
||||
// Swap extensions on the runner (clears dynamic state).
|
||||
m.extRunner.Reload(loaded)
|
||||
|
||||
// Update extension tools on the agent so the LLM sees changes.
|
||||
if m.agent != nil {
|
||||
extTools := extensions.ExtensionToolsAsFantasy(m.extRunner.RegisteredTools(), m.extRunner)
|
||||
m.agent.SetExtraTools(extTools)
|
||||
}
|
||||
|
||||
// Re-set context and emit SessionStart.
|
||||
ctx := m.extRunner.GetContext()
|
||||
m.extRunner.SetContext(ctx)
|
||||
@@ -433,6 +440,18 @@ type Options struct {
|
||||
// Debug enables debug logging for the SDK.
|
||||
Debug bool
|
||||
|
||||
// MCPAuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports (streamable HTTP, SSE) are configured with
|
||||
// OAuth support. If the server returns a 401, the handler is invoked to
|
||||
// let the user authorize via browser.
|
||||
//
|
||||
// If nil, a [DefaultMCPAuthHandler] is created automatically — opening the
|
||||
// system browser and listening on a local callback server.
|
||||
//
|
||||
// Set to a custom implementation to control the authorization UX (e.g.
|
||||
// display a URL in a custom UI, redirect to a web app, etc.).
|
||||
MCPAuthHandler MCPAuthHandler
|
||||
|
||||
// CLI is optional CLI-specific configuration. SDK users leave this nil.
|
||||
CLI *CLIOptions
|
||||
}
|
||||
@@ -499,85 +518,125 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
opts = &Options{}
|
||||
}
|
||||
|
||||
viperInitMu.Lock()
|
||||
defer viperInitMu.Unlock()
|
||||
// All viper writes (SetSDKDefaults, InitConfig, Set calls, system-prompt
|
||||
// composition) happen under viperInitMu. We also call BuildProviderConfig
|
||||
// here — it's fast (just reads) — so we can capture the full config
|
||||
// snapshot before releasing the lock. The expensive work (MCP loading,
|
||||
// provider creation, session init) then runs outside the lock, allowing
|
||||
// parallel subagent spawns to proceed concurrently.
|
||||
var (
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
cwd string
|
||||
contextFiles []*ContextFile
|
||||
loadedSkills []*Skill
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
)
|
||||
|
||||
// Set CLI-equivalent defaults for viper. When used as an SDK (without
|
||||
// cobra), these defaults are not registered via flag bindings.
|
||||
setSDKDefaults()
|
||||
if err := func() error {
|
||||
viperInitMu.Lock()
|
||||
defer viperInitMu.Unlock()
|
||||
|
||||
// Initialize config (loads config files and env vars).
|
||||
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
|
||||
// Check if model is already set, which indicates config was loaded.
|
||||
if viper.GetString("model") == "" {
|
||||
if err := InitConfig(opts.ConfigFile, false); err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize config: %w", err)
|
||||
}
|
||||
}
|
||||
// Set CLI-equivalent defaults for viper. When used as an SDK (without
|
||||
// cobra), these defaults are not registered via flag bindings.
|
||||
setSDKDefaults()
|
||||
|
||||
// Handle CLI debug mode.
|
||||
if opts.Debug {
|
||||
viper.Set("debug", true)
|
||||
}
|
||||
|
||||
// Override viper settings with options.
|
||||
if opts.Model != "" {
|
||||
viper.Set("model", opts.Model)
|
||||
}
|
||||
if opts.SystemPrompt != "" {
|
||||
viper.Set("system-prompt", opts.SystemPrompt)
|
||||
}
|
||||
if opts.MaxSteps > 0 {
|
||||
viper.Set("max-steps", opts.MaxSteps)
|
||||
}
|
||||
viper.Set("stream", opts.Streaming)
|
||||
|
||||
// Resolve working directory for context/skill discovery.
|
||||
cwd := opts.SessionDir
|
||||
if cwd == "" {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
// Load context files (AGENTS.md) from the project root.
|
||||
contextFiles := loadContextFiles(cwd)
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
loadedSkills, err := loadSkills(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
|
||||
// Always compose the system prompt with runtime context: base prompt +
|
||||
// AGENTS.md context + skills metadata + date/cwd.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
// Initialize config (loads config files and env vars).
|
||||
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
|
||||
// Check if model is already set, which indicates config was loaded.
|
||||
if viper.GetString("model") == "" {
|
||||
if err := InitConfig(opts.ConfigFile, false); err != nil {
|
||||
return fmt.Errorf("failed to initialize config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Inject skills metadata (name + description + location).
|
||||
if len(loadedSkills) > 0 {
|
||||
pb.WithSkills(loadedSkills)
|
||||
// Handle CLI debug mode.
|
||||
if opts.Debug {
|
||||
viper.Set("debug", true)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
pb.WithSection("", fmt.Sprintf(
|
||||
"Current date and time: %s\nCurrent working directory: %s",
|
||||
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
|
||||
))
|
||||
// Override viper settings with options.
|
||||
if opts.Model != "" {
|
||||
viper.Set("model", opts.Model)
|
||||
}
|
||||
if opts.SystemPrompt != "" {
|
||||
viper.Set("system-prompt", opts.SystemPrompt)
|
||||
}
|
||||
if opts.MaxSteps > 0 {
|
||||
viper.Set("max-steps", opts.MaxSteps)
|
||||
}
|
||||
viper.Set("stream", opts.Streaming)
|
||||
|
||||
viper.Set("system-prompt", pb.Build())
|
||||
// Resolve working directory for context/skill discovery.
|
||||
cwd = opts.SessionDir
|
||||
if cwd == "" {
|
||||
cwd, _ = os.Getwd()
|
||||
}
|
||||
|
||||
// Load context files (AGENTS.md) from the project root.
|
||||
contextFiles = loadContextFiles(cwd)
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
|
||||
// Always compose the system prompt with runtime context: base prompt +
|
||||
// AGENTS.md context + skills metadata + date/cwd.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
}
|
||||
|
||||
// Inject skills metadata (name + description + location).
|
||||
if len(loadedSkills) > 0 {
|
||||
pb.WithSkills(loadedSkills)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
pb.WithSection("", fmt.Sprintf(
|
||||
"Current date and time: %s\nCurrent working directory: %s",
|
||||
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
|
||||
))
|
||||
|
||||
viper.Set("system-prompt", pb.Build())
|
||||
}
|
||||
|
||||
// Snapshot all viper-derived values now, while the lock is held.
|
||||
// BuildProviderConfig is fast (pure reads), so we do it here.
|
||||
var pcErr error
|
||||
providerConfig, _, pcErr = kitsetup.BuildProviderConfig()
|
||||
if pcErr != nil {
|
||||
return fmt.Errorf("failed to build provider config: %w", pcErr)
|
||||
}
|
||||
modelString = viper.GetString("model")
|
||||
debug = viper.GetBool("debug")
|
||||
noExtensions = viper.GetBool("no-extensions")
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
streaming = viper.GetBool("stream")
|
||||
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// ---- viperInitMu released — heavy I/O below runs concurrently ----
|
||||
|
||||
// Load MCP configuration. Use pre-loaded config if provided via CLI options.
|
||||
var mcpConfig *config.Config
|
||||
if opts.CLI != nil {
|
||||
if opts.CLI != nil && opts.CLI.MCPConfig != nil {
|
||||
mcpConfig = opts.CLI.MCPConfig
|
||||
}
|
||||
if mcpConfig == nil {
|
||||
var err error
|
||||
mcpConfig, err = config.LoadAndValidateConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load MCP config: %w", err)
|
||||
@@ -595,13 +654,37 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
beforeCompact := newHookRegistry[BeforeCompactHook, BeforeCompactResult]()
|
||||
|
||||
// Build agent setup options, pulling CLI-specific fields when available.
|
||||
// Pass the pre-built ProviderConfig and scalar viper snapshots so
|
||||
// SetupAgent doesn't need to re-read viper (which would require the lock).
|
||||
setupOpts := kitsetup.AgentSetupOptions{
|
||||
MCPConfig: mcpConfig,
|
||||
Quiet: opts.Quiet,
|
||||
CoreTools: opts.Tools,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
|
||||
MCPConfig: mcpConfig,
|
||||
Quiet: opts.Quiet,
|
||||
CoreTools: opts.Tools,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
|
||||
ProviderConfig: providerConfig,
|
||||
Debug: debug,
|
||||
NoExtensions: noExtensions,
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streaming,
|
||||
}
|
||||
|
||||
// Set up OAuth handler for remote MCP servers.
|
||||
// The SDK MCPAuthHandler interface is structurally identical to
|
||||
// tools.MCPAuthHandler, so any implementation satisfies both.
|
||||
if opts.MCPAuthHandler != nil {
|
||||
setupOpts.AuthHandler = opts.MCPAuthHandler
|
||||
} else {
|
||||
// Create a default handler that opens the system browser.
|
||||
defaultHandler, authErr := NewDefaultMCPAuthHandler()
|
||||
if authErr != nil {
|
||||
// Non-fatal: OAuth just won't be available for remote servers.
|
||||
charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr)
|
||||
} else {
|
||||
setupOpts.AuthHandler = defaultHandler
|
||||
}
|
||||
}
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
@@ -624,7 +707,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
k := &Kit{
|
||||
agent: agentResult.Agent,
|
||||
treeSession: treeSession,
|
||||
modelString: viper.GetString("model"),
|
||||
modelString: modelString,
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
@@ -632,6 +715,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
@@ -904,6 +988,16 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
// Pre-flight check: if the incoming context is already dead, don't
|
||||
// waste time attempting init. This catches the case where the parent
|
||||
// generation loop's context was cancelled (e.g. user ESC, step cancel)
|
||||
// between when the LLM requested the subagent tool and when this code
|
||||
// runs. We replace it with a fresh context carrying only the timeout,
|
||||
// since the subagent should be independently bounded.
|
||||
if ctx.Err() != nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -920,6 +1014,17 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
}
|
||||
}
|
||||
|
||||
// Early validation: check model format and provider before doing any
|
||||
// expensive work (MCP init, system prompt composition, etc.). This
|
||||
// gives the calling agent immediate feedback it can act on — e.g.
|
||||
// correcting a typo — instead of waiting for a full Kit.New() cycle
|
||||
// that silently falls back to the parent model.
|
||||
if model != m.modelString {
|
||||
if err := models.GetGlobalRegistry().ValidateModelString(model); err != nil {
|
||||
return nil, fmt.Errorf("invalid subagent model %q: %w", model, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Default system prompt.
|
||||
systemPrompt := cfg.SystemPrompt
|
||||
if systemPrompt == "" {
|
||||
@@ -932,9 +1037,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
tools = SubagentTools()
|
||||
}
|
||||
|
||||
// Create child Kit instance. If the requested model fails (bad name,
|
||||
// unsupported provider, etc.), fall back to the parent's model so the
|
||||
// agent gets a useful error message instead of a hard failure.
|
||||
// Create child Kit instance.
|
||||
childOpts := &Options{
|
||||
Model: model,
|
||||
SystemPrompt: systemPrompt,
|
||||
@@ -943,20 +1046,8 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
Quiet: true,
|
||||
}
|
||||
child, err := New(ctx, childOpts)
|
||||
if err != nil && model != m.modelString {
|
||||
// Model-specific failure — retry with parent's model.
|
||||
childOpts.Model = m.modelString
|
||||
child, err = New(ctx, childOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create subagent: %w", err)
|
||||
}
|
||||
// Prepend a note so the agent knows which model is actually running.
|
||||
cfg.Prompt = fmt.Sprintf(
|
||||
"[Note: requested model %q was not available, using %s instead.]\n\n%s",
|
||||
model, m.modelString, cfg.Prompt,
|
||||
)
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to create subagent: %w", err)
|
||||
if err != nil {
|
||||
return &SubagentResult{Elapsed: time.Since(start)}, fmt.Errorf("failed to create subagent: %w", err)
|
||||
}
|
||||
defer func() { _ = child.Close() }()
|
||||
|
||||
@@ -970,7 +1061,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return &SubagentResult{Elapsed: elapsed}, err
|
||||
}
|
||||
|
||||
subResult := &SubagentResult{
|
||||
@@ -996,14 +1087,14 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
|
||||
// Create a per-turn steer channel and attach it to the context so the
|
||||
// agent's PrepareStep can inject steering messages between steps.
|
||||
steerCh := make(chan string, 16)
|
||||
steerCh := make(chan agent.SteerMessage, 16)
|
||||
m.steerMu.Lock()
|
||||
m.steerCh = steerCh
|
||||
m.steerMu.Unlock()
|
||||
defer func() {
|
||||
// Drain any unconsumed steer messages before nilling the channel.
|
||||
// These are stored in leftoverSteer so DrainSteer() can return them.
|
||||
var leftover []string
|
||||
var leftover []agent.SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
@@ -1093,12 +1184,52 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
func(content string) {
|
||||
m.events.emit(ToolCallContentEvent{Content: content})
|
||||
},
|
||||
func(chunk string) {
|
||||
m.events.emit(MessageUpdateEvent{Chunk: chunk})
|
||||
},
|
||||
// <think> tag filtering: models like Qwen/DeepSeek wrap reasoning inside
|
||||
// <think>...</think> tags in the regular text stream. We intercept those
|
||||
// spans here and re-route them as ReasoningDeltaEvent/ReasoningCompleteEvent
|
||||
// so callers always receive clean, tag-free text and structured reasoning.
|
||||
func() func(chunk string) {
|
||||
const (
|
||||
thinkOpen = "<think>"
|
||||
thinkClose = "</think>"
|
||||
)
|
||||
var inThinkTag bool
|
||||
return func(chunk string) {
|
||||
remaining := chunk
|
||||
for remaining != "" {
|
||||
if inThinkTag {
|
||||
i := strings.Index(remaining, thinkClose)
|
||||
if i == -1 {
|
||||
m.events.emit(ReasoningDeltaEvent{Delta: remaining})
|
||||
return
|
||||
}
|
||||
if i > 0 {
|
||||
m.events.emit(ReasoningDeltaEvent{Delta: remaining[:i]})
|
||||
}
|
||||
inThinkTag = false
|
||||
m.events.emit(ReasoningCompleteEvent{})
|
||||
remaining = remaining[i+len(thinkClose):]
|
||||
} else {
|
||||
i := strings.Index(remaining, thinkOpen)
|
||||
if i == -1 {
|
||||
m.events.emit(MessageUpdateEvent{Chunk: remaining})
|
||||
return
|
||||
}
|
||||
if i > 0 {
|
||||
m.events.emit(MessageUpdateEvent{Chunk: remaining[:i]})
|
||||
}
|
||||
inThinkTag = true
|
||||
remaining = remaining[i+len(thinkOpen):]
|
||||
}
|
||||
}
|
||||
}
|
||||
}(),
|
||||
func(delta string) {
|
||||
m.events.emit(ReasoningDeltaEvent{Delta: delta})
|
||||
},
|
||||
func() {
|
||||
m.events.emit(ReasoningCompleteEvent{})
|
||||
},
|
||||
func(toolCallID, toolName, chunk string, isStderr bool) {
|
||||
// Emit tool output chunk event for streaming bash output
|
||||
m.events.emit(ToolOutputEvent{
|
||||
@@ -1344,6 +1475,13 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
|
||||
// This is the preferred way to redirect an agent mid-turn without cancelling
|
||||
// in-progress tool execution.
|
||||
func (m *Kit) InjectSteer(message string) {
|
||||
m.InjectSteerWithFiles(message, nil)
|
||||
}
|
||||
|
||||
// InjectSteerWithFiles sends a steering message with optional file attachments
|
||||
// (e.g. pasted images) into the currently active agent turn. Behaves like
|
||||
// InjectSteer but includes file parts in the injected user message.
|
||||
func (m *Kit) InjectSteerWithFiles(message string, files []LLMFilePart) {
|
||||
m.steerMu.Lock()
|
||||
ch := m.steerCh
|
||||
m.steerMu.Unlock()
|
||||
@@ -1351,7 +1489,7 @@ func (m *Kit) InjectSteer(message string) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- message:
|
||||
case ch <- agent.SteerMessage{Text: message, Files: files}:
|
||||
default:
|
||||
// Channel full — extremely unlikely with buffer of 16, but don't block.
|
||||
}
|
||||
@@ -1369,7 +1507,7 @@ func (m *Kit) IsGenerating() bool {
|
||||
// a turn completes so the app layer can process any steer messages that
|
||||
// arrived after the last PrepareStep fired (e.g. during a text-only response
|
||||
// with no tool calls, or after the agent finished its last step).
|
||||
func (m *Kit) DrainSteer() []string {
|
||||
func (m *Kit) DrainSteer() []agent.SteerMessage {
|
||||
m.steerMu.Lock()
|
||||
defer m.steerMu.Unlock()
|
||||
|
||||
@@ -1382,7 +1520,7 @@ func (m *Kit) DrainSteer() []string {
|
||||
|
||||
// If a turn is still active, drain from the live channel.
|
||||
if m.steerCh != nil {
|
||||
var msgs []string
|
||||
var msgs []agent.SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-m.steerCh:
|
||||
@@ -1538,5 +1676,9 @@ func (m *Kit) Close() error {
|
||||
if m.treeSession != nil {
|
||||
_ = m.treeSession.Close()
|
||||
}
|
||||
// Release the OAuth callback port if we own the handler.
|
||||
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
|
||||
_ = closer.Close()
|
||||
}
|
||||
return m.agent.Close()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,265 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MCPAuthHandler handles OAuth authorization for MCP servers.
|
||||
// Implementations control the user experience — opening a browser, showing a
|
||||
// prompt, displaying a URL, etc.
|
||||
//
|
||||
// The default implementation ([DefaultMCPAuthHandler]) opens the system browser
|
||||
// and starts a local HTTP callback server to receive the authorization code.
|
||||
type MCPAuthHandler interface {
|
||||
// RedirectURI returns the OAuth redirect URI that the callback server
|
||||
// will listen on. This is called during MCP transport setup — before any
|
||||
// OAuth errors occur — so the redirect URI can be registered with the
|
||||
// authorization server.
|
||||
RedirectURI() string
|
||||
|
||||
// HandleAuth is called when an MCP server requires OAuth authorization.
|
||||
// It receives the server name and an authorization URL that the user must
|
||||
// visit. The handler must:
|
||||
// 1. Direct the user to authURL (e.g. open browser, display URL)
|
||||
// 2. Listen for the OAuth callback on the redirect URI
|
||||
// 3. Return the full callback URL (with code and state query params)
|
||||
//
|
||||
// Return an error to abort the connection to this MCP server.
|
||||
// The context controls the overall timeout; implementations should
|
||||
// respect ctx.Done().
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// DefaultMCPAuthHandler opens the system browser and starts a local HTTP
|
||||
// callback server to receive the OAuth authorization code. It eagerly reserves
|
||||
// a TCP port on construction so [RedirectURI] is stable for the lifetime of
|
||||
// the handler.
|
||||
//
|
||||
// Create instances with [NewDefaultMCPAuthHandler] (random port) or
|
||||
// [NewDefaultMCPAuthHandlerWithPort] (explicit port).
|
||||
type DefaultMCPAuthHandler struct {
|
||||
listener net.Listener
|
||||
port int
|
||||
mu sync.Mutex // guards listener lifecycle
|
||||
}
|
||||
|
||||
// NewDefaultMCPAuthHandler creates a handler that listens on a random
|
||||
// available port on localhost. The port is reserved immediately so
|
||||
// [RedirectURI] returns a stable value. Call [DefaultMCPAuthHandler.Close]
|
||||
// when the handler is no longer needed to release the port.
|
||||
func NewDefaultMCPAuthHandler() (*DefaultMCPAuthHandler, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen for OAuth callback: %w", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
return &DefaultMCPAuthHandler{listener: listener, port: port}, nil
|
||||
}
|
||||
|
||||
// NewDefaultMCPAuthHandlerWithPort creates a handler that listens on the
|
||||
// specified port on localhost. The port is reserved immediately. Pass 0 to
|
||||
// let the OS pick a free port (equivalent to [NewDefaultMCPAuthHandler]).
|
||||
// Call [DefaultMCPAuthHandler.Close] when the handler is no longer needed.
|
||||
func NewDefaultMCPAuthHandlerWithPort(port int) (*DefaultMCPAuthHandler, error) {
|
||||
addr := fmt.Sprintf("localhost:%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to listen on %s for OAuth callback: %w", addr, err)
|
||||
}
|
||||
actualPort := listener.Addr().(*net.TCPAddr).Port
|
||||
return &DefaultMCPAuthHandler{listener: listener, port: actualPort}, nil
|
||||
}
|
||||
|
||||
// RedirectURI returns the OAuth redirect URI pointing to the local callback
|
||||
// server. This value is stable for the lifetime of the handler.
|
||||
func (h *DefaultMCPAuthHandler) RedirectURI() string {
|
||||
return fmt.Sprintf("http://localhost:%d/oauth/callback", h.port)
|
||||
}
|
||||
|
||||
// Port returns the TCP port the callback server is bound to.
|
||||
func (h *DefaultMCPAuthHandler) Port() int {
|
||||
return h.port
|
||||
}
|
||||
|
||||
// HandleAuth opens the system browser to authURL and waits for the OAuth
|
||||
// callback on the local server. It returns the full callback URL including
|
||||
// query parameters (code, state, etc.).
|
||||
//
|
||||
// If the context has no deadline, a default 2-minute timeout is applied.
|
||||
// The callback server is started for each HandleAuth call and shut down
|
||||
// before returning.
|
||||
func (h *DefaultMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
|
||||
h.mu.Lock()
|
||||
listener := h.listener
|
||||
h.mu.Unlock()
|
||||
|
||||
if listener == nil {
|
||||
return "", fmt.Errorf("OAuth callback handler is closed")
|
||||
}
|
||||
|
||||
// Apply default timeout if the context has no deadline.
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Channel receives the full callback URL from the HTTP handler.
|
||||
callbackCh := make(chan string, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Reconstruct the full callback URL as the caller expects it.
|
||||
fullURL := fmt.Sprintf("http://localhost:%d%s", h.port, r.RequestURI)
|
||||
|
||||
// Send the callback URL to the waiting goroutine (non-blocking).
|
||||
select {
|
||||
case callbackCh <- fullURL:
|
||||
default:
|
||||
}
|
||||
|
||||
// Respond with a friendly HTML page so the user knows they can
|
||||
// close the browser tab.
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprint(w, oauthSuccessHTML)
|
||||
})
|
||||
|
||||
server := &http.Server{
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// Start serving on the pre-reserved listener. We need to create a new
|
||||
// listener on the same port because http.Server.Serve takes ownership
|
||||
// and closes the listener when done. The original listener is kept open
|
||||
// to reserve the port; we create a second listener via SO_REUSEADDR
|
||||
// semantics (Go's default on most platforms) or, more reliably, we
|
||||
// temporarily release and re-acquire.
|
||||
//
|
||||
// Strategy: use the held listener directly for Serve. After Serve
|
||||
// returns (due to Shutdown), re-acquire the listener to keep the port
|
||||
// reserved for future HandleAuth calls.
|
||||
h.mu.Lock()
|
||||
serveListener := h.listener
|
||||
h.listener = nil // Serve will close it
|
||||
h.mu.Unlock()
|
||||
|
||||
if serveListener == nil {
|
||||
return "", fmt.Errorf("OAuth callback handler is closed")
|
||||
}
|
||||
|
||||
// Start the HTTP server in a background goroutine.
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := server.Serve(serveListener)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
serverErrCh <- err
|
||||
}
|
||||
close(serverErrCh)
|
||||
}()
|
||||
|
||||
// Re-acquire the listener after Serve completes (deferred).
|
||||
defer func() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
|
||||
// Re-reserve the port for future HandleAuth calls.
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.listener == nil {
|
||||
newListener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", h.port))
|
||||
if err == nil {
|
||||
h.listener = newListener
|
||||
}
|
||||
// If re-listen fails, the handler degrades gracefully — the
|
||||
// next HandleAuth call will return an error.
|
||||
}
|
||||
}()
|
||||
|
||||
// Open the system browser.
|
||||
if err := openBrowser(authURL); err != nil {
|
||||
// Browser open is best-effort; the user can still navigate manually.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Wait for the callback, a server error, or context cancellation.
|
||||
select {
|
||||
case url := <-callbackCh:
|
||||
return url, nil
|
||||
case err := <-serverErrCh:
|
||||
return "", fmt.Errorf("OAuth callback server error for %q: %w", serverName, err)
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("OAuth authorization timed out for %q: %w", serverName, ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases the reserved port and shuts down the handler. After Close,
|
||||
// HandleAuth will return an error. Close is safe to call multiple times.
|
||||
func (h *DefaultMCPAuthHandler) Close() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.listener != nil {
|
||||
err := h.listener.Close()
|
||||
h.listener = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openBrowser opens the default system browser to the given URL. This is a
|
||||
// best-effort operation — errors are returned but callers typically ignore
|
||||
// them since the user can navigate manually.
|
||||
func openBrowser(url string) error {
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
return exec.Command("xdg-open", url).Start()
|
||||
case "windows":
|
||||
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
|
||||
case "darwin":
|
||||
return exec.Command("open", url).Start()
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// oauthSuccessHTML is the HTML page returned to the browser after a
|
||||
// successful OAuth callback.
|
||||
const oauthSuccessHTML = `<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<title>Authorization Successful</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background: #f8f9fa;
|
||||
color: #333;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
}
|
||||
h1 { color: #22863a; }
|
||||
p { color: #586069; margin-top: 0.5rem; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>✓ Authorization Successful</h1>
|
||||
<p>You can close this tab and return to the terminal.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
@@ -0,0 +1,68 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// CLIMCPAuthHandler wraps a [DefaultMCPAuthHandler] and prints status messages
|
||||
// to a writer (typically stderr) so the user knows what's happening during
|
||||
// OAuth authorization. This is the handler used by the CLI/TUI binary.
|
||||
//
|
||||
// For TUI integration, set NotifyFunc to route messages through the TUI's
|
||||
// event system instead of (or in addition to) the writer.
|
||||
type CLIMCPAuthHandler struct {
|
||||
inner *DefaultMCPAuthHandler
|
||||
w io.Writer
|
||||
|
||||
// NotifyFunc, when set, is called with status messages instead of writing
|
||||
// to the writer. This allows the TUI to display system messages in the
|
||||
// chat stream. If nil, messages are written to w.
|
||||
NotifyFunc func(serverName, message string)
|
||||
}
|
||||
|
||||
// NewCLIMCPAuthHandler creates a CLI auth handler that prints status messages
|
||||
// to stderr and delegates the actual OAuth flow to a [DefaultMCPAuthHandler].
|
||||
func NewCLIMCPAuthHandler() (*CLIMCPAuthHandler, error) {
|
||||
inner, err := NewDefaultMCPAuthHandler()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CLIMCPAuthHandler{inner: inner, w: os.Stderr}, nil
|
||||
}
|
||||
|
||||
// RedirectURI returns the OAuth redirect URI from the inner handler.
|
||||
func (h *CLIMCPAuthHandler) RedirectURI() string {
|
||||
return h.inner.RedirectURI()
|
||||
}
|
||||
|
||||
// HandleAuth prints status messages and delegates to the inner handler.
|
||||
func (h *CLIMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
|
||||
h.notify(serverName, fmt.Sprintf("🔐 MCP server %q requires authentication. Opening browser...", serverName))
|
||||
h.notify(serverName, fmt.Sprintf(" If the browser doesn't open, visit:\n %s", authURL))
|
||||
|
||||
callbackURL, err := h.inner.HandleAuth(ctx, serverName, authURL)
|
||||
if err != nil {
|
||||
h.notify(serverName, fmt.Sprintf("✗ Authentication failed for %q: %v", serverName, err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
h.notify(serverName, fmt.Sprintf("✓ Authenticated with %q", serverName))
|
||||
return callbackURL, nil
|
||||
}
|
||||
|
||||
// Close releases the inner handler's resources.
|
||||
func (h *CLIMCPAuthHandler) Close() error {
|
||||
return h.inner.Close()
|
||||
}
|
||||
|
||||
// notify sends a message through NotifyFunc if set, otherwise writes to w.
|
||||
func (h *CLIMCPAuthHandler) notify(serverName, message string) {
|
||||
if h.NotifyFunc != nil {
|
||||
h.NotifyFunc(serverName, message)
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintln(h.w, message)
|
||||
}
|
||||
@@ -91,7 +91,11 @@ api.OnAgentStart(func(e ext.AgentStartEvent, ctx ext.Context) {
|
||||
// Agent finished responding.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
// e.Response string
|
||||
// e.StopReason string — "completed", "cancelled", "error"
|
||||
// e.StopReason string — "error" (on failure), "completed" (when LLM returns
|
||||
// empty stop reason), or the raw LLM provider value passed through
|
||||
// (e.g. "stop", "end_turn", "max_tokens", "tool_use").
|
||||
// To detect errors, check e.StopReason == "error".
|
||||
// Do NOT compare against "completed" for success — instead check != "error".
|
||||
})
|
||||
```
|
||||
|
||||
|
||||
@@ -104,6 +104,8 @@ Define custom models in your `.kit.yml` for use with the `custom` provider. This
|
||||
customModels:
|
||||
my-model:
|
||||
name: "My Custom Model"
|
||||
baseUrl: "http://localhost:8080/v1"
|
||||
apiKey: "my-secret-key"
|
||||
reasoning: true
|
||||
temperature: true
|
||||
cost:
|
||||
@@ -119,6 +121,8 @@ customModels:
|
||||
| Field | Type | Required | Description |
|
||||
|-------|------|----------|-------------|
|
||||
| `name` | string | Yes | Display name for the model |
|
||||
| `baseUrl` | string | No | Per-model base URL override; when set, `--provider-url` is not required |
|
||||
| `apiKey` | string | No | Per-model API key override |
|
||||
| `reasoning` | bool | No | Whether the model supports reasoning/thinking |
|
||||
| `temperature` | bool | No | Whether the model supports temperature adjustment |
|
||||
| `cost.input` | float | No | Cost per 1K input tokens |
|
||||
@@ -126,7 +130,13 @@ customModels:
|
||||
| `limit.context` | int | Yes | Maximum context window in tokens |
|
||||
| `limit.output` | int | No | Maximum output tokens |
|
||||
|
||||
Use with a custom provider URL:
|
||||
Use with a per-model `baseUrl` (no `--provider-url` needed):
|
||||
|
||||
```bash
|
||||
kit --model custom/my-model "Hello"
|
||||
```
|
||||
|
||||
Or override the base URL at runtime:
|
||||
|
||||
```bash
|
||||
kit --provider-url "http://localhost:8080/v1" --model custom/my-model "Hello"
|
||||
|
||||
Reference in New Issue
Block a user