Compare commits

...

17 Commits

Author SHA1 Message Date
Ed Zynda b2bd016135 fix(tui): redirect log output to file to prevent TUI corruption
- Add tea.LogToFile in runInteractiveModeBubbleTea to send stdlib log
  output to /tmp/kit/kit.log instead of stderr
- Replace charmbracelet/log with stdlib log in extensions loader,
  runner, watcher, prompts loader, and pkg/kit so all log calls go
  through the redirected stdlib logger
- Leave charmbracelet/log in CLI-only commands (install, acp) and
  acpserver where stderr logging is correct
2026-04-07 21:20:04 +03:00
Ed Zynda 812dedaea2 feat(pkg/kit): add SessionManager interface for custom session backends
Add SessionManager interface to allow pluggable session storage backends.
This enables users to implement custom session managers for databases,
cloud storage, or other persistence mechanisms instead of the default
JSONL file-based TreeManager.

Changes:
- Add SessionManager interface with methods for message storage,
  tree navigation, compaction, and extension data
- Add treeManagerAdapter to wrap existing TreeManager for backward compatibility
- Update Kit struct to use SessionManager interface instead of concrete type
- Add SessionManager option to Options struct
- Update all session-related methods to use interface
- Add documentation for custom SessionManager usage

The default behavior is preserved - when no SessionManager is provided,
Kit automatically uses the TreeManager via the adapter.
2026-04-07 17:41:46 +03:00
Ed Zynda f65b6737f2 feat(sdk): add SkipConfig and DisableCoreTools options
Add two new Options fields for programmatic SDK usage:

- SkipConfig: Skip .kit.yml file loading while still using viper defaults
  and environment variables. Useful for fully programmatic configuration.

- DisableCoreTools: Allow creating agents with 0 tools (chat-only mode) or
  with only custom tools. When true and Tools is empty, no tools are loaded.
  When combined with custom Tools, only those tools are loaded.

Updates documentation in README, pkg/kit/README, skills/kit-sdk/SKILL,
and www/pages/sdk/options.
2026-04-07 17:10:58 +03:00
Ed Zynda 5d45aa196b fix(watcher): remove debug logging that corrupts TUI
Remove charmbracelet/log debug statements from the file watcher that
were writing directly to stderr, corrupting the Bubble Tea terminal UI.

- Remove log.Debug calls for directory operations and file changes
- Remove log.Warn for watcher errors (silently ignore instead)
- Remove the charmbracelet/log import entirely
2026-04-07 16:31:29 +03:00
Ed Zynda debb39f56c fix(ui): show MCP tools in /tools and status bar after async loading
Background MCP tool loading (added in 7e54710) caused tools to not appear
in the UI because tool names and counts were captured at startup before
loading completed. This adds:

- MCPToolsReadyEvent and MCPServerLoadedEvent for progress notifications
- Dynamic GetToolNames/GetMCPToolCount callbacks for live updates
- Per-server status messages as each MCP server finishes loading
- Refresh handlers to update /tools output and status bar when ready
2026-04-07 16:29:09 +03:00
Ed Zynda 7ce6f4fd9e fix(watcher): dynamically watch new subdirectories for skill/prompt reload
- Detect new subdirectory creation in the fsnotify event loop and add
  it to the watcher so files created inside trigger reload events
- Handle cp -r case by checking if new directories already contain
  matching files and scheduling an immediate debounced reload
- Add dirContainsMatchingFiles helper method
- Add tests for both new-subdirectory and copy-with-existing-files cases
2026-04-07 15:01:18 +03:00
Ed Zynda c2f2bdb3d3 feat: auto-reload custom prompts and skills on file change
- Add internal/watcher package with general-purpose ContentWatcher
  using fsnotify, configurable file extensions, and debouncing
- Add ContentReloadEvent and App.NotifyContentReload() for TUI signaling
- Add GetPromptTemplates/GetSkillItems callback fields on AppModelOptions
  following the existing GetExtensionCommands lazy-provider pattern
- Add Kit.ReloadSkills() to re-discover skills from disk
- Wire fsnotify watcher for .kit/prompts/, .kit/skills/, .agents/skills/,
  and global config directories, triggering on .md/.txt changes
- TUI refreshes autocomplete entries and skill list on reload
2026-04-07 14:09:59 +03:00
Ed Zynda 201d14804e fix(ui): prevent double-rendered messages after reasoning-only responses
- Always fire onResponse callback even when response text is empty so
  ResponseCompleteEvent reaches the TUI and resets the StreamComponent
- Check for existing StreamingMessageItem in flushStreamAndPendingUserMessages
  before creating a new StyledMessageItem to avoid duplicate content
- Mark trailing StreamingMessageItem complete on StepComplete, StepCancelled,
  and StepError to freeze live timers and prevent dangling streaming state
2026-04-07 13:52:30 +03:00
Ed Zynda 7e54710d4a perf(agent): load MCP tools asynchronously to speed up startup
Load MCP server tools in the background so the UI appears immediately
instead of blocking until all servers connect. The first LLM call
automatically waits for tools to be ready before proceeding.

Key changes:
- NewAgent() starts MCP loading in a background goroutine and returns
  immediately with core/extension tools only
- GenerateWithLoop() calls ensureMCPTools() to lazily wait and rebuild
  the fantasy agent with full tool set before first LLM call
- Parallelize LoadTools() across all configured MCP servers
- Add WaitForMCPTools() and MCPToolsReady() for status checking
- Refactor SetModel/SetExtraTools to use shared rebuildFantasyAgent()
- Expose async MCP status methods in public SDK
2026-04-07 13:36:10 +03:00
Ed Zynda 88870be4d2 feat: add frequency-penalty and presence-penalty parameters
- Add --frequency-penalty and --presence-penalty CLI flags (0.0-2.0)
- Wire through config, viper, ProviderConfig, and fantasy agent options
- Support in config file, env vars (KIT_FREQUENCY_PENALTY), and SDK
- Pass to Ollama via options map (frequency_penalty, presence_penalty)
- Apply on both initial agent creation and runtime model swap
2026-04-06 10:52:33 +03:00
Ed Zynda 46bf809715 chore(models): update embedded models.json from models.dev
- Providers: 97 -> 109 (+12 new)
- Models: 3039 -> 4156 (+1117 new)
- New providers: alibaba-coding-plan, alibaba-coding-plan-cn, clarifai,
  dinference, drun, llmgateway, perplexity-agent, tencent-coding-plan,
  the-grid-ai, xiaomi-token-plan-ams, xiaomi-token-plan-cn,
  xiaomi-token-plan-sgp
2026-04-06 09:50:43 +03:00
Ed Zynda e19e9642a2 feat(session): include system prompt and model in shared sessions
Add SystemPromptEntry type to capture system prompt, model, and provider
when sharing sessions via /share command. The entry is inserted into the
JSONL after the header and displayed in the web viewer as a collapsible
section with a model badge.

- Add SystemPromptEntry with Content, Model, and Provider fields
- Capture current system prompt and model at share time
- Display in web viewer with collapsible UI and model badge
- Update documentation for /share command
2026-04-04 19:33:02 +03:00
Ed Zynda 32675b8b35 chore(deps): update all go module dependencies
- mcp-go v0.46.0 → v0.47.0
- herald v0.11.0 → v0.13.0
- herald-md v0.2.0 → v0.3.0
- smithy-go v1.24.2 → v1.24.3
- otel v1.42.0 → v1.43.0
- googleapis/gax-go v2.20.0 → v2.21.0
- google.golang.org/api v0.273.1 → v0.274.0
- runewidth v0.0.21 → v0.0.22
- azure-sdk-internal v1.11.2 → v1.12.0
- various aws-sdk-go-v2 sub-modules patched
2026-04-04 18:11:56 +03:00
Ed Zynda aecce001ee feat(mcp): add OAuth support for remote MCP servers
- Add MCPAuthHandler interface at SDK level (pkg/kit/) so all consumers
  (CLI, TUI, SDK embedders) control the OAuth UX through one interface
- Default handler opens system browser + local callback server with PKCE
- CLIMCPAuthHandler wraps default with status messages (stderr pre-TUI,
  system messages via TUI event system once running)
- Always enable OAuth on remote transports (streamable HTTP, SSE) when
  handler is configured; harmless for servers that don't need it
- Dynamic client registration when no client ID is pre-configured
- File-based TokenStore persists tokens to ~/.config/.kit/mcp_tokens.json
  keyed by server URL so users don't re-auth on restart
- Catch OAuthAuthorizationRequiredError at connection init (startup) and
  tool execution (mid-session token expiry), run auth flow, retry once
- Fix error wrapping (%v -> %w) in connection pool so errors.As can
  unwrap through the chain to find OAuth errors
- Thread AuthHandler through MCPToolManager -> AgentConfig ->
  AgentCreationOptions -> AgentSetupOptions -> kit.Options
2026-04-04 17:41:57 +03:00
Ed Zynda 32d73171fd fix(extensions): write manifest Include in single pass and preserve on update
- InstallWithInclude wrote manifest twice via two different code paths,
  with the first write missing Include; unify into shared install() method
  that writes the manifest once with all fields including Include
- Update() now reads the existing manifest entry to preserve Include and
  Installed timestamp instead of constructing a fresh entry from scratch
2026-04-04 17:19:00 +03:00
Ed Zynda 265fd2ec0c fix(extensions): skip _test.go files and non-extension examples/ subdirs
- Filter out _test.go files in findExtensionsInDir, findExtensionsInRepo,
  and ScanForExtensions to prevent Yaegi from loading test files
- Narrow examples/ traversal so only recognized extension directories
  (extensions/, ext/, *-ext/, *-extensions/) are scanned, not arbitrary
  subdirs like examples/sdk/ that import pkg/kit
2026-04-04 16:44:13 +03:00
Ed Zynda efebf2eba6 fix(kit-telegram): add typing indicator and config fallback to global path
- Send sendChatAction("typing") every 4s while agent is processing,
  started on AgentStart and stopped on AgentEnd/SessionShutdown
- configPath() now checks project-local .kit/ first, then falls back
  to ~/.config/kit/kit-telegram.json for cross-project portability
2026-04-04 16:33:08 +03:00
44 changed files with 3416 additions and 560 deletions
+6 -1
View File
@@ -531,7 +531,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
+176 -17
View File
@@ -7,6 +7,7 @@ import (
"image/color"
"log"
"os"
"path/filepath"
"strings"
tea "charm.land/bubbletea/v2"
@@ -18,6 +19,7 @@ import (
"github.com/mark3labs/kit/internal/prompts"
"github.com/mark3labs/kit/internal/ui"
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/watcher"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -48,12 +50,14 @@ var (
noSessionFlag bool // --no-session: ephemeral mode, no persistence
// Model generation parameters
maxTokens int
temperature float32
topP float32
topK int32
stopSequences []string
thinkingLevel string
maxTokens int
temperature float32
topP float32
topK int32
frequencyPenalty float32
presencePenalty float32
stopSequences []string
thinkingLevel string
// Ollama-specific parameters
numGPU int32
@@ -291,6 +295,8 @@ func init() {
flags.Float32Var(&temperature, "temperature", 0.7, "controls randomness in responses (0.0-1.0)")
flags.Float32Var(&topP, "top-p", 0.95, "controls diversity via nucleus sampling (0.0-1.0)")
flags.Int32Var(&topK, "top-k", 40, "controls diversity by limiting top K tokens to sample from")
flags.Float32Var(&frequencyPenalty, "frequency-penalty", 0.0, "penalizes tokens based on frequency of appearance (0.0-2.0)")
flags.Float32Var(&presencePenalty, "presence-penalty", 0.0, "penalizes tokens based on whether they have appeared (0.0-2.0)")
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, minimal, low, medium, high")
@@ -313,6 +319,8 @@ func init() {
_ = viper.BindPFlag("temperature", rootCmd.PersistentFlags().Lookup("temperature"))
_ = viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
_ = viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
_ = viper.BindPFlag("frequency-penalty", rootCmd.PersistentFlags().Lookup("frequency-penalty"))
_ = viper.BindPFlag("presence-penalty", rootCmd.PersistentFlags().Lookup("presence-penalty"))
_ = viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
_ = viper.BindPFlag("thinking-level", rootCmd.PersistentFlags().Lookup("thinking-level"))
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
@@ -717,13 +725,33 @@ func runNormalMode(ctx context.Context) error {
// Build Kit options from CLI flags and create the SDK instance.
// kit.New() handles: config → skills → agent → session → extension bridge.
authHandler, authErr := kit.NewCLIMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote MCP servers.
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
}
// appInstancePtr is used to break the circular dependency between
// kit.New (which needs the OnMCPServerLoaded callback) and app.New
// (which is needed by the callback to send events to the TUI).
var appInstancePtr *app.App
kitOpts := &kit.Options{
Quiet: quietFlag,
Debug: debugMode,
NoSession: noSessionFlag,
Continue: continueFlag,
SessionPath: sessionPath,
AutoCompact: autoCompactFlag,
Quiet: quietFlag,
Debug: debugMode,
NoSession: noSessionFlag,
Continue: continueFlag,
SessionPath: sessionPath,
AutoCompact: autoCompactFlag,
MCPAuthHandler: authHandler,
// This callback is called when each MCP server finishes loading.
// We use a closure that captures appInstancePtr which is set after
// app.New() is called below.
OnMCPServerLoaded: func(serverName string, toolCount int, err error) {
if appInstancePtr != nil {
appInstancePtr.NotifyMCPServerLoaded(serverName, toolCount, err)
}
},
CLI: &kit.CLIOptions{
MCPConfig: mcpConfig,
ShowSpinner: true,
@@ -794,8 +822,16 @@ func runNormalMode(ctx context.Context) error {
}
appInstance := app.New(appOpts, messages)
appInstancePtr = appInstance // Wire up the MCP server loaded callback.
defer appInstance.Close()
// Wire OAuth handler to route messages through the TUI once it's running.
if authHandler != nil {
authHandler.NotifyFunc = func(serverName, message string) {
appInstance.PrintFromExtension("info", message)
}
}
// Buffer for extension messages during startup (printed after startup banner).
var startupExtensionMessages []string
@@ -1600,6 +1636,49 @@ func runNormalMode(ctx context.Context) error {
})
}
// Build prompt template and skill item provider callbacks for hot-reload.
// These are called by the TUI when ContentReloadEvent fires.
getPromptTemplates := func() []*prompts.PromptTemplate {
if noPromptTemplates {
return nil
}
homeDir, _ := os.UserHomeDir()
cwd, _ := os.Getwd()
tpls, _, err := prompts.LoadAll(prompts.LoadOptions{
Cwd: cwd,
HomeDir: homeDir,
ExtraPaths: promptTemplatePaths,
ConfigPaths: viper.GetStringSlice("prompts"),
IncludeDefaults: true,
})
if err != nil {
log.Printf("Warning: failed to reload prompt templates: %v", err)
}
return tpls
}
getSkillItems := func() []ui.SkillItem {
// Re-discover skills from disk.
if err := kitInstance.ReloadSkills(); err != nil {
log.Printf("Warning: failed to reload skills: %v", err)
return nil
}
cwd, _ := os.Getwd()
var items []ui.SkillItem
for _, s := range kitInstance.GetSkills() {
source := "user"
if strings.HasPrefix(s.Path, cwd) {
source = "project"
}
items = append(items, ui.SkillItem{
Name: s.Name,
Path: s.Path,
Source: source,
})
}
return items
}
// Build extension UI providers once (shared between both modes).
getWidgets := widgetProviderForUI(kitInstance)
getHeader := headerProviderForUI(kitInstance)
@@ -1615,6 +1694,25 @@ func runNormalMode(ctx context.Context) error {
return extensionCommandsForUI(kitInstance)
}
// Build dynamic tool name and MCP tool count providers. These are called
// by the TUI when MCPToolsReadyEvent fires to refresh the /tools list
// and startup info bar after background MCP tool loading completes.
getToolNames := func() []string {
return kitInstance.GetToolNames()
}
getMCPToolCount := func() int {
return kitInstance.GetMCPToolCount()
}
// Start a goroutine that waits for background MCP tool loading to
// complete and notifies the TUI so it can refresh tool names and counts.
if len(mcpConfig.MCPServers) > 0 {
go func() {
_ = kitInstance.WaitForMCPTools()
appInstance.NotifyMCPToolsReady()
}()
}
// Build model switching callbacks for the /model command.
setModelForUI := func(modelString string) error {
err := kitInstance.SetModel(context.Background(), modelString)
@@ -1695,9 +1793,54 @@ func runNormalMode(ctx context.Context) error {
}
}
// Start file watchers for automatic prompt and skill hot-reload.
{
homeDir, _ := os.UserHomeDir()
cwd, _ := os.Getwd()
// Collect prompt template directories.
promptDirs := watcher.CollectDirs(
[]string{
filepath.Join(homeDir, ".kit", "prompts"),
filepath.Join(cwd, ".kit", "prompts"),
},
append(promptTemplatePaths, viper.GetStringSlice("prompts")...),
)
// Collect skill directories.
skillDirs := watcher.CollectDirs(
[]string{
filepath.Join(homeDir, ".config", "kit", "skills"),
filepath.Join(cwd, ".agents", "skills"),
filepath.Join(cwd, ".kit", "skills"),
},
nil,
)
// Combine all content directories and start a single watcher.
allContentDirs := append(promptDirs, skillDirs...)
if len(allContentDirs) > 0 {
contentWatcher, watchErr := watcher.New(watcher.Options{
Dirs: allContentDirs,
Extensions: []string{".md", ".txt"},
Label: "prompts/skills",
OnReload: func() {
log.Printf("auto-reloading prompts and skills")
appInstance.NotifyContentReload()
},
})
if watchErr != nil {
log.Printf("content file watcher not started: %v", watchErr)
} else {
go contentWatcher.Start(ctx)
defer func() { _ = contentWatcher.Close() }()
}
}
}
// Check if running in non-interactive mode
if positionalPrompt != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
}
// Quiet mode is not allowed in interactive mode
@@ -1705,7 +1848,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet requires a prompt")
}
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
}
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
@@ -1718,7 +1861,7 @@ func runNormalMode(ctx context.Context) error {
//
// When --no-exit is set, after the prompt completes the interactive BubbleTea
// TUI is started so the user can continue the conversation.
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
// Expand @file references in the prompt before sending to the agent.
if cwd, err := os.Getwd(); err == nil {
prompt = ui.ProcessFileAttachments(prompt, cwd)
@@ -1761,7 +1904,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
// If --no-exit was requested, hand off to the interactive TUI.
if noExit {
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
}
return nil
@@ -1859,7 +2002,19 @@ func writeJSONError(err error) {
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
// background goroutines (watchers, extension handlers, SDK internals)
// will visually corrupt the terminal.
logDir := filepath.Join(os.TempDir(), "kit")
_ = os.MkdirAll(logDir, 0o700)
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
if logErr == nil {
defer func() { _ = logFile.Close() }()
}
// Determine terminal size; fall back gracefully.
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil || termWidth == 0 {
@@ -1878,13 +2033,17 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
Height: termHeight,
ServerNames: serverNames,
ToolNames: toolNames,
GetToolNames: getToolNames,
GetMCPToolCount: getMCPToolCount,
MCPToolCount: mcpToolCount,
ExtensionToolCount: extensionToolCount,
UsageTracker: usageTracker,
ExtensionCommands: extCommands,
PromptTemplates: promptTemplates,
GetPromptTemplates: getPromptTemplates,
ContextPaths: contextPaths,
SkillItems: skillItems,
GetSkillItems: getSkillItems,
StartupExtensionMessages: startupExtensionMessages,
GetWidgets: getWidgets,
GetHeader: getHeader,
+74 -1
View File
@@ -168,6 +168,10 @@ var (
// Test
pendingTest *PendingTest
// Typing indicator
typingTicker *time.Ticker
typingStop chan struct{}
// Latest context for background goroutines
latestCtx ext.Context
latestCtxSet bool
@@ -203,8 +207,23 @@ func configDir() string {
return filepath.Join(home, ".config", "kit")
}
func globalConfigDir() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".config", "kit")
}
func configPath() string {
return filepath.Join(configDir(), "kit-telegram.json")
// Prefer project-local config, fall back to global config.
local := filepath.Join(configDir(), "kit-telegram.json")
if _, err := os.Stat(local); err == nil {
return local
}
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
if _, err := os.Stat(global); err == nil {
return global
}
// Neither exists — return local path (will be created on connect).
return local
}
func failureLogDir() string {
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
return &msg, nil
}
func tgSendChatAction(token string, chatID int64, action string) error {
_, err := telegramRequest(token, "sendChatAction", map[string]any{
"chat_id": chatID,
"action": action,
}, 15)
return err
}
// ──────────────────────────────────────────────
// Error classification
// ──────────────────────────────────────────────
@@ -637,6 +664,48 @@ func clearHealthTimer() {
}
}
// ──────────────────────────────────────────────
// Typing indicator
// ──────────────────────────────────────────────
func startTypingLoop() {
mu.Lock()
defer mu.Unlock()
if typingTicker != nil {
return
}
cfg := config
if cfg == nil || !cfg.Enabled {
return
}
token := cfg.BotToken
chatID := cfg.ChatID
typingTicker = time.NewTicker(4 * time.Second)
typingStop = make(chan struct{})
// Send immediately, then every 4 seconds.
go func() {
tgSendChatAction(token, chatID, "typing")
for {
select {
case <-typingTicker.C:
tgSendChatAction(token, chatID, "typing")
case <-typingStop:
return
}
}
}()
}
func stopTypingLoop() {
mu.Lock()
defer mu.Unlock()
if typingTicker != nil {
typingTicker.Stop()
close(typingStop)
typingTicker = nil
}
}
// ──────────────────────────────────────────────
// Polling lifecycle
// ──────────────────────────────────────────────
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
mu.Unlock()
sendShutdownDisconnectedMessage()
stopTypingLoop()
stopPolling()
clearHealthTimer()
clearFooter()
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
mu.Unlock()
report("run.start", fmt.Sprintf("runId=%d", run.ID))
startTypingLoop()
ensureProgressMessage()
updateProgressMessage()
})
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
run := activeRun
mu.Unlock()
stopTypingLoop()
if run != nil {
// Capture final response from event
if e.Response != "" {
+19 -20
View File
@@ -14,10 +14,14 @@ require (
github.com/charmbracelet/fang v1.0.0
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b
github.com/clipperhouse/displaywidth v0.11.0
github.com/clipperhouse/uax29/v2 v2.7.0
github.com/coder/acp-go-sdk v0.6.3
github.com/indaco/herald v0.11.0
github.com/indaco/herald-md v0.2.0
github.com/mark3labs/mcp-go v0.46.0
github.com/fsnotify/fsnotify v1.9.0
github.com/indaco/herald v0.13.0
github.com/indaco/herald-md v0.3.0
github.com/mark3labs/mcp-go v0.47.0
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/traefik/yaegi v0.16.1
@@ -31,11 +35,11 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.13 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
@@ -43,17 +47,16 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/aws/smithy-go v1.24.3 // indirect
github.com/catppuccin/go v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
github.com/charmbracelet/colorprofile v0.4.3 // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
@@ -62,25 +65,21 @@ require (
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kaptinlin/go-i18n v0.3.0 // indirect
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
@@ -106,16 +105,16 @@ require (
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/otel v1.42.0 // indirect
go.opentelemetry.io/otel/metric v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/api v0.273.1 // indirect
google.golang.org/api v0.274.0 // indirect
google.golang.org/genai v1.52.1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
@@ -130,7 +129,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.21 // indirect
github.com/mattn/go-runewidth v0.0.22 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
+34 -34
View File
@@ -18,12 +18,12 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
@@ -38,10 +38,10 @@ github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg=
github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY=
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk=
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
@@ -56,14 +56,14 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3x
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
@@ -173,18 +173,18 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/indaco/herald v0.11.0 h1:tJZc6DAzfUYVWQsU9Lik4RcKR7TtiRfnBIu/oXjp/WA=
github.com/indaco/herald v0.11.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.2.0 h1:kGFsKE+Swzf7EyTUFx7FL1d1jwiKoJRcxqYo2bhUgS0=
github.com/indaco/herald-md v0.2.0/go.mod h1:64DKh1wSQUsWXTuIYklFzSheJKkW0+FpaqyKqwids3g=
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
@@ -201,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mark3labs/mcp-go v0.47.0 h1:h44yeM3DduDyQgzImYWu4pt6VRkqP/0p/95AGhWngnA=
github.com/mark3labs/mcp-go v0.47.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.22 h1:76lXsPn6FyHtTY+jt2fTTvsMUCZq1k0qwRsAMuxzKAk=
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -276,16 +276,16 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.6
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
@@ -309,8 +309,8 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.273.1 h1:L7G/TmpAMz0nKx/ciAVssVmWQiOF6+pOuXeKrWVsquY=
google.golang.org/api v0.273.1/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
+223 -136
View File
@@ -25,11 +25,21 @@ type AgentConfig struct {
StreamingEnabled bool
DebugLogger tools.DebugLogger
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
// If nil, remote MCP servers that require OAuth will fail to connect.
AuthHandler tools.MCPAuthHandler
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. This allows SDK users to provide a custom tool set (e.g.
// CodingTools or tools with a custom WorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper is an optional function that wraps the combined tool list
// before it is passed to the LLM agent. Used by the extensions system
// to intercept tool calls/results.
@@ -38,6 +48,11 @@ type AgentConfig struct {
// ExtraTools are additional tools to include alongside core and MCP tools.
// Used by extensions to register custom tools.
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). The callback receives the server
// name, tool count, and any error. Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// ToolCallHandler is a function type for handling tool calls as they happen.
@@ -83,6 +98,10 @@ type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCrea
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
// AgentTool implementations — no MCP layer, no serialization overhead.
// Additional tools from external MCP servers can be loaded alongside core tools.
//
// When MCP servers are configured, tool loading happens in the background so the
// agent (and UI) can start immediately. The first LLM call automatically waits
// for MCP tools to finish loading before proceeding.
type Agent struct {
toolManager *tools.MCPToolManager
fantasyAgent fantasy.Agent
@@ -96,6 +115,18 @@ type Agent struct {
coreTools []fantasy.AgentTool
extraTools []fantasy.AgentTool
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool // stored for SetModel rebuild
// providerOptions and modelConfig are stored for rebuilding the fantasy
// agent when MCP tools arrive asynchronously or on SetModel.
providerOptions fantasy.ProviderOptions
skipMaxOutputTokens bool
modelConfig *models.ProviderConfig
// mcpReady is closed when background MCP tool loading completes (success
// or failure). nil when no MCP servers are configured.
mcpReady chan struct{}
// mcpErr holds any error from background MCP loading.
mcpErr error
}
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
@@ -114,7 +145,10 @@ type GenerateWithLoopResult struct {
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
// External MCP tools are loaded from the config if any MCP servers are configured.
// If MCP servers are configured, their tools are loaded in the background —
// the agent returns immediately and is usable with core tools only. The first
// LLM call (GenerateWithLoop) automatically waits for MCP tools to finish
// loading and rebuilds the agent with the full tool set.
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Create the LLM provider
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
@@ -124,34 +158,22 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Register core tools (direct AgentTool implementations, no MCP overhead).
// Use caller-provided tools if set, otherwise default to all core tools.
coreTools := agentConfig.CoreTools
if len(coreTools) == 0 {
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
var coreTools []fantasy.AgentTool
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
// Explicitly zero tools - chat-only mode
coreTools = nil
} else if len(agentConfig.CoreTools) > 0 {
// Custom tools provided - use them
coreTools = agentConfig.CoreTools
} else {
// Default: load all core tools
coreTools = core.AllTools()
}
// Build the combined tool list: core tools + any external MCP tools
// Build the initial tool list: core tools + extension tools (no MCP yet).
allTools := make([]fantasy.AgentTool, len(coreTools))
copy(allTools, coreTools)
// Load external MCP tools if configured
var toolManager *tools.MCPToolManager
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
toolManager = tools.NewMCPToolManager()
toolManager.SetModel(providerResult.Model)
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
// MCP tool loading failures are non-fatal; core tools still work
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
} else {
mcpTools := toolManager.GetTools()
allTools = append(allTools, mcpTools...)
}
}
// Append any extra tools provided by extensions.
if len(agentConfig.ExtraTools) > 0 {
allTools = append(allTools, agentConfig.ExtraTools...)
@@ -163,6 +185,144 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
}
// Build agent options
agentOpts := buildAgentOptions(agentConfig, providerResult, allTools)
// Create the agent
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Determine provider type from model string
providerType := "default"
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
providerType = p
}
}
a := &Agent{
fantasyAgent: fantasyAgent,
model: providerResult.Model,
providerCloser: providerResult.Closer,
maxSteps: agentConfig.MaxSteps,
systemPrompt: agentConfig.SystemPrompt,
loadingMessage: providerResult.Message,
providerType: providerType,
streamingEnabled: agentConfig.StreamingEnabled,
coreTools: coreTools,
extraTools: agentConfig.ExtraTools,
toolWrapper: agentConfig.ToolWrapper,
providerOptions: providerResult.ProviderOptions,
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
modelConfig: agentConfig.ModelConfig,
}
// Start MCP tool loading in the background if servers are configured.
// The mcpReady channel is closed when loading completes (success or failure).
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
toolManager := tools.NewMCPToolManager()
toolManager.SetModel(providerResult.Model)
if agentConfig.AuthHandler != nil {
toolManager.SetAuthHandler(agentConfig.AuthHandler)
}
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
// Set per-server loaded callback if provided.
if agentConfig.OnMCPServerLoaded != nil {
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
}
a.toolManager = toolManager
a.mcpReady = make(chan struct{})
go func() {
defer close(a.mcpReady)
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
a.mcpErr = err
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
}
}()
}
return a, nil
}
// WaitForMCPTools blocks until background MCP tool loading completes.
// Returns nil if no MCP servers are configured or if loading succeeded.
// Returns the loading error if all servers failed. Safe to call multiple times.
func (a *Agent) WaitForMCPTools() error {
if a.mcpReady == nil {
return nil
}
<-a.mcpReady
return a.mcpErr
}
// MCPToolsReady returns true if MCP tool loading has completed (or was never
// started). This is a non-blocking check useful for UI status display.
func (a *Agent) MCPToolsReady() bool {
if a.mcpReady == nil {
return true
}
select {
case <-a.mcpReady:
return true
default:
return false
}
}
// ensureMCPTools waits for MCP tools to load and rebuilds the fantasy agent
// with the full tool set. Called lazily before the first LLM call.
// This is idempotent — subsequent calls after the first rebuild are no-ops.
func (a *Agent) ensureMCPTools() {
if a.mcpReady == nil {
return
}
<-a.mcpReady
// If there are MCP tools, rebuild the fantasy agent to include them.
if a.toolManager != nil && len(a.toolManager.GetTools()) > 0 {
a.rebuildFantasyAgent()
}
// Nil out the channel so future calls are instant no-ops and we
// don't rebuild again.
a.mcpReady = nil
}
// rebuildFantasyAgent reconstructs the fantasy agent with the current full
// tool set (core + MCP + extension tools). Used after MCP tools arrive
// asynchronously and by SetModel.
func (a *Agent) rebuildFantasyAgent() {
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
providerResult := &models.ProviderResult{
Model: a.model,
ProviderOptions: a.providerOptions,
SkipMaxOutputTokens: a.skipMaxOutputTokens,
}
agentOpts := buildAgentOptions(&AgentConfig{
ModelConfig: a.modelConfig,
SystemPrompt: a.systemPrompt,
MaxSteps: a.maxSteps,
}, providerResult, allTools)
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
}
// buildAgentOptions constructs the fantasy.AgentOption slice from config,
// provider result, and the combined tool list. Shared by NewAgent,
// rebuildFantasyAgent, and SetModel.
func buildAgentOptions(agentConfig *AgentConfig, providerResult *models.ProviderResult, allTools []fantasy.AgentTool) []fantasy.AgentOption {
var agentOpts []fantasy.AgentOption
if agentConfig.SystemPrompt != "" {
@@ -200,33 +360,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.ModelConfig.TopK != nil {
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
}
}
// Create the agent
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Determine provider type from model string
providerType := "default"
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
providerType = p
if agentConfig.ModelConfig.FrequencyPenalty != nil {
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
}
if agentConfig.ModelConfig.PresencePenalty != nil {
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
}
}
return &Agent{
toolManager: toolManager,
fantasyAgent: fantasyAgent,
model: providerResult.Model,
providerCloser: providerResult.Closer,
maxSteps: agentConfig.MaxSteps,
systemPrompt: agentConfig.SystemPrompt,
loadingMessage: providerResult.Message,
providerType: providerType,
streamingEnabled: agentConfig.StreamingEnabled,
coreTools: coreTools,
extraTools: agentConfig.ExtraTools,
toolWrapper: agentConfig.ToolWrapper,
}, nil
return agentOpts
}
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
@@ -251,6 +393,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onStepUsage StepUsageHandler,
) (*GenerateWithLoopResult, error) {
// Wait for background MCP tool loading to complete and rebuild the
// fantasy agent with the full tool set. This is a no-op when no MCP
// servers are configured or tools have already been integrated.
a.ensureMCPTools()
// Inject tool output handler into context for use by core tools (e.g., bash).
if onToolOutput != nil {
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
@@ -451,9 +598,12 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return nil, err
}
// Fire the response callback for callers that use it (e.g. non-streaming
// callers that still want the final response notification).
if onResponse != nil && result.Response.Content.Text() != "" {
// Fire the response callback so callers (e.g. the TUI) can reset
// streaming state. This must fire even when the response text is
// empty (e.g. reasoning-only responses) so the UI properly resets
// the stream component and avoids duplicate content on the next
// flush.
if onResponse != nil {
onResponse(result.Response.Content.Text())
}
@@ -470,8 +620,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return nil, err
}
// For non-streaming, fire the response callback with the final text
if onResponse != nil && result.Response.Content.Text() != "" {
// For non-streaming, fire the response callback so callers can reset
// streaming state (see streaming path comment above).
if onResponse != nil {
onResponse(result.Response.Content.Text())
}
@@ -642,38 +793,9 @@ func (a *Agent) GetExtensionToolCount() int {
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
// tools) and rebuilds the internal agent with the updated tool list. The
// model, system prompt, and all other configuration are preserved.
func (a *Agent) SetExtraTools(tools []fantasy.AgentTool) {
a.extraTools = tools
// Rebuild tool list (same as NewAgent / SetModel).
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
// Rebuild agent options with the existing model.
var agentOpts []fantasy.AgentOption
if a.systemPrompt != "" {
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
}
if len(allTools) > 0 {
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
}
if a.maxSteps > 0 {
agentOpts = append(agentOpts, fantasy.WithStopConditions(
fantasy.StepCountIs(a.maxSteps),
))
}
// Swap the fantasy agent (model and provider are unchanged).
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
a.extraTools = extraTools
a.rebuildFantasyAgent()
}
// GetLoadingMessage returns the loading message from provider creation.
@@ -693,60 +815,14 @@ func (a *Agent) GetLoadedServerNames() []string {
// system prompt, and configuration are preserved. The old provider is closed
// if it has a closer. Returns the previous model string for notification.
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
// before the first LLM call).
a.ensureMCPTools()
providerResult, err := models.CreateProvider(ctx, config)
if err != nil {
return fmt.Errorf("failed to create model provider: %v", err)
}
// Rebuild tool list (same as NewAgent).
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
// Rebuild agent options.
var agentOpts []fantasy.AgentOption
if a.systemPrompt != "" {
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
}
if len(allTools) > 0 {
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
}
if a.maxSteps > 0 {
agentOpts = append(agentOpts, fantasy.WithStopConditions(
fantasy.StepCountIs(a.maxSteps),
))
}
// Pass provider-specific options (e.g. OpenAI Responses API reasoning settings).
if providerResult.ProviderOptions != nil {
agentOpts = append(agentOpts, fantasy.WithProviderOptions(providerResult.ProviderOptions))
}
// Pass generation parameters when available.
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
}
if config.Temperature != nil {
agentOpts = append(agentOpts, fantasy.WithTemperature(float64(*config.Temperature)))
}
if config.TopP != nil {
agentOpts = append(agentOpts, fantasy.WithTopP(float64(*config.TopP)))
}
if config.TopK != nil {
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
}
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Close old provider.
if a.providerCloser != nil {
_ = a.providerCloser.Close()
@@ -758,9 +834,11 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
}
// Swap fields.
a.fantasyAgent = newFantasyAgent
a.model = providerResult.Model
a.providerCloser = providerResult.Closer
a.providerOptions = providerResult.ProviderOptions
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
a.modelConfig = config
// Update provider type.
if config.ModelString != "" {
@@ -769,6 +847,9 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
}
}
// Rebuild the fantasy agent with the new model and current tool set.
a.rebuildFantasyAgent()
return nil
}
@@ -778,7 +859,13 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
}
// Close closes the agent and cleans up resources.
// If MCP tools are still loading in the background, Close waits for them
// to finish before closing connections to avoid resource leaks.
func (a *Agent) Close() error {
// Wait for background MCP loading to finish before closing connections.
if a.mcpReady != nil {
<-a.mcpReady
}
var toolErr error
if a.toolManager != nil {
toolErr = a.toolManager.Close()
+21 -9
View File
@@ -36,13 +36,22 @@ type AgentCreationOptions struct {
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
// DebugLogger is an optional logger for debugging MCP communications
DebugLogger tools.DebugLogger // Optional debug logger
// AuthHandler handles OAuth authorization for remote MCP servers
AuthHandler tools.MCPAuthHandler
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used.
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper wraps the combined tool list before agent creation.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include (e.g. from extensions).
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// CreateAgent creates an agent with optional spinner for Ollama models.
@@ -50,15 +59,18 @@ type AgentCreationOptions struct {
// Returns the created agent or an error if creation fails.
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
agentConfig := &AgentConfig{
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
CoreTools: opts.CoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
var agent *Agent
+41
View File
@@ -997,6 +997,47 @@ func (a *App) NotifyWidgetUpdate() {
}
}
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
// prompt templates and skills from their provider callbacks. Called by file
// watchers when .md/.txt files change in prompt or skill directories.
// In non-interactive mode this is a no-op.
func (a *App) NotifyContentReload() {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(ContentReloadEvent{})
}
}
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
// tool names and MCP tool count from provider callbacks. Called when background
// MCP tool loading completes. In non-interactive mode this is a no-op.
func (a *App) NotifyMCPToolsReady() {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPToolsReadyEvent{})
}
}
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
// display a system message when a single MCP server finishes loading. Called
// per server as background MCP tool loading progresses.
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPServerLoadedEvent{
ServerName: serverName,
ToolCount: toolCount,
Error: err,
})
}
}
// SendEvent sends a tea.Msg to the registered program. Safe to call from
// any goroutine. No-op when no program is registered.
//
+19
View File
@@ -167,6 +167,25 @@ type ModelChangedEvent struct {
// from its WidgetProvider on the next render cycle.
type WidgetUpdateEvent struct{}
// ContentReloadEvent is sent when prompt templates or skills are reloaded
// from disk (e.g. by a file watcher detecting changes). The TUI refreshes
// its autocomplete entries and internal state from the provider callbacks.
type ContentReloadEvent struct{}
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
// The TUI refreshes its tool names and MCP tool count from provider callbacks
// so that /tools and the startup info bar reflect the loaded MCP tools.
type MCPToolsReadyEvent struct{}
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
// (successfully or with error). The TUI displays a system message so users
// see real-time progress as each server initializes.
type MCPServerLoadedEvent struct {
ServerName string
ToolCount int
Error error // nil on success
}
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
// pre-fill the input editor with text. The TUI handles this by setting the
// textarea content and moving the cursor to the end.
+9 -5
View File
@@ -199,11 +199,13 @@ type Config struct {
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
Theme any `json:"theme" yaml:"theme"`
// Model generation parameters
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
// Thinking / extended reasoning
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
@@ -370,6 +372,8 @@ mcpServers:
# temperature: 0.7 # Randomness (0.0-1.0)
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
# top-k: 40 # Top K sampling
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
# API Configuration (can also use environment variables)
+25 -26
View File
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
// Install clones a git repository to the appropriate scope.
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
return i.install(source, scope, nil)
}
// install is the internal implementation that supports optional include paths.
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
targetDir := i.getInstallPath(source, scope)
// Check if already installed
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
Pinned: source.Pinned,
Scope: scope,
Installed: time.Now(),
Include: includePaths,
}
if err := i.addToManifest(entry, scope); err != nil {
// Don't fail the install, just log the error
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
cleanCmd.Dir = targetDir
_ = cleanCmd.Run() // Ignore errors - clean is best effort
// Update manifest timestamp
// Update manifest timestamp, preserving existing fields like Include
existing, _ := i.loadManifest(scope)
var include []string
var installed time.Time
if existing != nil {
for _, p := range existing.Packages {
if p.Host+"/"+p.Path == source.Identity() {
include = p.Include
installed = p.Installed
break
}
}
}
if installed.IsZero() {
installed = time.Now()
}
entry := ManifestEntry{
Source: source.String(),
Repo: source.Repo,
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
Ref: "",
Pinned: false,
Scope: scope,
Installed: time.Now(),
Installed: installed,
Updated: time.Now(),
Include: include,
}
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
// InstallWithInclude clones a repo and installs only the specified extensions.
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
// First, do a regular install
if err := i.Install(source, scope); err != nil {
return err
}
// If specific includes were requested, update the manifest
if len(includePaths) > 0 {
entry := ManifestEntry{
Source: source.String(),
Repo: source.Repo,
Host: source.Host,
Path: source.Path,
Ref: source.Ref,
Pinned: source.Pinned,
Scope: scope,
Include: includePaths,
}
if err := addEntryToManifest(entry, scope); err != nil {
return fmt.Errorf("updating manifest with includes: %w", err)
}
}
return nil
return i.install(source, scope, includePaths)
}
// CleanupTempDir removes a temporary directory used for preview.
+11 -18
View File
@@ -2,11 +2,11 @@ package extensions
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
"github.com/traefik/yaegi/interp"
"github.com/traefik/yaegi/stdlib"
"github.com/traefik/yaegi/stdlib/unrestricted"
@@ -34,15 +34,11 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
for _, p := range paths {
ext, err := loadSingleExtension(p)
if err != nil {
log.Warn("skipping extension", "path", p, "err", err)
log.Printf("WARN skipping extension: path=%s err=%v", p, err)
continue
}
loaded = append(loaded, *ext)
log.Debug("loaded extension", "path", p,
"handlers", countHandlers(ext),
"tools", len(ext.Tools),
"commands", len(ext.Commands),
"tool_renderers", len(ext.ToolRenderers))
log.Printf("DEBUG loaded extension: path=%s handlers=%d tools=%d commands=%d tool_renderers=%d", p, countHandlers(ext), len(ext.Tools), len(ext.Commands), len(ext.ToolRenderers))
}
return loaded, nil
}
@@ -133,7 +129,7 @@ func findExtensionsInDir(dir string) []string {
for _, entry := range entries {
full := filepath.Join(dir, entry.Name())
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
results = append(results, full)
} else if entry.IsDir() {
main := filepath.Join(full, "main.go")
@@ -190,9 +186,13 @@ func findExtensionsInRepo(repoPath string) []string {
isExtDir := base == "extensions" || base == "ext" ||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
// Allow walking into examples/ so we can reach examples/extensions/ etc,
// but don't treat examples/ itself or non-extension subdirs as extension locations.
if relPath == "examples" {
return nil
}
if !isExtDir && !isExamplesSubdir {
if !isExtDir {
mainPath := filepath.Join(path, "main.go")
if _, err := os.Stat(mainPath); err == nil {
if relPath == base { // Top-level directory
@@ -202,13 +202,6 @@ func findExtensionsInRepo(repoPath string) []string {
}
return filepath.SkipDir
}
if isExamplesSubdir || isExtDir {
if !multiFileDirs[relPath] {
multiFileDirs[relPath] = true
results = append(results, mainPath)
}
return filepath.SkipDir
}
}
return filepath.SkipDir
}
@@ -227,7 +220,7 @@ func findExtensionsInRepo(repoPath string) []string {
}
// It's a file
if !strings.HasSuffix(info.Name(), ".go") {
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
return nil
}
+7 -16
View File
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
isExtDir := base == "extensions" || base == "ext" ||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
// Or check if it's a subdirectory of examples/ that might contain extensions
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
// Allow walking into examples/ so we can reach examples/extensions/ etc,
// but don't treat examples/ itself or non-extension subdirs as extension locations.
if relPath == "examples" {
return nil
}
if !isExtDir && !isExamplesSubdir {
if !isExtDir {
// Check for main.go before skipping
mainPath := filepath.Join(path, "main.go")
if _, err := os.Stat(mainPath); err == nil {
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
}
return filepath.SkipDir
}
// Inside a valid extensions directory
if isExamplesSubdir || isExtDir {
if !multiFileDirs[relPath] {
multiFileDirs[relPath] = true
previews = append(previews, ExtensionPreview{
Path: "./" + relPath + "/main.go",
Name: deriveExtensionName(relPath+"/main.go", true),
IsMain: true,
})
}
return filepath.SkipDir
}
}
// Not an extension location
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
}
// It's a file - check if it's a valid extension
if !strings.HasSuffix(info.Name(), ".go") {
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
return nil
}
+3 -8
View File
@@ -2,12 +2,12 @@ package extensions
import (
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
"github.com/charmbracelet/log"
"github.com/spf13/viper"
)
@@ -370,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
log.Warn("extension handler error",
"path", ext.Path,
"event", event.Type(),
"err", err)
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
continue
}
if result == nil {
@@ -707,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
safeInvoke := func(h func(string)) {
defer func() {
if rec := recover(); rec != nil {
log.Warn("custom event handler panicked",
"event", name,
"err", fmt.Sprintf("%v", rec))
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
}
}()
h(data)
+6 -6
View File
@@ -3,13 +3,13 @@ package extensions
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/fsnotify/fsnotify"
)
@@ -39,7 +39,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
for _, dir := range dirs {
// Watch the directory itself.
if err := fsw.Add(dir); err != nil {
log.Debug("watcher: skipping directory", "dir", dir, "err", err)
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
continue
}
@@ -52,7 +52,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Debug("watcher: skipping subdirectory", "dir", subdir, "err", err)
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
}
}
}
@@ -101,7 +101,7 @@ func (w *Watcher) Start(ctx context.Context) {
continue
}
log.Debug("watcher: file changed", "file", event.Name, "op", event.Op)
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
// Debounce: reset timer on each event.
if timer != nil {
@@ -113,14 +113,14 @@ func (w *Watcher) Start(ctx context.Context) {
case <-timerC:
timerC = nil
timer = nil
log.Debug("watcher: reloading extensions")
log.Printf("DEBUG watcher: reloading extensions")
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Warn("watcher: error", "err", err)
log.Printf("WARN watcher: error: %v", err)
}
}
}
+42 -25
View File
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ExtraTools are additional tools added alongside core, MCP, and extension
// tools. They do not replace the defaults — they extend them.
ExtraTools []fantasy.AgentTool
@@ -58,6 +62,12 @@ type AgentSetupOptions struct {
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
// is also set.
StreamingEnabled bool
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
AuthHandler tools.MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// AgentSetupResult bundles the created agent and any debug logger so the caller
@@ -81,23 +91,27 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
temperature := float32(viper.GetFloat64("temperature"))
topP := float32(viper.GetFloat64("top-p"))
topK := int32(viper.GetInt("top-k"))
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
numGPU := int32(viper.GetInt("num-gpu-layers"))
mainGPU := int32(viper.GetInt("main-gpu"))
cfg := &models.ProviderConfig{
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &frequencyPenalty,
PresencePenalty: &presencePenalty,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
}
return cfg, systemPrompt, nil
@@ -176,18 +190,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
}
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: maxSteps,
StreamingEnabled: streamingEnabled,
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
CoreTools: opts.CoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: maxSteps,
StreamingEnabled: streamingEnabled,
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)
File diff suppressed because one or more lines are too long
+22 -14
View File
@@ -143,20 +143,22 @@ func ParseThinkingLevel(s string) ThinkingLevel {
// ProviderConfig holds configuration for creating LLM providers.
type ProviderConfig struct {
ModelString string
SystemPrompt string
ProviderAPIKey string
ProviderURL string
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
StopSequences []string
NumGPU *int32
MainGPU *int32
TLSSkipVerify bool
ThinkingLevel ThinkingLevel
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
ModelString string
SystemPrompt string
ProviderAPIKey string
ProviderURL string
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
FrequencyPenalty *float32
PresencePenalty *float32
StopSequences []string
NumGPU *int32
MainGPU *int32
TLSSkipVerify bool
ThinkingLevel ThinkingLevel
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
}
// ProviderResult contains the result of provider creation.
@@ -1164,6 +1166,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
if config.TopK != nil {
options["top_k"] = int(*config.TopK)
}
if config.FrequencyPenalty != nil {
options["frequency_penalty"] = *config.FrequencyPenalty
}
if config.PresencePenalty != nil {
options["presence_penalty"] = *config.PresencePenalty
}
if len(config.StopSequences) > 0 {
options["stop"] = config.StopSequences
}
+2 -6
View File
@@ -2,11 +2,10 @@ package prompts
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
)
// LoadOptions configures how templates are discovered and loaded.
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
DroppedPath: tpl.FilePath,
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
})
log.Debug("template collision",
"name", tpl.Name,
"dropped", tpl.FilePath,
"kept", existing.FilePath)
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
} else {
tpl.Source = source
seen[tpl.Name] = tpl
+33
View File
@@ -24,6 +24,7 @@ const (
EntryTypeSessionInfo EntryType = "session_info"
EntryTypeExtensionData EntryType = "extension_data"
EntryTypeCompaction EntryType = "compaction"
EntryTypeSystemPrompt EntryType = "system_prompt"
)
// CurrentVersion is the session format version for JSONL tree sessions.
@@ -117,6 +118,19 @@ type CompactionEntry struct {
ModifiedFiles []string `json:"modified_files,omitempty"`
}
// SystemPromptEntry records the system prompt and model used for the session.
// This is primarily for sharing/debugging to see what instructions were
// active during the conversation. It does NOT participate in the tree
// structure (no ParentID) and is not used when building LLM context.
type SystemPromptEntry struct {
Type EntryType `json:"type"` // always "system_prompt"
ID string `json:"id"` // unique entry ID
Timestamp time.Time `json:"timestamp"` // when captured
Content string `json:"content"` // the system prompt text
Model string `json:"model"` // the model used (e.g., "claude-sonnet-4-5")
Provider string `json:"provider"` // the provider used (e.g., "anthropic")
}
// GenerateEntryID creates a unique entry identifier (16 hex chars).
func GenerateEntryID() string {
bytes := make([]byte, 8)
@@ -217,6 +231,18 @@ func NewCompactionEntry(parentID, summary, firstKeptEntryID string, tokensBefore
}
}
// NewSystemPromptEntry creates a SystemPromptEntry.
func NewSystemPromptEntry(content, model, provider string) *SystemPromptEntry {
return &SystemPromptEntry{
Type: EntryTypeSystemPrompt,
ID: GenerateEntryID(),
Timestamp: time.Now(),
Content: content,
Model: model,
Provider: provider,
}
}
// --- JSONL marshaling helpers ---
// MarshalEntry serializes any entry to a JSON line (no trailing newline).
@@ -295,6 +321,13 @@ func UnmarshalEntry(data []byte) (any, error) {
}
return &e, nil
case EntryTypeSystemPrompt:
var e SystemPromptEntry
if err := json.Unmarshal(data, &e); err != nil {
return nil, fmt.Errorf("failed to unmarshal system_prompt entry: %w", err)
}
return &e, nil
default:
return nil, fmt.Errorf("unknown entry type: %q", env.Type)
}
+113
View File
@@ -0,0 +1,113 @@
package session
import (
"encoding/json"
"testing"
)
func TestSystemPromptEntry(t *testing.T) {
// Test creation
content := "You are a helpful coding assistant."
model := "claude-sonnet-4-5"
provider := "anthropic"
entry := NewSystemPromptEntry(content, model, provider)
if entry.Type != EntryTypeSystemPrompt {
t.Errorf("Expected type %q, got %q", EntryTypeSystemPrompt, entry.Type)
}
if entry.Content != content {
t.Errorf("Expected content %q, got %q", content, entry.Content)
}
if entry.Model != model {
t.Errorf("Expected model %q, got %q", model, entry.Model)
}
if entry.Provider != provider {
t.Errorf("Expected provider %q, got %q", provider, entry.Provider)
}
if entry.ID == "" {
t.Error("Expected non-empty ID")
}
// Test marshaling
data, err := MarshalEntry(entry)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Test unmarshaling
unmarshaled, err := UnmarshalEntry(data)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
sysPrompt, ok := unmarshaled.(*SystemPromptEntry)
if !ok {
t.Fatalf("Expected *SystemPromptEntry, got %T", unmarshaled)
}
if sysPrompt.Type != EntryTypeSystemPrompt {
t.Errorf("Unmarshaled: expected type %q, got %q", EntryTypeSystemPrompt, sysPrompt.Type)
}
if sysPrompt.Content != content {
t.Errorf("Unmarshaled: expected content %q, got %q", content, sysPrompt.Content)
}
if sysPrompt.Model != model {
t.Errorf("Unmarshaled: expected model %q, got %q", model, sysPrompt.Model)
}
if sysPrompt.Provider != provider {
t.Errorf("Unmarshaled: expected provider %q, got %q", provider, sysPrompt.Provider)
}
if sysPrompt.ID != entry.ID {
t.Errorf("Unmarshaled: expected ID %q, got %q", entry.ID, sysPrompt.ID)
}
}
func TestSystemPromptEntryJSONStructure(t *testing.T) {
content := "Test system prompt content"
model := "gpt-4o"
provider := "openai"
entry := NewSystemPromptEntry(content, model, provider)
data, err := MarshalEntry(entry)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("Failed to unmarshal to raw map: %v", err)
}
if raw["type"] != "system_prompt" {
t.Errorf("Expected type 'system_prompt', got %v", raw["type"])
}
if raw["content"] != content {
t.Errorf("Expected content %q, got %v", content, raw["content"])
}
if raw["model"] != model {
t.Errorf("Expected model %q, got %v", model, raw["model"])
}
if raw["provider"] != provider {
t.Errorf("Expected provider %q, got %v", provider, raw["provider"])
}
if raw["id"] == "" || raw["id"] == nil {
t.Error("Expected non-empty id field")
}
if raw["timestamp"] == "" || raw["timestamp"] == nil {
t.Error("Expected non-empty timestamp field")
}
}
+84 -10
View File
@@ -68,6 +68,7 @@ type MCPConnectionPool struct {
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
}
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
@@ -75,7 +76,7 @@ type MCPConnectionPool struct {
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
}
@@ -90,6 +91,10 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
debug: debug,
}
if authHandler != nil {
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
}
go pool.startHealthCheck()
return pool
}
@@ -103,6 +108,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
p.debugLogger = logger
}
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
// When set, the pool can trigger OAuth re-authorization when a tool call fails
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
p.mu.Lock()
defer p.mu.Unlock()
p.oauthFlow = flow
}
// GetConnection retrieves or creates a connection for the specified MCP server.
// If a healthy, non-idle connection exists in the pool, it will be reused.
// Otherwise, a new connection is created and added to the pool.
@@ -230,18 +244,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
// createConnection creates a new connection
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
client, err := p.createMCPClient(ctx, serverName, serverConfig)
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return nil, err
// SSE transport can return OAuth error during Start()
if p.oauthFlow != nil && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
// Retry after successful auth
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
if err := p.initializeClient(ctx, client); err != nil {
_ = client.Close()
return nil, err
if err := p.initializeClient(ctx, mcpClient); err != nil {
// Streamable HTTP transport returns OAuth error during Initialize()
if p.oauthFlow != nil && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
_ = mcpClient.Close()
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
// Retry initialization after successful auth
if err := p.initializeClient(ctx, mcpClient); err != nil {
_ = mcpClient.Close()
return nil, err
}
} else {
_ = mcpClient.Close()
return nil, err
}
}
conn := &MCPConnection{
client: client,
client: mcpClient,
serverName: serverName,
serverConfig: serverConfig,
lastUsed: time.Now(),
@@ -323,13 +362,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithOAuth(transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := sseClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start SSE client: %v", err)
return nil, fmt.Errorf("failed to start SSE client: %w", err)
}
return sseClient, nil
@@ -354,13 +409,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := streamableClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
}
return streamableClient, nil
@@ -381,7 +452,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
_, err := client.Initialize(initCtx, initRequest)
if err != nil {
return fmt.Errorf("initialization timeout or failed: %v", err)
return fmt.Errorf("initialization timeout or failed: %w", err)
}
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
@@ -539,6 +610,9 @@ func (p *MCPConnectionPool) Close() error {
// isConnectionError checks if the error is connection-related
func isConnectionError(err error) bool {
if IsOAuthError(err) {
return false // OAuth errors are recoverable, not connection failures
}
errStr := err.Error()
return strings.Contains(errStr, "Connection not found") ||
strings.Contains(errStr, "transport error") ||
+24 -3
View File
@@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas
},
})
if err != nil {
// Mark connection as unhealthy for automatic recovery
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
// Handle OAuth re-authorization: token may have expired mid-session.
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
}
// Retry the tool call after successful re-auth.
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{
Method: "tools/call",
},
Params: mcp.CallToolParams{
Name: t.mapping.originalName,
Arguments: arguments,
},
})
if err != nil {
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
}
} else {
// Mark connection as unhealthy for automatic recovery
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
}
}
// Marshal the MCP result to JSON string
+86 -21
View File
@@ -4,8 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"slices"
"strings"
"sync"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
@@ -21,10 +23,16 @@ type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
config *config.Config
debug bool
debugLogger DebugLogger
// onServerLoaded, if non-nil, is called when each server finishes loading.
// Called with server name, tool count, and error (nil on success).
onServerLoaded func(serverName string, toolCount int, err error)
}
// toolMapping stores the mapping between prefixed tool names and their original details
@@ -53,6 +61,14 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
m.model = model
}
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
// support, enabling automatic authorization flows when servers require authentication.
// This method should be called before LoadTools.
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
m.authHandler = handler
}
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
@@ -64,48 +80,87 @@ func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
}
}
// SetOnServerLoaded sets the callback that's invoked when each MCP server finishes
// loading. The callback receives the server name, tool count, and any error.
// Call this before LoadTools to receive per-server notifications.
func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount int, err error)) {
m.onServerLoaded = cb
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
// This method is thread-safe and idempotent.
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) error {
// Initialize connection pool
m.config = config
m.debug = config.Debug
m.config = cfg
m.debug = cfg.Debug
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(config.Debug)
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
m.connectionPool.SetDebugLogger(m.debugLogger)
var loadErrors []string
// Load all servers in parallel. Each server connection (subprocess
// spawn, MCP initialize handshake, ListTools) is independent and
// typically dominated by process startup latency. Running them
// concurrently reduces total wall-clock time from O(n * avg) to
// O(max).
type serverResult struct {
name string
err error
}
for serverName, serverConfig := range config.MCPServers {
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
continue
results := make(chan serverResult, len(cfg.MCPServers))
var wg sync.WaitGroup
for serverName, serverConfig := range cfg.MCPServers {
wg.Add(1)
go func(name string, sc config.MCPServerConfig) {
defer wg.Done()
count, err := m.loadServerTools(ctx, name, sc)
results <- serverResult{name: name, err: err}
// Notify callback if set (for real-time UI updates).
if m.onServerLoaded != nil {
m.onServerLoaded(name, count, err)
}
}(serverName, serverConfig)
}
// Close results channel once all goroutines finish.
go func() {
wg.Wait()
close(results)
}()
var loadErrors []string
for r := range results {
if r.err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", r.name, r.err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", r.name, r.err)
}
}
// If all servers failed to load, return an error
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
if len(loadErrors) == len(cfg.MCPServers) && len(cfg.MCPServers) > 0 {
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
}
return nil
}
// loadServerTools loads tools from a single MCP server
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
// loadServerTools loads tools from a single MCP server.
// Thread-safe: may be called concurrently for different servers.
// Returns the number of tools loaded from this server, or -1 on error.
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (int, error) {
// Add debug logging
m.debugLogConnectionInfo(serverName, serverConfig)
// Get connection from pool
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to get connection from pool: %v", err)
return -1, fmt.Errorf("failed to get connection from pool: %v", err)
}
// Get tools from this server
@@ -113,7 +168,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
if err != nil {
// Handle connection error
m.connectionPool.HandleConnectionError(serverName, err)
return fmt.Errorf("failed to list tools: %v", err)
return -1, fmt.Errorf("failed to list tools: %v", err)
}
// Create name set for allowed tools
@@ -125,6 +180,10 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
}
}
// Build tools locally before acquiring the lock.
var localTools []fantasy.AgentTool
localMap := make(map[string]*toolMapping)
// Convert MCP tools to fantasy AgentTools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
@@ -142,7 +201,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
// Fix for JSON Schema draft-07 vs draft-04 compatibility
@@ -151,7 +210,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Parse into map[string]any for fantasy's parameters format
var schemaMap map[string]any
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Extract properties and required from the schema
@@ -184,7 +243,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
serverConfig: serverConfig,
manager: m,
}
m.toolMap[prefixedName] = mapping
localMap[prefixedName] = mapping
// Create fantasy AgentTool
fantasyTool := &mcpFantasyTool{
@@ -197,10 +256,16 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
mapping: mapping,
}
m.tools = append(m.tools, fantasyTool)
localTools = append(localTools, fantasyTool)
}
return nil
// Merge into the manager under the lock.
m.mu.Lock()
maps.Copy(m.toolMap, localMap)
m.tools = append(m.tools, localTools...)
m.mu.Unlock()
return len(localTools), nil
}
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
+109
View File
@@ -0,0 +1,109 @@
package tools
import (
"context"
"fmt"
"net/url"
"github.com/mark3labs/mcp-go/client"
)
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go
// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK.
type MCPAuthHandler interface {
// RedirectURI returns the OAuth redirect URI for transport setup.
RedirectURI() string
// HandleAuth is called when a server requires OAuth authorization.
// It receives the server name and the authorization URL the user must visit.
// It returns the full callback URL (containing code and state query params)
// after the user completes authorization.
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
// registration, PKCE generation, user authorization (via MCPAuthHandler),
// and token exchange.
type OAuthFlowRunner struct {
handler MCPAuthHandler
}
// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler.
func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner {
return &OAuthFlowRunner{handler: handler}
}
// RunAuthFlow executes the OAuth authorization flow for the given server.
// It extracts the OAuthHandler from the error, performs dynamic client registration
// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user
// interaction, and exchanges the authorization code for a token.
func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error {
// Extract the OAuthHandler from the authorization-required error.
oauthHandler := client.GetOAuthHandler(authErr)
if oauthHandler == nil {
return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr)
}
// Perform dynamic client registration if no client ID is configured yet.
if oauthHandler.GetClientID() == "" {
if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil {
return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err)
}
}
// Generate PKCE code verifier and challenge.
codeVerifier, err := client.GenerateCodeVerifier()
if err != nil {
return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err)
}
codeChallenge := client.GenerateCodeChallenge(codeVerifier)
// Generate a random state parameter for CSRF protection.
state, err := client.GenerateState()
if err != nil {
return fmt.Errorf("oauth flow: failed to generate state: %w", err)
}
// Build the authorization URL the user needs to visit.
authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge)
if err != nil {
return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err)
}
// Delegate to the MCPAuthHandler for user-facing authorization (e.g. open
// browser, wait for redirect). It returns the full callback URL containing
// the authorization code and state.
callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL)
if err != nil {
return fmt.Errorf("oauth flow: user authorization failed: %w", err)
}
// Parse the callback URL to extract the authorization code and state.
parsed, err := url.Parse(callbackURL)
if err != nil {
return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err)
}
code := parsed.Query().Get("code")
returnedState := parsed.Query().Get("state")
if code == "" {
return fmt.Errorf("oauth flow: callback URL missing 'code' parameter")
}
if returnedState == "" {
return fmt.Errorf("oauth flow: callback URL missing 'state' parameter")
}
// Exchange the authorization code for an access token.
if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil {
return fmt.Errorf("oauth flow: token exchange failed: %w", err)
}
return nil
}
// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError.
func IsOAuthError(err error) bool {
return client.IsOAuthAuthorizationRequiredError(err)
}
+155
View File
@@ -0,0 +1,155 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/mark3labs/mcp-go/client/transport"
)
// Compile-time check that FileTokenStore implements transport.TokenStore.
var _ transport.TokenStore = (*FileTokenStore)(nil)
// FileTokenStore is a file-backed implementation of transport.TokenStore that
// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file
// keyed by server URL, allowing multiple MCP servers to maintain independent tokens.
//
// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back
// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set.
//
// FileTokenStore is safe for concurrent use.
type FileTokenStore struct {
serverKey string
filePath string
mu sync.RWMutex
}
// NewFileTokenStore creates a new FileTokenStore for the given server URL.
// The serverKey is used as the map key in the shared token file, and should
// typically be the MCP server's base URL.
//
// Returns an error if the token file path cannot be resolved.
func NewFileTokenStore(serverKey string) (*FileTokenStore, error) {
filePath, err := resolveTokenFilePath()
if err != nil {
return nil, fmt.Errorf("resolving token file path: %w", err)
}
return &FileTokenStore{
serverKey: serverKey,
filePath: filePath,
}, nil
}
// GetToken returns the stored token for this store's server key.
// Returns transport.ErrNoToken if no token exists for the server key or if
// the token file does not yet exist.
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
s.mu.RLock()
defer s.mu.RUnlock()
tokens, err := readTokenFile(s.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, transport.ErrNoToken
}
return nil, fmt.Errorf("reading token file: %w", err)
}
token, ok := tokens[s.serverKey]
if !ok {
return nil, transport.ErrNoToken
}
return token, nil
}
// SaveToken persists the given token for this store's server key.
// If the token file or its parent directories do not exist, they are created.
// Existing tokens for other server keys are preserved.
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error {
if err := ctx.Err(); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
tokens, err := readTokenFile(s.filePath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("reading token file: %w", err)
}
if tokens == nil {
tokens = make(map[string]*transport.Token)
}
tokens[s.serverKey] = token
if err := writeTokenFile(s.filePath, tokens); err != nil {
return fmt.Errorf("writing token file: %w", err)
}
return nil
}
// resolveTokenFilePath determines the path to the token file using
// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/.
func resolveTokenFilePath() (string, error) {
configDir := os.Getenv("XDG_CONFIG_HOME")
if configDir == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("determining user home directory: %w", err)
}
configDir = filepath.Join(home, ".config")
}
return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil
}
// readTokenFile reads and unmarshals the token file into a server-keyed map.
// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist.
func readTokenFile(path string) (map[string]*transport.Token, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var tokens map[string]*transport.Token
if err := json.Unmarshal(data, &tokens); err != nil {
return nil, fmt.Errorf("unmarshaling token file: %w", err)
}
return tokens, nil
}
// writeTokenFile marshals the token map and writes it to disk, creating
// parent directories as needed. The file is written with 0600 permissions
// to protect sensitive token data.
func writeTokenFile(path string, tokens map[string]*transport.Token) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating token directory %s: %w", dir, err)
}
data, err := json.MarshalIndent(tokens, "", " ")
if err != nil {
return fmt.Errorf("marshaling tokens: %w", err)
}
if err := os.WriteFile(path, data, 0600); err != nil {
return fmt.Errorf("writing token file %s: %w", path, err)
}
return nil
}
+229 -25
View File
@@ -12,6 +12,7 @@ import (
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/spf13/viper"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/core"
@@ -280,6 +281,16 @@ type AppModelOptions struct {
// ToolNames holds available tool names for the /tools command.
ToolNames []string
// GetToolNames, if non-nil, returns the current tool names. Called on
// MCPToolsReadyEvent to refresh the tool list after background MCP tool
// loading completes. May be nil if dynamic tool refresh is not needed.
GetToolNames func() []string
// GetMCPToolCount, if non-nil, returns the current MCP tool count.
// Called on MCPToolsReadyEvent to refresh the startup info bar.
// May be nil if dynamic tool refresh is not needed.
GetMCPToolCount func() int
// UsageTracker provides token usage statistics for /usage and /reset-usage.
// May be nil if usage tracking is unavailable for the current model.
UsageTracker *UsageTracker
@@ -293,6 +304,11 @@ type AppModelOptions struct {
// and are expanded when submitted (e.g., /review → full prompt text).
PromptTemplates []*prompts.PromptTemplate
// GetPromptTemplates, if non-nil, returns the current prompt templates.
// Called on ContentReloadEvent to refresh the template list after a file
// watcher detects changes. May be nil if prompt hot-reload is not needed.
GetPromptTemplates func() []*prompts.PromptTemplate
// ContextPaths lists absolute paths of loaded context files (e.g.
// AGENTS.md). Displayed in the [Context] startup section.
ContextPaths []string
@@ -300,6 +316,11 @@ type AppModelOptions struct {
// SkillItems lists loaded skills for the [Skills] startup section.
SkillItems []SkillItem
// GetSkillItems, if non-nil, returns the current skill items.
// Called on ContentReloadEvent to refresh the skill list after a file
// watcher detects changes. May be nil if skill hot-reload is not needed.
GetSkillItems func() []SkillItem
// MCPToolCount is the number of tools loaded from external MCP servers.
MCPToolCount int
@@ -484,8 +505,12 @@ type AppModel struct {
loadingMessage string
// serverNames, toolNames are used by /servers and /tools commands.
serverNames []string
toolNames []string
serverNames []string
toolNames []string
getToolNames func() []string // dynamic tool name provider (for MCP refresh)
// getMCPToolCount returns the current MCP tool count dynamically.
getMCPToolCount func() int
// usageTracker provides token usage stats for /usage and /reset-usage.
// May be nil when usage tracking is unavailable.
@@ -499,6 +524,10 @@ type AppModel struct {
// They appear in autocomplete and are expanded when submitted.
promptTemplates []*prompts.PromptTemplate
// getPromptTemplates returns the current prompt templates. Used to
// refresh the template list after content hot-reload. May be nil.
getPromptTemplates func() []*prompts.PromptTemplate
// treeSelector is the tree navigation overlay, active in stateTreeSelector.
treeSelector *TreeSelectorComponent
@@ -507,6 +536,10 @@ type AppModel struct {
contextPaths []string
skillItems []SkillItem
// getSkillItems returns the current skill items. Used to refresh the
// skill list after content hot-reload. May be nil.
getSkillItems func() []SkillItem
// mcpToolCount and extensionToolCount track tool counts by source for
// the startup info display.
mcpToolCount int
@@ -703,23 +736,26 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
rdr := mr
m := &AppModel{
state: stateInput,
appCtrl: appCtrl,
renderer: rdr,
modelName: opts.ModelName,
providerName: opts.ProviderName,
loadingMessage: opts.LoadingMessage,
serverNames: opts.ServerNames,
toolNames: opts.ToolNames,
usageTracker: opts.UsageTracker,
cwd: opts.Cwd,
width: width,
height: height,
state: stateInput,
appCtrl: appCtrl,
renderer: rdr,
modelName: opts.ModelName,
providerName: opts.ProviderName,
loadingMessage: opts.LoadingMessage,
serverNames: opts.ServerNames,
toolNames: opts.ToolNames,
getToolNames: opts.GetToolNames,
getMCPToolCount: opts.GetMCPToolCount,
usageTracker: opts.UsageTracker,
cwd: opts.Cwd,
width: width,
height: height,
}
// Store extension commands for dispatch.
m.extensionCommands = opts.ExtensionCommands
m.promptTemplates = opts.PromptTemplates
m.getPromptTemplates = opts.GetPromptTemplates
m.getWidgets = opts.GetWidgets
m.getHeader = opts.GetHeader
m.getFooter = opts.GetFooter
@@ -745,6 +781,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
// Store context/skills metadata and tool counts for startup display.
m.contextPaths = opts.ContextPaths
m.skillItems = opts.SkillItems
m.getSkillItems = opts.GetSkillItems
m.mcpToolCount = opts.MCPToolCount
m.extensionToolCount = opts.ExtensionToolCount
m.startupExtensionMessages = opts.StartupExtensionMessages
@@ -1699,6 +1736,13 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.stream, _ = updated.(streamComponentIface)
cmds = append(cmds, cmd)
}
// Mark any trailing StreamingMessageItem as complete so its live
// timer freezes and it is not left in a dangling streaming state.
if len(m.messages) > 0 {
if streamMsg, ok := m.messages[len(m.messages)-1].(*StreamingMessageItem); ok {
streamMsg.MarkComplete()
}
}
m.state = stateInput
m.canceling = false
@@ -1710,6 +1754,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.stream, _ = updated.(streamComponentIface)
cmds = append(cmds, cmd)
}
// Mark any trailing StreamingMessageItem as complete (see StepCompleteEvent).
if len(m.messages) > 0 {
if streamMsg, ok := m.messages[len(m.messages)-1].(*StreamingMessageItem); ok {
streamMsg.MarkComplete()
}
}
m.state = stateInput
m.canceling = false
@@ -1722,6 +1772,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.stream, _ = updated.(streamComponentIface)
cmds = append(cmds, cmd)
}
// Mark any trailing StreamingMessageItem as complete (see StepCompleteEvent).
if len(m.messages) > 0 {
if streamMsg, ok := m.messages[len(m.messages)-1].(*StreamingMessageItem); ok {
streamMsg.MarkComplete()
}
}
if msg.Err != nil {
m.printErrorResponse(msg)
}
@@ -1797,6 +1853,27 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
}
case app.ContentReloadEvent:
// Prompt templates or skills changed on disk — refresh from providers.
m.refreshPromptTemplates()
m.refreshSkillItems()
m.printSystemMessage("Prompts and skills reloaded.")
case app.MCPToolsReadyEvent:
// Background MCP tool loading completed — refresh tool names and count.
m.refreshToolNames()
m.refreshMCPToolCount()
case app.MCPServerLoadedEvent:
// A single MCP server finished loading — display a system message.
if msg.Error != nil {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' failed to load: %v", msg.ServerName, msg.Error))
} else if msg.ToolCount > 0 {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' loaded with %d tools", msg.ServerName, msg.ToolCount))
} else {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' loaded (no tools)", msg.ServerName))
}
case app.EditorTextSetEvent:
// Extension wants to pre-fill the input editor with text.
if ic, ok := m.input.(*InputComponent); ok {
@@ -2690,6 +2767,61 @@ func (m *AppModel) expandPromptTemplate(text string) (string, bool) {
return text, false
}
// refreshPromptTemplates reloads prompt templates from the provider callback
// and updates the autocomplete entries. Called on ContentReloadEvent.
func (m *AppModel) refreshPromptTemplates() {
if m.getPromptTemplates == nil {
return
}
newTemplates := m.getPromptTemplates()
m.promptTemplates = newTemplates
if ic, ok := m.input.(*InputComponent); ok {
// Remove old prompt commands and add fresh ones.
var kept []commands.SlashCommand
for _, sc := range ic.commands {
if sc.Category != "Prompts" {
kept = append(kept, sc)
}
}
for _, tpl := range newTemplates {
kept = append(kept, commands.SlashCommand{
Name: "/" + tpl.Name,
Description: tpl.Description,
Category: "Prompts",
})
}
ic.commands = kept
}
}
// refreshSkillItems reloads skill items from the provider callback.
// Called on ContentReloadEvent.
func (m *AppModel) refreshSkillItems() {
if m.getSkillItems == nil {
return
}
m.skillItems = m.getSkillItems()
}
// refreshToolNames reloads tool names from the provider callback.
// Called on MCPToolsReadyEvent when background MCP tool loading completes.
func (m *AppModel) refreshToolNames() {
if m.getToolNames == nil {
return
}
m.toolNames = m.getToolNames()
}
// refreshMCPToolCount reloads the MCP tool count from the provider callback.
// Called on MCPToolsReadyEvent when background MCP tool loading completes.
func (m *AppModel) refreshMCPToolCount() {
if m.getMCPToolCount == nil {
return
}
m.mcpToolCount = m.getMCPToolCount()
}
// printHelpMessage renders the help text listing all available slash commands.
func (m *AppModel) printHelpMessage() {
help := "## Available Commands\n\n" +
@@ -2885,12 +3017,27 @@ func (m *AppModel) flushStreamAndPendingUserMessages() {
if content := m.stream.GetRenderedContent(); content != "" {
m.stream.Reset()
// Render styled content using MessageRenderer
styledMsg := m.renderer.RenderAssistantMessage(content, time.Now(), m.modelName)
// Check whether the content is already in the ScrollList as a
// StreamingMessageItem (created by appendStreamingChunk during
// ReasoningChunkEvent / StreamChunkEvent). If so, just mark it
// complete — creating a second StyledMessageItem would duplicate
// the rendered block and shift mouse hit-testing coordinates.
alreadyInList := false
if len(m.messages) > 0 {
if streamMsg, ok := m.messages[len(m.messages)-1].(*StreamingMessageItem); ok {
streamMsg.MarkComplete()
alreadyInList = true
}
}
// Add to in-memory scrollList with styled content
msg := NewStyledMessageItem(generateMessageID(), "assistant", content, styledMsg.Content)
m.messages = append(m.messages, msg)
if !alreadyInList {
// Render styled content using MessageRenderer
styledMsg := m.renderer.RenderAssistantMessage(content, time.Now(), m.modelName)
// Add to in-memory scrollList with styled content
msg := NewStyledMessageItem(generateMessageID(), "assistant", content, styledMsg.Content)
m.messages = append(m.messages, msg)
}
}
}
@@ -3502,13 +3649,29 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
return nil
}
// Copy session to a temp file with a clean name.
// Read the original session file.
data, err := os.ReadFile(srcPath)
if err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to read session file: %v", err))
return nil
}
// Capture the current system prompt and model info.
systemPrompt := viper.GetString("system-prompt")
_, provider, modelID := ts.BuildContext()
if modelID == "" {
// Fallback to viper if no model change recorded in session
modelID = viper.GetString("model")
}
// Create a SystemPromptEntry with both prompt and model info.
sysPromptEntry := session.NewSystemPromptEntry(systemPrompt, modelID, provider)
sysPromptJSON, err := session.MarshalEntry(sysPromptEntry)
if err != nil {
m.printSystemMessage(fmt.Sprintf("Failed to marshal system prompt: %v", err))
return nil
}
name := ts.GetSessionName()
if name == "" {
name = "session"
@@ -3528,12 +3691,53 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
}
tmpPath := tmpFile.Name()
if _, err := tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
// Write the session data with the system prompt entry inserted after the header.
// The header is the first line, so we write:
// 1. First line (header) from original data
// 2. System prompt entry
// 3. Remaining lines from original data
lines := strings.Split(string(data), "\n")
if len(lines) > 0 && lines[len(lines)-1] == "" {
lines = lines[:len(lines)-1] // Remove trailing empty line
}
if len(lines) > 0 {
// Write header (first line)
if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
// Write system prompt entry
if _, err := tmpFile.Write(sysPromptJSON); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
return nil
}
if _, err := tmpFile.WriteString("\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
// Write remaining lines
for i := 1; i < len(lines); i++ {
if lines[i] == "" {
continue // Skip empty lines
}
if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
return nil
}
}
}
_ = tmpFile.Close()
m.printSystemMessage("Uploading session to GitHub Gist...")
+259
View File
@@ -0,0 +1,259 @@
// Package watcher provides a general-purpose file watcher that monitors
// directories for changes to files matching specified extensions. It uses
// fsnotify for kernel-level notifications with debouncing to coalesce
// rapid editor writes.
package watcher
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
// ContentWatcher monitors directories for file changes matching a set of
// extensions and triggers a reload callback when changes are detected.
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
type ContentWatcher struct {
watcher *fsnotify.Watcher
onReload func()
extensions []string // e.g. [".md", ".txt"]
label string // for logging (e.g. "prompts", "skills")
debounce time.Duration
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
}
// Options configures a ContentWatcher.
type Options struct {
// Dirs are the directories to watch.
Dirs []string
// Extensions are the file extensions to watch for (e.g. ".md", ".txt").
// Include the leading dot.
Extensions []string
// OnReload is called when a matching file changes (after debouncing).
OnReload func()
// Label is a human-readable name for logging (e.g. "prompts", "skills").
Label string
// Debounce is the debounce duration. Defaults to 300ms if zero.
Debounce time.Duration
}
// New creates a ContentWatcher that monitors the given directories for
// file changes matching the specified extensions. When a change is detected
// (after debouncing), onReload is called. The watcher must be started with
// Start() and stopped with Close().
func New(opts Options) (*ContentWatcher, error) {
if len(opts.Dirs) == 0 {
return nil, fmt.Errorf("no directories to watch")
}
fsw, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating file watcher: %w", err)
}
for _, dir := range opts.Dirs {
if err := fsw.Add(dir); err != nil {
continue
}
// Also watch immediate subdirectories (for skill/SKILL.md pattern).
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
_ = fsw.Add(subdir)
}
}
}
debounce := opts.Debounce
if debounce == 0 {
debounce = 300 * time.Millisecond
}
return &ContentWatcher{
watcher: fsw,
onReload: opts.OnReload,
extensions: opts.Extensions,
label: opts.Label,
debounce: debounce,
done: make(chan struct{}),
}, nil
}
// Start begins watching for file changes. It blocks until the context
// is cancelled or Close() is called. Typically called in a goroutine.
func (w *ContentWatcher) Start(ctx context.Context) {
w.mu.Lock()
ctx, w.cancel = context.WithCancel(ctx)
w.mu.Unlock()
defer close(w.done)
var timer *time.Timer
var timerC <-chan time.Time
for {
select {
case <-ctx.Done():
if timer != nil {
timer.Stop()
}
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// When a new subdirectory is created, start watching it so
// that files added inside (e.g. new-skill/SKILL.md) trigger
// reload events. Also schedule a reload in case the directory
// was created with matching files already inside.
if event.Op&fsnotify.Create != 0 {
if info, err := os.Stat(event.Name); err == nil && info.IsDir() {
if addErr := w.watcher.Add(event.Name); addErr == nil {
// Check if the new directory already contains matching files.
if w.dirContainsMatchingFiles(event.Name) {
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
}
}
continue
}
}
// Only care about files matching our extensions.
if !w.matchesExtension(event.Name) {
continue
}
// React to write, create, remove, rename events.
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
continue
}
// Debounce: reset timer on each event.
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
case <-timerC:
timerC = nil
timer = nil
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
_ = err
}
}
}
// Close stops the watcher and releases resources.
func (w *ContentWatcher) Close() error {
w.mu.Lock()
cancel := w.cancel
w.mu.Unlock()
if cancel != nil {
cancel()
}
// Wait for the event loop to finish.
<-w.done
return w.watcher.Close()
}
// matchesExtension returns true if the file name ends with one of the
// watched extensions.
func (w *ContentWatcher) matchesExtension(name string) bool {
for _, ext := range w.extensions {
if strings.HasSuffix(name, ext) {
return true
}
}
return false
}
// dirContainsMatchingFiles returns true if the directory contains at least
// one file matching the watched extensions. Used to detect cases where a
// directory is created with files already inside (e.g. cp -r).
func (w *ContentWatcher) dirContainsMatchingFiles(dir string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if !entry.IsDir() && w.matchesExtension(entry.Name()) {
return true
}
}
return false
}
// CollectDirs returns the directories to watch for a given set of standard
// directories and extra paths. Directories are deduplicated by absolute path
// and verified to exist. For explicit file paths, the parent directory is
// watched instead.
func CollectDirs(standardDirs []string, extraPaths []string) []string {
var dirs []string
seen := make(map[string]bool)
add := func(dir string) {
abs, err := filepath.Abs(dir)
if err != nil {
return
}
if seen[abs] {
return
}
// Verify the directory exists.
info, err := os.Stat(abs)
if err != nil || !info.IsDir() {
return
}
seen[abs] = true
dirs = append(dirs, abs)
}
for _, d := range standardDirs {
add(d)
}
for _, p := range extraPaths {
info, err := os.Stat(p)
if err != nil {
continue
}
if info.IsDir() {
add(p)
} else {
// For explicit files, watch the parent directory.
add(filepath.Dir(p))
}
}
return dirs
}
+307
View File
@@ -0,0 +1,307 @@
package watcher
import (
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
func TestContentWatcher_ReloadsOnMatchingFile(t *testing.T) {
dir := t.TempDir()
// Write an initial file so the directory isn't empty.
initial := filepath.Join(dir, "existing.md")
if err := os.WriteFile(initial, []byte("# Hello"), 0644); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Wait for watcher to be ready.
time.Sleep(100 * time.Millisecond)
// Modify the file.
if err := os.WriteFile(initial, []byte("# Updated"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_IgnoresNonMatchingFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write a non-matching file.
if err := os.WriteFile(filepath.Join(dir, "readme.txt"), []byte("hello"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 0 {
t.Errorf("expected 0 reloads for non-matching file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_MultipleExtensions(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md", ".txt"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write a .txt file — should trigger.
if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("notes"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload for .txt file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_Debounces(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 100 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Rapid-fire writes — should debounce into 1 reload.
for i := range 5 {
if err := os.WriteFile(filepath.Join(dir, "test.md"), []byte("v"+string(rune('0'+i))), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(30 * time.Millisecond)
}
time.Sleep(300 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 debounced reload, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesSubdirectories(t *testing.T) {
dir := t.TempDir()
// Create a subdirectory (simulates skill-name/SKILL.md pattern).
subdir := filepath.Join(dir, "my-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write to subdirectory.
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Skill"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload for subdirectory file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectory(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Wait for watcher to be ready.
time.Sleep(100 * time.Millisecond)
// Create a NEW subdirectory after the watcher started (the bug scenario).
subdir := filepath.Join(dir, "new-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
// Give fsnotify time to pick up the new directory.
time.Sleep(100 * time.Millisecond)
// Write a matching file inside the new subdirectory.
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# New Skill"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for file in new subdirectory, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectoryWithExistingFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Create a subdirectory with a matching file already inside (simulates cp -r).
subdir := filepath.Join(dir, "copied-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Copied"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(300 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for copied subdirectory with files, got %d", got)
}
_ = w.Close()
}
func TestCollectDirs_Deduplicates(t *testing.T) {
dir := t.TempDir()
dirs := CollectDirs([]string{dir, dir}, nil)
if len(dirs) != 1 {
t.Errorf("expected 1 deduplicated dir, got %d", len(dirs))
}
}
func TestCollectDirs_FileParent(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "test.md")
if err := os.WriteFile(file, []byte("test"), 0644); err != nil {
t.Fatal(err)
}
dirs := CollectDirs(nil, []string{file})
if len(dirs) != 1 {
t.Fatalf("expected 1 dir, got %d", len(dirs))
}
abs, _ := filepath.Abs(dir)
if dirs[0] != abs {
t.Errorf("expected %s, got %s", abs, dirs[0])
}
}
func TestCollectDirs_SkipsNonexistent(t *testing.T) {
dirs := CollectDirs([]string{"/nonexistent/dir"}, nil)
if len(dirs) != 0 {
t.Errorf("expected 0 dirs for nonexistent path, got %d", len(dirs))
}
}
+24 -2
View File
@@ -68,8 +68,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -172,6 +176,24 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
- `GetSessionID()` - Get session UUID
- `Close()` - Clean up resources
### Options
Key `Options` fields for SDK usage:
| Field | Description |
|-------|-------------|
| `Model` | Override model (e.g., "anthropic/claude-sonnet-4-5-20250929") |
| `SystemPrompt` | Override system prompt |
| `ConfigFile` | Load specific config file (empty = search defaults) |
| `SkipConfig` | Skip `.kit.yml` loading (defaults + env vars still apply) |
| `Tools` | Replace core tools with custom set |
| `ExtraTools` | Add tools alongside defaults |
| `DisableCoreTools` | Use no core tools (0 tools, for chat-only) |
| `NoSession` | Ephemeral mode (no session persistence) |
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `Debug` | Enable debug logging |
## Environment Variables
All CLI environment variables work with the SDK:
+231
View File
@@ -0,0 +1,231 @@
package kit
import (
"strings"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/session"
)
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
inner *session.TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
// AppendMessage implements SessionManager.
func (a *treeManagerAdapter) AppendMessage(msg LLMMessage) (string, error) {
// LLMMessage is just an alias for fantasy.Message, so no conversion needed
return a.inner.AppendLLMMessage(msg)
}
// GetMessages implements SessionManager.
func (a *treeManagerAdapter) GetMessages() []LLMMessage {
// LLMMessage is just an alias for fantasy.Message
return a.inner.GetLLMMessages()
}
// BuildContext implements SessionManager.
func (a *treeManagerAdapter) BuildContext() ([]LLMMessage, string, string) {
msgs, provider, modelID := a.inner.BuildContext()
return msgs, provider, modelID
}
// Branch implements SessionManager.
func (a *treeManagerAdapter) Branch(entryID string) error {
return a.inner.Branch(entryID)
}
// GetCurrentBranch implements SessionManager.
func (a *treeManagerAdapter) GetCurrentBranch() []BranchEntry {
branch := a.inner.GetBranch("")
var result []BranchEntry
for _, entry := range branch {
be := a.convertEntry(entry)
if be != nil {
result = append(result, *be)
}
}
return result
}
// GetChildren implements SessionManager.
func (a *treeManagerAdapter) GetChildren(parentID string) []string {
return a.inner.GetChildren(parentID)
}
// GetEntry implements SessionManager.
func (a *treeManagerAdapter) GetEntry(entryID string) *BranchEntry {
entry := a.inner.GetEntry(entryID)
if entry == nil {
return nil
}
return a.convertEntry(entry)
}
// GetSessionID implements SessionManager.
func (a *treeManagerAdapter) GetSessionID() string {
return a.inner.GetSessionID()
}
// GetSessionName implements SessionManager.
func (a *treeManagerAdapter) GetSessionName() string {
return a.inner.GetSessionName()
}
// SetSessionName implements SessionManager.
func (a *treeManagerAdapter) SetSessionName(name string) error {
_, err := a.inner.AppendSessionInfo(name)
return err
}
// GetCreatedAt implements SessionManager.
func (a *treeManagerAdapter) GetCreatedAt() time.Time {
return a.inner.GetHeader().Timestamp
}
// IsPersisted implements SessionManager.
func (a *treeManagerAdapter) IsPersisted() bool {
return a.inner.IsPersisted()
}
// AppendCompaction implements SessionManager.
func (a *treeManagerAdapter) AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
return a.inner.AppendCompaction(summary, firstKeptEntryID,
tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
}
// GetLastCompaction implements SessionManager.
func (a *treeManagerAdapter) GetLastCompaction() *CompactionEntry {
c := a.inner.GetLastCompaction()
if c == nil {
return nil
}
return &CompactionEntry{
ID: c.ID,
Summary: c.Summary,
FirstKeptEntryID: c.FirstKeptEntryID,
TokensBefore: c.TokensBefore,
TokensAfter: c.TokensAfter,
MessagesRemoved: c.MessagesRemoved,
ReadFiles: c.ReadFiles,
ModifiedFiles: c.ModifiedFiles,
Timestamp: c.Timestamp,
}
}
// AppendExtensionData implements SessionManager.
func (a *treeManagerAdapter) AppendExtensionData(extType, data string) (string, error) {
return a.inner.AppendExtensionData(extType, data)
}
// GetExtensionData implements SessionManager.
func (a *treeManagerAdapter) GetExtensionData(extType string) []ExtensionDataEntry {
entries := a.inner.GetExtensionData(extType)
var result []ExtensionDataEntry
for _, e := range entries {
result = append(result, ExtensionDataEntry{
ID: e.ID,
ExtType: e.ExtType,
Data: e.Data,
Timestamp: e.Timestamp,
})
}
return result
}
// AppendModelChange implements SessionManager.
func (a *treeManagerAdapter) AppendModelChange(provider, modelID string) (string, error) {
return a.inner.AppendModelChange(provider, modelID)
}
// GetContextEntryIDs implements SessionManager.
func (a *treeManagerAdapter) GetContextEntryIDs() []string {
return a.inner.GetContextEntryIDs()
}
// Close implements SessionManager.
func (a *treeManagerAdapter) Close() error {
return a.inner.Close()
}
// Helper: Convert internal entry types to BranchEntry
func (a *treeManagerAdapter) convertEntry(entry any) *BranchEntry {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// Build content text from parts
var content strings.Builder
for _, part := range msg.Parts {
if textPart, ok := part.(TextContent); ok {
content.WriteString(textPart.Text)
}
}
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeMessage,
Role: string(msg.Role),
Content: content.String(),
Model: e.Model,
Provider: e.Provider,
Timestamp: e.Timestamp,
RawParts: msg.Parts,
}
case *session.BranchSummaryEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeBranchSummary,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ModelChangeEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeModelChange,
Content: "Model changed to " + e.Provider + "/" + e.ModelID,
Model: e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp,
}
case *session.CompactionEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeCompaction,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ExtensionDataEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeExtensionData,
Content: "Extension data: " + e.ExtType,
Timestamp: e.Timestamp,
}
default:
return nil
}
}
// convertKitMessagesToFantasy converts kit LLM messages to fantasy messages.
// Since LLMMessage is an alias for fantasy.Message, this is a no-op.
func convertKitMessagesToFantasy(msgs []LLMMessage) []fantasy.Message {
// LLMMessage is just an alias for fantasy.Message, so we can type convert
return msgs
}
+12 -12
View File
@@ -21,9 +21,9 @@ type ContextStats struct {
const defaultReserveTokens = 16384
// EstimateContextTokens returns the estimated token count of the current
// conversation based on tree session messages.
// conversation based on session messages.
func (m *Kit) EstimateContextTokens() int {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
return compaction.EstimateMessageTokens(messages)
}
@@ -42,8 +42,8 @@ func (m *Kit) ShouldCompact() bool {
reserveTokens = m.compactionOpts.ReserveTokens
}
messages := m.treeSession.GetLLMMessages()
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
messages := m.session.GetMessages()
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
}
// GetContextStats returns current context usage statistics including
@@ -55,7 +55,7 @@ func (m *Kit) ShouldCompact() bool {
// because it includes system prompts, tool definitions, and other overhead
// that the heuristic cannot account for.
func (m *Kit) GetContextStats() ContextStats {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
// Prefer the real API-reported input token count when available.
m.lastInputTokensMu.RLock()
@@ -114,7 +114,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
}
}
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
if len(messages) < 2 {
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
}
@@ -145,7 +145,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Carry forward file tracking from previous compaction.
var prev *compaction.PreviousCompaction
if lastCompaction := m.treeSession.GetLastCompaction(); lastCompaction != nil {
if lastCompaction := m.session.GetLastCompaction(); lastCompaction != nil {
prev = &compaction.PreviousCompaction{
ReadFiles: lastCompaction.ReadFiles,
ModifiedFiles: lastCompaction.ModifiedFiles,
@@ -171,7 +171,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Non-destructive: append a CompactionEntry to the session tree instead
// of clearing and rewriting messages.
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if result.CutPoint >= 0 && result.CutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[result.CutPoint]
@@ -188,9 +188,9 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// custom summary. It still determines the cut point and persists a
// CompactionEntry.
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
originalTokens := compaction.EstimateMessageTokens(messages)
originalTokens := compaction.EstimateMessageTokens(convertKitMessagesToFantasy(messages))
cutPoint := compaction.FindCutPoint(messages, opts.KeepRecentTokens)
cutPoint := compaction.FindCutPoint(convertKitMessagesToFantasy(messages), opts.KeepRecentTokens)
if cutPoint == 0 {
cutPoint = len(messages) - 1
if cutPoint < 1 {
@@ -198,7 +198,7 @@ func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts
}
}
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if cutPoint >= 0 && cutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[cutPoint]
@@ -234,7 +234,7 @@ func (m *Kit) persistAndEmitCompaction(
originalTokens, compactedTokens, messagesRemoved int,
readFiles, modifiedFiles []string,
) error {
if _, err := m.treeSession.AppendCompaction(
if _, err := m.session.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
+2
View File
@@ -48,6 +48,8 @@ func setSDKDefaults() {
viper.SetDefault("temperature", 0.7)
viper.SetDefault("top-p", 0.95)
viper.SetDefault("top-k", 40)
viper.SetDefault("frequency-penalty", 0.0)
viper.SetDefault("presence-penalty", 0.0)
viper.SetDefault("stream", true)
viper.SetDefault("thinking-level", "off")
viper.SetDefault("num-gpu-layers", -1)
+34 -11
View File
@@ -227,28 +227,51 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
// Session data
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
return iterBranchMessages(e.kit.treeSession, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if e.kit.session == nil {
return nil
}
// Try to use the legacy iterBranchMessages for backward compatibility
// with the default TreeManager adapter
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
}
})
}
// For custom SessionManagers, use the public interface
branch := e.kit.session.GetCurrentBranch()
var result []extensions.SessionMessage
for _, entry := range branch {
if entry.Type == EntryTypeMessage {
result = append(result, extensions.SessionMessage{
ID: entry.ID,
Role: entry.Role,
Content: entry.Content,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
})
}
return result
}
func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return "", fmt.Errorf("no session available")
}
return e.kit.treeSession.AppendExtensionData(extType, data)
return e.kit.session.AppendExtensionData(extType, data)
}
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return nil
}
entries := e.kit.treeSession.GetExtensionData(extType)
entries := e.kit.session.GetExtensionData(extType)
result := make([]extensions.ExtensionEntry, 0, len(entries))
for _, e := range entries {
result = append(result, extensions.ExtensionEntry{
+152 -46
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
@@ -11,7 +12,6 @@ import (
"time"
"charm.land/fantasy"
charmlog "github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
@@ -39,7 +39,7 @@ type ContextFile struct {
// agents, sessions, and model configurations.
type Kit struct {
agent *agent.Agent
treeSession *session.TreeManager
session SessionManager
modelString string
events *eventBus
autoCompact bool
@@ -48,6 +48,8 @@ type Kit struct {
skills []*skills.Skill
extRunner *extensions.Runner
bufferedLogger *tools.BufferedDebugLogger
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
opts *Options // stored for reload operations (skills, etc.)
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
@@ -112,15 +114,32 @@ func (m *Kit) GetLoadingMessage() string {
}
// GetLoadedServerNames returns the names of successfully loaded MCP servers.
// If MCP servers are still loading in the background, this returns only the
// servers that have completed loading so far.
func (m *Kit) GetLoadedServerNames() []string {
return m.agent.GetLoadedServerNames()
}
// GetMCPToolCount returns the number of tools loaded from external MCP servers.
// If MCP servers are still loading in the background, this returns the count
// of tools loaded so far (may be 0).
func (m *Kit) GetMCPToolCount() int {
return m.agent.GetMCPToolCount()
}
// WaitForMCPTools blocks until background MCP tool loading completes.
// Returns nil if no MCP servers are configured or if loading succeeded.
// Returns the loading error if all servers failed. Safe to call multiple times.
func (m *Kit) WaitForMCPTools() error {
return m.agent.WaitForMCPTools()
}
// MCPToolsReady returns true if MCP tool loading has completed (or was never
// started). This is a non-blocking check useful for UI status display.
func (m *Kit) MCPToolsReady() bool {
return m.agent.MCPToolsReady()
}
// GetExtensionToolCount returns the number of tools registered by extensions.
func (m *Kit) GetExtensionToolCount() int {
return m.agent.GetExtensionToolCount()
@@ -153,27 +172,39 @@ type StructuredMessage struct {
// flattens all content to a single text string, this preserves tool calls,
// tool results, reasoning blocks, and finish markers as distinct typed parts.
func (m *Kit) GetStructuredMessages() []StructuredMessage {
return iterBranchMessages(m.treeSession, func(me *session.MessageEntry, msg message.Message) StructuredMessage {
return StructuredMessage{
ID: me.ID,
ParentID: me.ParentID,
Role: msg.Role,
Parts: msg.Parts,
Model: msg.Model,
Provider: msg.Provider,
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if m.session == nil {
return nil
}
branch := m.session.GetCurrentBranch()
var results []StructuredMessage
for _, entry := range branch {
if entry.Type != EntryTypeMessage {
continue
}
})
results = append(results, StructuredMessage{
ID: entry.ID,
ParentID: entry.ParentID,
Role: MessageRole(entry.Role),
Parts: entry.RawParts,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
return results
}
// iterBranchMessages iterates over the current branch's MessageEntry items,
// converting each to a message.Message and calling fn to build the result.
// Returns nil if there is no tree session. Skips entries that are not
// Returns nil if there is no session. Skips entries that are not
// MessageEntry or that fail conversion.
// Deprecated: Use SessionManager.GetCurrentBranch() directly.
func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.MessageEntry, message.Message) T) []T {
if tm == nil {
return nil
}
branch := tm.GetBranch("")
var results []T
for _, entry := range branch {
@@ -224,6 +255,10 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
config.TopP = &topP
topK := int32(viper.GetInt("top-k"))
config.TopK = &topK
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
config.FrequencyPenalty = &frequencyPenalty
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
config.PresencePenalty = &presencePenalty
if err := m.agent.SetModel(ctx, config); err != nil {
return err
@@ -422,6 +457,17 @@ type Options struct {
Tools []Tool // Custom tool set. If empty, AllTools() is used.
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
// SkipConfig, when true, skips loading .kit.yml configuration files.
// Viper defaults (setSDKDefaults) and environment variables (KIT_*)
// are still applied. Use this for fully programmatic configuration.
SkipConfig bool
// DisableCoreTools, when true, prevents loading any core tools.
// Use with Tools or ExtraTools to provide only custom tools.
// If both DisableCoreTools is true and Tools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// Session configuration
SessionDir string // Base directory for session discovery (default: cwd)
SessionPath string // Open a specific session file by path
@@ -439,8 +485,32 @@ type Options struct {
// Debug enables debug logging for the SDK.
Debug bool
// MCPAuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports (streamable HTTP, SSE) are configured with
// OAuth support. If the server returns a 401, the handler is invoked to
// let the user authorize via browser.
//
// If nil, a [DefaultMCPAuthHandler] is created automatically — opening the
// system browser and listening on a local callback server.
//
// Set to a custom implementation to control the authorization UX (e.g.
// display a URL in a custom UI, redirect to a web app, etc.).
MCPAuthHandler MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading during Kit initialization. The callback receives the server name,
// tool count, and any error. Called from a background goroutine; safe to
// call app.NotifyMCPServerLoaded() from within the callback to display
// real-time progress in the TUI.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// CLI is optional CLI-specific configuration. SDK users leave this nil.
CLI *CLIOptions
// SessionManager allows custom session storage backends.
// If nil (default), Kit uses the built-in file-based TreeManager.
// When provided, SessionPath, Continue, and NoSession options are ignored.
SessionManager SessionManager
}
// CLIOptions holds fields only relevant to the CLI binary. SDK users should
@@ -535,7 +605,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Initialize config (loads config files and env vars).
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
// Check if model is already set, which indicates config was loaded.
if viper.GetString("model") == "" {
// SkipConfig bypasses .kit.yml file loading (viper defaults and env vars still apply).
if !opts.SkipConfig && viper.GetString("model") == "" {
if err := InitConfig(opts.ConfigFile, false); err != nil {
return fmt.Errorf("failed to initialize config: %w", err)
}
@@ -644,17 +715,36 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Pass the pre-built ProviderConfig and scalar viper snapshots so
// SetupAgent doesn't need to re-read viper (which would require the lock).
setupOpts := kitsetup.AgentSetupOptions{
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
DisableCoreTools: opts.DisableCoreTools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
// Set up OAuth handler for remote MCP servers.
// The SDK MCPAuthHandler interface is structurally identical to
// tools.MCPAuthHandler, so any implementation satisfies both.
if opts.MCPAuthHandler != nil {
setupOpts.AuthHandler = opts.MCPAuthHandler
} else {
// Create a default handler that opens the system browser.
defaultHandler, authErr := NewDefaultMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote servers.
log.Printf("WARN Failed to create OAuth handler; remote MCP servers requiring auth will fail: %v", authErr)
} else {
setupOpts.AuthHandler = defaultHandler
}
}
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
@@ -667,16 +757,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, err
}
// Initialize tree session.
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
// Initialize session manager.
var sessionManager SessionManager
if opts.SessionManager != nil {
// Use custom session manager provided by user.
sessionManager = opts.SessionManager
} else {
// DEFAULT: Use built-in TreeManager (existing behavior).
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
}
// Wrap TreeManager in adapter to satisfy SessionManager interface.
sessionManager = NewTreeManagerAdapter(treeSession)
}
k := &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
@@ -685,6 +784,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
@@ -1211,11 +1312,8 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
// Emit step usage event for real-time cost tracking
if viper.GetBool("debug") {
charmlog.Debug("Kit.generate emitting StepUsageEvent",
"input", inputTokens,
"output", outputTokens,
"cacheRead", cacheReadTokens,
"cacheCreate", cacheCreationTokens,
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
)
}
m.events.emit(StepUsageEvent{
@@ -1282,9 +1380,9 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
}
// Persist pre-generation messages to tree session.
// Persist pre-generation messages to session.
for _, msg := range preMessages {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
// Auto-compact if enabled and conversation is near the context limit.
@@ -1292,8 +1390,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
_, _ = m.compactInternal(ctx, m.compactionOpts, "", true) // best-effort, automatic
}
// Build context from the tree so only the current branch is sent.
messages := m.treeSession.GetLLMMessages()
// Build context from the session so only the current branch is sent.
messages, _, _ := m.session.BuildContext()
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
@@ -1316,7 +1414,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
// (pending) message or tool call is discarded.
if result != nil && len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
}
m.events.emit(TurnEndEvent{Error: err})
@@ -1332,7 +1430,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
// GetContextStats() see up-to-date token counts.
if len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
}
@@ -1414,7 +1512,7 @@ func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
// Returns an error if there are no previous messages in the session.
func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
// Verify there is conversation history to follow up on.
if len(m.treeSession.GetLLMMessages()) == 0 {
if len(m.session.GetMessages()) == 0 {
return "", fmt.Errorf("cannot follow up: no previous messages")
}
@@ -1570,10 +1668,12 @@ func (m *Kit) PromptResultWithMessages(ctx context.Context, messages []string) (
return m.runTurn(ctx, promptLabel, messages[len(messages)-1], preMessages)
}
// ClearSession resets the tree session's leaf pointer to the root, starting
// ClearSession resets the session's leaf pointer to the root, starting
// a fresh conversation branch.
func (m *Kit) ClearSession() {
m.treeSession.ResetLeaf()
if m.session != nil {
_ = m.session.Branch("")
}
}
// GetModelString returns the current model string identifier (e.g.,
@@ -1642,8 +1742,14 @@ func (m *Kit) Close() error {
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionShutdown) {
_, _ = m.extRunner.Emit(extensions.SessionShutdownEvent{})
}
if m.treeSession != nil {
_ = m.treeSession.Close()
if m.session != nil {
_ = m.session.Close()
}
// Release the OAuth callback port if we own the handler.
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
_ = closer.Close()
}
return m.agent.Close()
}
// Conversion helpers are defined in adapter.go.
+265
View File
@@ -0,0 +1,265 @@
package kit
import (
"context"
"fmt"
"net"
"net/http"
"os/exec"
"runtime"
"sync"
"time"
)
// MCPAuthHandler handles OAuth authorization for MCP servers.
// Implementations control the user experience — opening a browser, showing a
// prompt, displaying a URL, etc.
//
// The default implementation ([DefaultMCPAuthHandler]) opens the system browser
// and starts a local HTTP callback server to receive the authorization code.
type MCPAuthHandler interface {
// RedirectURI returns the OAuth redirect URI that the callback server
// will listen on. This is called during MCP transport setup — before any
// OAuth errors occur — so the redirect URI can be registered with the
// authorization server.
RedirectURI() string
// HandleAuth is called when an MCP server requires OAuth authorization.
// It receives the server name and an authorization URL that the user must
// visit. The handler must:
// 1. Direct the user to authURL (e.g. open browser, display URL)
// 2. Listen for the OAuth callback on the redirect URI
// 3. Return the full callback URL (with code and state query params)
//
// Return an error to abort the connection to this MCP server.
// The context controls the overall timeout; implementations should
// respect ctx.Done().
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// DefaultMCPAuthHandler opens the system browser and starts a local HTTP
// callback server to receive the OAuth authorization code. It eagerly reserves
// a TCP port on construction so [RedirectURI] is stable for the lifetime of
// the handler.
//
// Create instances with [NewDefaultMCPAuthHandler] (random port) or
// [NewDefaultMCPAuthHandlerWithPort] (explicit port).
type DefaultMCPAuthHandler struct {
listener net.Listener
port int
mu sync.Mutex // guards listener lifecycle
}
// NewDefaultMCPAuthHandler creates a handler that listens on a random
// available port on localhost. The port is reserved immediately so
// [RedirectURI] returns a stable value. Call [DefaultMCPAuthHandler.Close]
// when the handler is no longer needed to release the port.
func NewDefaultMCPAuthHandler() (*DefaultMCPAuthHandler, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to listen for OAuth callback: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
return &DefaultMCPAuthHandler{listener: listener, port: port}, nil
}
// NewDefaultMCPAuthHandlerWithPort creates a handler that listens on the
// specified port on localhost. The port is reserved immediately. Pass 0 to
// let the OS pick a free port (equivalent to [NewDefaultMCPAuthHandler]).
// Call [DefaultMCPAuthHandler.Close] when the handler is no longer needed.
func NewDefaultMCPAuthHandlerWithPort(port int) (*DefaultMCPAuthHandler, error) {
addr := fmt.Sprintf("localhost:%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s for OAuth callback: %w", addr, err)
}
actualPort := listener.Addr().(*net.TCPAddr).Port
return &DefaultMCPAuthHandler{listener: listener, port: actualPort}, nil
}
// RedirectURI returns the OAuth redirect URI pointing to the local callback
// server. This value is stable for the lifetime of the handler.
func (h *DefaultMCPAuthHandler) RedirectURI() string {
return fmt.Sprintf("http://localhost:%d/oauth/callback", h.port)
}
// Port returns the TCP port the callback server is bound to.
func (h *DefaultMCPAuthHandler) Port() int {
return h.port
}
// HandleAuth opens the system browser to authURL and waits for the OAuth
// callback on the local server. It returns the full callback URL including
// query parameters (code, state, etc.).
//
// If the context has no deadline, a default 2-minute timeout is applied.
// The callback server is started for each HandleAuth call and shut down
// before returning.
func (h *DefaultMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
h.mu.Lock()
listener := h.listener
h.mu.Unlock()
if listener == nil {
return "", fmt.Errorf("OAuth callback handler is closed")
}
// Apply default timeout if the context has no deadline.
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
}
// Channel receives the full callback URL from the HTTP handler.
callbackCh := make(chan string, 1)
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
// Reconstruct the full callback URL as the caller expects it.
fullURL := fmt.Sprintf("http://localhost:%d%s", h.port, r.RequestURI)
// Send the callback URL to the waiting goroutine (non-blocking).
select {
case callbackCh <- fullURL:
default:
}
// Respond with a friendly HTML page so the user knows they can
// close the browser tab.
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprint(w, oauthSuccessHTML)
})
server := &http.Server{
Handler: mux,
}
// Start serving on the pre-reserved listener. We need to create a new
// listener on the same port because http.Server.Serve takes ownership
// and closes the listener when done. The original listener is kept open
// to reserve the port; we create a second listener via SO_REUSEADDR
// semantics (Go's default on most platforms) or, more reliably, we
// temporarily release and re-acquire.
//
// Strategy: use the held listener directly for Serve. After Serve
// returns (due to Shutdown), re-acquire the listener to keep the port
// reserved for future HandleAuth calls.
h.mu.Lock()
serveListener := h.listener
h.listener = nil // Serve will close it
h.mu.Unlock()
if serveListener == nil {
return "", fmt.Errorf("OAuth callback handler is closed")
}
// Start the HTTP server in a background goroutine.
serverErrCh := make(chan error, 1)
go func() {
err := server.Serve(serveListener)
if err != nil && err != http.ErrServerClosed {
serverErrCh <- err
}
close(serverErrCh)
}()
// Re-acquire the listener after Serve completes (deferred).
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
_ = server.Shutdown(shutdownCtx)
// Re-reserve the port for future HandleAuth calls.
h.mu.Lock()
defer h.mu.Unlock()
if h.listener == nil {
newListener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", h.port))
if err == nil {
h.listener = newListener
}
// If re-listen fails, the handler degrades gracefully — the
// next HandleAuth call will return an error.
}
}()
// Open the system browser.
if err := openBrowser(authURL); err != nil {
// Browser open is best-effort; the user can still navigate manually.
_ = err
}
// Wait for the callback, a server error, or context cancellation.
select {
case url := <-callbackCh:
return url, nil
case err := <-serverErrCh:
return "", fmt.Errorf("OAuth callback server error for %q: %w", serverName, err)
case <-ctx.Done():
return "", fmt.Errorf("OAuth authorization timed out for %q: %w", serverName, ctx.Err())
}
}
// Close releases the reserved port and shuts down the handler. After Close,
// HandleAuth will return an error. Close is safe to call multiple times.
func (h *DefaultMCPAuthHandler) Close() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.listener != nil {
err := h.listener.Close()
h.listener = nil
return err
}
return nil
}
// openBrowser opens the default system browser to the given URL. This is a
// best-effort operation — errors are returned but callers typically ignore
// them since the user can navigate manually.
func openBrowser(url string) error {
switch runtime.GOOS {
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
return exec.Command("open", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// oauthSuccessHTML is the HTML page returned to the browser after a
// successful OAuth callback.
const oauthSuccessHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Authorization Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background: #f8f9fa;
color: #333;
}
.container {
text-align: center;
padding: 2rem;
}
h1 { color: #22863a; }
p { color: #586069; margin-top: 0.5rem; }
</style>
</head>
<body>
<div class="container">
<h1>&#10003; Authorization Successful</h1>
<p>You can close this tab and return to the terminal.</p>
</div>
</body>
</html>`
+68
View File
@@ -0,0 +1,68 @@
package kit
import (
"context"
"fmt"
"io"
"os"
)
// CLIMCPAuthHandler wraps a [DefaultMCPAuthHandler] and prints status messages
// to a writer (typically stderr) so the user knows what's happening during
// OAuth authorization. This is the handler used by the CLI/TUI binary.
//
// For TUI integration, set NotifyFunc to route messages through the TUI's
// event system instead of (or in addition to) the writer.
type CLIMCPAuthHandler struct {
inner *DefaultMCPAuthHandler
w io.Writer
// NotifyFunc, when set, is called with status messages instead of writing
// to the writer. This allows the TUI to display system messages in the
// chat stream. If nil, messages are written to w.
NotifyFunc func(serverName, message string)
}
// NewCLIMCPAuthHandler creates a CLI auth handler that prints status messages
// to stderr and delegates the actual OAuth flow to a [DefaultMCPAuthHandler].
func NewCLIMCPAuthHandler() (*CLIMCPAuthHandler, error) {
inner, err := NewDefaultMCPAuthHandler()
if err != nil {
return nil, err
}
return &CLIMCPAuthHandler{inner: inner, w: os.Stderr}, nil
}
// RedirectURI returns the OAuth redirect URI from the inner handler.
func (h *CLIMCPAuthHandler) RedirectURI() string {
return h.inner.RedirectURI()
}
// HandleAuth prints status messages and delegates to the inner handler.
func (h *CLIMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
h.notify(serverName, fmt.Sprintf("🔐 MCP server %q requires authentication. Opening browser...", serverName))
h.notify(serverName, fmt.Sprintf(" If the browser doesn't open, visit:\n %s", authURL))
callbackURL, err := h.inner.HandleAuth(ctx, serverName, authURL)
if err != nil {
h.notify(serverName, fmt.Sprintf("✗ Authentication failed for %q: %v", serverName, err))
return "", err
}
h.notify(serverName, fmt.Sprintf("✓ Authenticated with %q", serverName))
return callbackURL, nil
}
// Close releases the inner handler's resources.
func (h *CLIMCPAuthHandler) Close() error {
return h.inner.Close()
}
// notify sends a message through NotifyFunc if set, otherwise writes to w.
func (h *CLIMCPAuthHandler) notify(serverName, message string) {
if h.NotifyFunc != nil {
h.NotifyFunc(serverName, message)
return
}
_, _ = fmt.Fprintln(h.w, message)
}
+132
View File
@@ -0,0 +1,132 @@
package kit
import (
"time"
)
// SessionManager defines the contract for conversation storage backends.
// Implementations can use files (default), databases, cloud storage, etc.
type SessionManager interface {
// AppendMessage adds a message to the current branch and returns its entry ID.
// The entry ID is used for tree navigation and must be unique within the session.
AppendMessage(msg LLMMessage) (entryID string, err error)
// GetMessages returns all messages on the current branch (from root to leaf),
// including any compaction summaries at the appropriate positions.
GetMessages() []LLMMessage
// BuildContext returns the message history to send to the LLM, applying
// compaction rules and branch summaries as needed.
// Returns: messages, currentProvider, currentModelID
BuildContext() (messages []LLMMessage, provider string, modelID string)
// Branch moves the leaf pointer to the given entry ID, creating a branch point.
// Subsequent AppendMessage calls extend from this new position.
// entryID can be empty to reset to root (new conversation branch).
Branch(entryID string) error
// GetCurrentBranch returns the path from root to current leaf as entry metadata.
// Used for UI display and navigation.
GetCurrentBranch() []BranchEntry
// GetChildren returns direct child entry IDs for a given parent entry.
// Used to display branch points in the conversation tree.
GetChildren(parentID string) []string
// GetEntry returns a specific entry by ID, or nil if not found.
GetEntry(entryID string) *BranchEntry
// GetSessionID returns the unique session identifier (UUID).
GetSessionID() string
// GetSessionName returns the user-defined display name, or empty.
GetSessionName() string
// SetSessionName sets a display name for the session.
SetSessionName(name string) error
// GetCreatedAt returns when the session was created.
GetCreatedAt() time.Time
// IsPersisted returns true if this session writes to durable storage.
IsPersisted() bool
// AppendCompaction adds a compaction entry that summarizes older messages.
// firstKeptEntryID is the ID of the first message to preserve in context.
// readFiles and modifiedFiles track file changes for the compaction summary.
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
// GetLastCompaction returns the most recent compaction entry on the current
// branch, or nil if none exists.
GetLastCompaction() *CompactionEntry
// AppendExtensionData stores custom extension data in the session tree.
// Extensions use this to persist state across restarts.
AppendExtensionData(extType, data string) (string, error)
// GetExtensionData returns all extension data entries of the given type
// on the current branch. If extType is empty, returns all extension data.
GetExtensionData(extType string) []ExtensionDataEntry
// AppendModelChange records a provider/model switch in the session.
AppendModelChange(provider, modelID string) (string, error)
// GetContextEntryIDs returns the entry IDs corresponding to the messages
// returned by BuildContext, in the same order. Used by compaction to
// determine which entries to summarize.
GetContextEntryIDs() []string
// Close releases resources (database connections, file handles, etc.).
Close() error
}
// BranchEntry represents a single node in the conversation tree.
// This is a SDK-friendly struct (not the internal entry types).
type BranchEntry struct {
ID string
ParentID string
Type EntryType // "message", "branch_summary", "model_change", "compaction", "extension_data"
Role string // for messages: "user", "assistant", "system", "tool"
Content string // text content or summary
Model string // model used (for messages and model_change)
Provider string // provider used
Timestamp time.Time
Children []string // child entry IDs (for tree display)
// RawParts contains the full typed content parts for structured access.
// Only populated for message entries.
RawParts []ContentPart
}
// EntryType identifies the kind of entry in the session tree.
type EntryType string
const (
EntryTypeMessage EntryType = "message"
EntryTypeBranchSummary EntryType = "branch_summary"
EntryTypeModelChange EntryType = "model_change"
EntryTypeCompaction EntryType = "compaction"
EntryTypeExtensionData EntryType = "extension_data"
)
// CompactionEntry represents a context compaction/summarization event.
type CompactionEntry struct {
ID string
Summary string
FirstKeptEntryID string
TokensBefore int
TokensAfter int
MessagesRemoved int
ReadFiles []string
ModifiedFiles []string
Timestamp time.Time
}
// ExtensionDataEntry represents custom extension data stored in the session.
type ExtensionDataEntry struct {
ID string
ExtType string
Data string
Timestamp time.Time
}
+111 -80
View File
@@ -8,7 +8,6 @@ import (
"time"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/session"
)
@@ -47,49 +46,73 @@ func OpenTreeSession(path string) (*TreeManager, error) {
// --- Instance methods on Kit ---
// GetSessionManager returns the session manager, or nil if not configured.
func (m *Kit) GetSessionManager() SessionManager {
return m.session
}
// GetTreeSession returns the tree session manager, or nil if not configured.
// Deprecated: Use GetSessionManager instead.
func (m *Kit) GetTreeSession() *TreeManager {
return m.treeSession
// Try to unwrap the adapter if using default implementation
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner
}
return nil
}
// SetSessionManager replaces the session manager on a Kit instance.
func (m *Kit) SetSessionManager(sm SessionManager) {
m.session = sm
}
// SetTreeSession replaces the tree session on a Kit instance. This is used by
// the CLI when it handles session creation externally (e.g. --resume with a
// TUI picker) and needs to inject the result into a Kit-like workflow.
// Deprecated: Use SetSessionManager instead.
func (m *Kit) SetTreeSession(ts *TreeManager) {
m.treeSession = ts
m.session = NewTreeManagerAdapter(ts)
}
// GetSessionPath returns the file path of the active tree session, or empty
// for in-memory sessions or when no tree session is configured.
// GetSessionPath returns the file path of the active session, or empty
// for in-memory sessions or when no file-based session is configured.
func (m *Kit) GetSessionPath() string {
if m.treeSession != nil {
return m.treeSession.GetFilePath()
// Only file-based sessions have a path
// Try to get it from the underlying TreeManager if using default adapter
if m.session == nil {
return ""
}
// Check if it's the default adapter
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner.GetFilePath()
}
return ""
}
// GetSessionID returns the UUID of the active tree session, or empty when no
// tree session is configured.
// GetSessionID returns the UUID of the active session, or empty when no
// session is configured.
func (m *Kit) GetSessionID() string {
if m.treeSession != nil {
return m.treeSession.GetSessionID()
if m.session == nil {
return ""
}
return ""
return m.session.GetSessionID()
}
// Branch moves the tree session's leaf pointer to the given entry ID, creating
// Branch moves the session's leaf pointer to the given entry ID, creating
// a branch point. Subsequent Prompt() calls will extend from the new position.
func (m *Kit) Branch(entryID string) error {
return m.treeSession.Branch(entryID)
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.session.Branch(entryID)
}
// SetSessionName sets a user-defined display name for the active tree session.
// SetSessionName sets a user-defined display name for the active session.
func (m *Kit) SetSessionName(name string) error {
if m.treeSession == nil {
return fmt.Errorf("session naming requires a tree session")
if m.session == nil {
return fmt.Errorf("session naming requires a session")
}
_, err := m.treeSession.AppendSessionInfo(name)
return err
return m.session.SetSessionName(name)
}
// ---------------------------------------------------------------------------
@@ -97,27 +120,27 @@ func (m *Kit) SetSessionName(name string) error {
// ---------------------------------------------------------------------------
// GetTreeNode returns a node by ID with full metadata and children.
// Returns nil if entry not found or no tree session.
// Returns nil if entry not found or no session.
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
entry := m.treeSession.GetEntry(entryID)
entry := m.session.GetEntry(entryID)
if entry == nil {
return nil
}
return m.entryToTreeNode(entry)
return m.branchEntryToTreeNode(entry)
}
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
func (m *Kit) GetCurrentBranch() []TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var nodes []TreeNode
for _, entry := range branch {
node := m.entryToTreeNode(entry)
node := m.branchEntryToTreeNode(&entry)
if node != nil {
nodes = append(nodes, *node)
}
@@ -127,34 +150,34 @@ func (m *Kit) GetCurrentBranch() []TreeNode {
// GetChildren returns direct child IDs of an entry.
func (m *Kit) GetChildren(parentID string) []string {
if m.treeSession == nil {
if m.session == nil {
return nil
}
return m.treeSession.GetChildren(parentID)
return m.session.GetChildren(parentID)
}
// NavigateTo branches/forks the session to the specified entry ID.
// Returns an error if the session is unavailable or the entry ID is not found.
func (m *Kit) NavigateTo(entryID string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.treeSession.Branch(entryID)
return m.session.Branch(entryID)
}
// SummarizeBranch uses the LLM to summarize the conversation between two
// entry IDs. Returns the summary text, or an error if the range is invalid,
// the session is unavailable, or the LLM call fails.
func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
if m.treeSession == nil {
return "", fmt.Errorf("no tree session available")
if m.session == nil {
return "", fmt.Errorf("no session available")
}
// Get the branch and find the range
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var startIdx, endIdx = -1, -1
for i, entry := range branch {
id := m.treeSession.EntryID(entry)
id := entry.ID
if id == fromID {
startIdx = i
}
@@ -170,7 +193,7 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// Build text to summarize
var content strings.Builder
for i := startIdx; i <= endIdx; i++ {
node := m.entryToTreeNode(branch[i])
node := m.branchEntryToTreeNode(&branch[i])
if node != nil && node.Content != "" {
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
}
@@ -195,73 +218,81 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// CollapseBranch replaces a branch range with a summary entry.
// Returns an error if the session is unavailable or the operation fails.
func (m *Kit) CollapseBranch(fromID, toID, summary string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
return err
// Note: This operation is not directly supported by SessionManager interface
// as it requires AppendBranchSummary which is TreeManager-specific.
// For custom SessionManagers, this would need to be implemented differently.
// For now, we try to use the underlying TreeManager if available.
if adapter, ok := m.session.(*treeManagerAdapter); ok {
_, err := adapter.inner.AppendBranchSummary(fromID, summary)
return err
}
return fmt.Errorf("CollapseBranch not supported by custom session manager")
}
// entryToTreeNode converts a session entry to a TreeNode.
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// branchEntryToTreeNode converts a BranchEntry to a TreeNode.
func (m *Kit) branchEntryToTreeNode(entry *BranchEntry) *TreeNode {
if entry == nil {
return nil
}
switch entry.Type {
case EntryTypeMessage:
// Build content from RawParts
var content strings.Builder
for _, p := range msg.Parts {
for _, p := range entry.RawParts {
switch pt := p.(type) {
case message.TextContent:
case TextContent:
content.WriteString(pt.Text)
case message.ReasoningContent:
case ReasoningContent:
content.WriteString(pt.Thinking)
case message.ToolCall:
case ToolCall:
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
case message.ToolResult:
case ToolResult:
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
}
}
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "message",
Role: string(msg.Role),
Role: entry.Role,
Content: content.String(),
Model: msg.Model,
Provider: msg.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.BranchSummaryEntry:
case EntryTypeBranchSummary:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "branch_summary",
Content: e.Summary,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ModelChangeEntry:
case EntryTypeModelChange:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "model_change",
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
Model: e.Provider + "/" + e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ExtensionDataEntry:
case EntryTypeExtensionData:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "extension_data",
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
default:
return nil
+13
View File
@@ -1,6 +1,7 @@
package kit
import (
"fmt"
"os"
"github.com/mark3labs/kit/internal/extensions"
@@ -136,3 +137,15 @@ func (m *Kit) ClearSkillCache() {
defer m.skillCache.mu.Unlock()
m.skillCache.skills = nil
}
// ReloadSkills re-discovers skills from disk, replacing the current set.
// This is called by file watchers when skill files change.
func (m *Kit) ReloadSkills() error {
newSkills, err := loadSkills(m.opts)
if err != nil {
return fmt.Errorf("reloading skills: %w", err)
}
m.skills = newSkills
m.ClearSkillCache()
return nil
}
+73 -2
View File
@@ -85,10 +85,15 @@ host, err := kit.New(ctx, &kit.Options{
SessionPath: "/path/to/session.jsonl", // open specific session file
Continue: true, // resume most recent session for SessionDir
NoSession: true, // ephemeral in-memory session, no disk persistence
SessionManager: myCustomSession, // custom SessionManager implementation (advanced)
// Tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Skills
Skills: []string{"/path/to/skill.md"}, // explicit skill files (empty = auto-discover)
@@ -431,6 +436,72 @@ kit.DeleteSession("/path/to/session.jsonl")
tm, _ := kit.OpenTreeSession("/path/to/session.jsonl") // open for direct access
```
### Custom Session Manager (Advanced)
You can provide a custom session manager to store conversation history in your own backend (database, cloud storage, etc.) instead of the default JSONL files.
```go
// Implement the SessionManager interface
type MyDatabaseSessionManager struct {
db *sql.DB
// ... other fields
}
func (s *MyDatabaseSessionManager) AppendMessage(msg kit.LLMMessage) (string, error) {
// Store message in your database
}
func (s *MyDatabaseSessionManager) GetMessages() []kit.LLMMessage {
// Retrieve messages from your database
}
// ... implement all other SessionManager methods
// Use with Kit
host, _ := kit.New(ctx, &kit.Options{
SessionManager: myCustomSession, // Your custom implementation
Model: "anthropic/claude-sonnet-latest",
})
```
**SessionManager Interface:**
```go
type SessionManager interface {
AppendMessage(msg kit.LLMMessage) (entryID string, err error)
GetMessages() []kit.LLMMessage
BuildContext() (messages []kit.LLMMessage, provider string, modelID string)
Branch(entryID string) error
GetCurrentBranch() []kit.BranchEntry
GetChildren(parentID string) []string
GetEntry(entryID string) *kit.BranchEntry
GetSessionID() string
GetSessionName() string
SetSessionName(name string) error
GetCreatedAt() time.Time
IsPersisted() bool
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
GetLastCompaction() *kit.CompactionEntry
AppendExtensionData(extType, data string) (string, error)
GetExtensionData(extType string) []kit.ExtensionDataEntry
AppendModelChange(provider, modelID string) (string, error)
GetContextEntryIDs() []string
Close() error
}
```
**Use Cases:**
- **PocketBase integration**: Store sessions as PocketBase records
- **Cloud storage**: Persist sessions to S3, GCS, or Azure Blob
- **Multi-user apps**: Store sessions per user in a database
- **Custom retention**: Implement your own session cleanup policies
**Note:** When using a custom SessionManager, the following Options are ignored:
- `SessionPath` - your manager handles its own storage
- `Continue` - your manager handles session selection
- `NoSession` - use an in-memory implementation instead
---
## Model Management
+8 -2
View File
@@ -29,8 +29,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true,
// Tools
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true,
@@ -58,6 +62,8 @@ host, err := kit.New(ctx, &kit.Options{
| `NoSession` | `bool` | `false` | Ephemeral mode (no persistence) |
| `Tools` | `[]Tool` | — | Replace the entire default tool set |
| `ExtraTools` | `[]Tool` | — | Additional tools alongside core/MCP/extension tools |
| `DisableCoreTools` | `bool` | `false` | Use no core tools (0 tools, for chat-only) |
| `SkipConfig` | `bool` | `false` | Skip .kit.yml file loading |
| `AutoCompact` | `bool` | `false` | Auto-compact when near context limit |
| `CompactionOptions` | `*CompactionOptions` | — | Configuration for auto-compaction |
| `Skills` | `[]string` | — | Explicit skill files/dirs to load |
+6
View File
@@ -101,6 +101,12 @@ The `/share` command uploads your session JSONL to GitHub Gist (via the `gh` CLI
/share
```
The shared session includes:
- The **system prompt** that was active during the conversation
- The **model** used (e.g., `anthropic/claude-sonnet-4-5`)
The viewer displays this information in a collapsible "System Prompt" section at the top of the session, with the model shown as a badge in the header.
The viewer is available at `https://go-kit.dev/session/#GIST_ID` and supports all message types including text, reasoning blocks, tool calls, images, and model changes.
You can also load any JSONL session via URL parameter: `https://go-kit.dev/session/?url=https://example.com/session.jsonl`
+145 -3
View File
@@ -901,6 +901,93 @@ a:hover { text-decoration: underline; }
color: var(--text-muted);
}
/* ============================================================
System Prompt Display
============================================================ */
.system-prompt-container {
margin: 16px 0;
border: 1px solid var(--border);
border-radius: var(--radius);
background: var(--surface);
overflow: hidden;
}
.system-prompt-header {
display: flex;
align-items: center;
gap: 10px;
padding: 12px 16px;
cursor: pointer;
user-select: none;
transition: background var(--transition);
}
.system-prompt-header:hover {
background: var(--surface-raised);
}
.system-prompt-icon {
width: 16px;
height: 16px;
color: var(--accent);
flex-shrink: 0;
}
.system-prompt-label {
font-size: 13px;
font-weight: 600;
color: var(--text-secondary);
flex: 1;
}
.system-prompt-chevron {
width: 16px;
height: 16px;
color: var(--text-faint);
transition: transform var(--transition);
flex-shrink: 0;
}
.system-prompt-chevron.expanded {
transform: rotate(180deg);
}
.system-prompt-content {
max-height: 0;
overflow: hidden;
transition: max-height var(--transition);
border-top: 1px solid transparent;
}
.system-prompt-content.expanded {
max-height: 600px;
overflow-y: auto;
border-top-color: var(--border);
}
.system-prompt-text {
margin: 0;
padding: 16px;
font-family: var(--font-mono);
font-size: 12px;
line-height: 1.6;
color: var(--text-secondary);
white-space: pre-wrap;
word-break: break-word;
background: var(--surface-raised);
}
.system-prompt-model {
font-size: 11px;
font-weight: 500;
color: var(--text-muted);
background: var(--surface-raised);
padding: 3px 8px;
border-radius: var(--radius-sm);
border: 1px solid var(--border);
margin-right: 8px;
}
/* ============================================================
Stats Bar
============================================================ */
@@ -1262,9 +1349,10 @@ a:hover { text-decoration: underline; }
// Tree Building — extract active path from root to latest leaf
// ============================================================
function buildActivePath(entries) {
if (entries.length === 0) return { header: null, path: [] };
if (entries.length === 0) return { header: null, path: [], systemPrompt: null };
let header = null;
let systemPrompt = null;
const nodeMap = new Map(); // id -> entry
const childrenMap = new Map(); // parentId -> [entry, ...]
@@ -1273,6 +1361,10 @@ a:hover { text-decoration: underline; }
header = entry;
continue;
}
if (entry.type === 'system_prompt') {
systemPrompt = entry;
continue;
}
if (entry.id) {
nodeMap.set(entry.id, entry);
}
@@ -1299,14 +1391,14 @@ a:hover { text-decoration: underline; }
return path;
}
return { header, path: findActivePath() };
return { header, path: findActivePath(), systemPrompt };
}
// ============================================================
// Rendering — Main entry point
// ============================================================
function renderSession(entries) {
const { header, path } = buildActivePath(entries);
const { header, path, systemPrompt } = buildActivePath(entries);
$conversation.innerHTML = '';
// Update header
@@ -1329,6 +1421,11 @@ a:hover { text-decoration: underline; }
$headerSessionName.textContent = 'Session';
}
// Render system prompt if present (collapsible)
if (systemPrompt && systemPrompt.content) {
renderSystemPrompt(systemPrompt);
}
// Track the current model for assistant messages
let currentModel = '';
let currentProvider = '';
@@ -1933,6 +2030,51 @@ a:hover { text-decoration: underline; }
$conversation.appendChild(el);
}
// ============================================================
// System Prompt Display (collapsible)
// ============================================================
function renderSystemPrompt(entry) {
const el = document.createElement('div');
el.className = 'system-prompt-container fade-in';
const promptId = 'sys-prompt-' + Math.random().toString(36).substr(2, 9);
// Build model badge if model info is available
let modelBadge = '';
if (entry.model || entry.provider) {
const modelText = entry.provider && entry.model
? `${entry.provider}/${entry.model}`
: (entry.model || entry.provider);
modelBadge = `<span class="system-prompt-model">${escapeHtml(modelText)}</span>`;
}
el.innerHTML = `
<div class="system-prompt-header" onclick="toggleSystemPrompt('${promptId}')">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="system-prompt-icon">
<path d="M8 0a8 8 0 1 1 0 16A8 8 0 0 1 8 0ZM4.5 7.5a.5.5 0 0 0 0 1h5.793l-2.147 2.146a.5.5 0 0 0 .707.707l3-3a.5.5 0 0 0 0-.707l-3-3a.5.5 0 1 0-.707.707L10.293 7.5H4.5Z"/>
</svg>
<span class="system-prompt-label">System Prompt</span>
${modelBadge}
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" fill="currentColor" class="system-prompt-chevron" id="${promptId}-chevron">
<path d="M4.427 9.427a.25.25 0 0 0 0 .353l3 3a.25.25 0 0 0 .353 0l3-3a.25.25 0 0 0-.353-.353L8 11.646V4.75a.75.75 0 0 0-1.5 0v6.896L4.78 9.427a.25.25 0 0 0-.353 0Z"/>
</svg>
</div>
<div id="${promptId}" class="system-prompt-content">
<pre class="system-prompt-text">${escapeHtml(entry.content)}</pre>
</div>
`;
$conversation.appendChild(el);
}
window.toggleSystemPrompt = function(id) {
const content = document.getElementById(id);
const chevron = document.getElementById(id + '-chevron');
if (!content) return;
const isExpanded = content.classList.contains('expanded');
content.classList.toggle('expanded');
if (chevron) chevron.classList.toggle('expanded');
};
// ============================================================
// Interactive Handlers (global scope)
// ============================================================