Compare commits

..

14 Commits

Author SHA1 Message Date
Ed Zynda 3ea0db69ea fix(ui): wrap user messages to terminal width
- Add width parameter to UserBlock and apply lipgloss.Wrap() before
  passing content to herald Tip alert
- Subtract 4 from width to account for alert bar prefix and margin
- Pass renderer width from RenderUserMessage to UserBlock
- Mirrors the assistant message wrapping added in e33564c
2026-04-08 15:15:27 +03:00
Ed Zynda 4304a5e899 feat(ui): change steer keybind to Ctrl+X s leader key chord
- Replace single Ctrl+S with Ctrl+X leader prefix followed by "s"
- Add leaderKeyActive flag to AppModel for two-key chord state
- Ctrl+X sets the leader flag; next keypress completes or cancels chord
- Update hint text in input component (adjust width thresholds)
- Update /help command output to reflect new keybind
2026-04-08 15:04:48 +03:00
Ed Zynda 4019c1e4f7 fix(ui): remove character limits from all textarea inputs
- Main message input: 5000 -> unlimited
- Prompt dialog input: 1000 -> unlimited
- Tool approval input: 1000 -> unlimited

Setting CharLimit to 0 disables the limit in Bubble Tea's textarea.
2026-04-08 14:23:34 +03:00
Ed Zynda 30ad7c1d0b feat(sdk): persist session messages incrementally per agent step
- Add StepMessagesHandler callback to agent's GenerateWithLoopAndStreaming
  so callers can persist messages as each step completes
- Wire onStepMessages in Kit.generate() to call session.AppendMessage
  for each step's messages immediately on completion
- Track PersistedMessageCount on GenerateWithLoopResult so runTurn
  skips already-persisted messages in post-generation cleanup
- Tool calls are always persisted as assistant+tool pairs (never orphaned)
- Document concurrency and incremental persistence requirements on
  the SessionManager interface for custom implementations
2026-04-08 14:15:05 +03:00
Ed Zynda e33564c569 fix(ui): wrap assistant messages to terminal width
- ToMarkdown() received a width param but never used it
- Apply lipgloss.Wrap() after herald-md render to break long lines
- Preserves ANSI styles/colors through the wrapping pass
- Fixes overflow for all markdown paths: assistant messages, tool
  bodies, and overlay text
2026-04-08 13:34:33 +03:00
Ed Zynda 5ff28445fd fix(ui): truncate queued and steering message blocks to prevent overflow
- Limit each queued/steering block to 3 visible content lines with ellipsis
- Account for soft-wrapping when counting visual lines
- Truncation is visual only; full text is preserved for scrollback
- Add truncateMessageForBlock helper with wrap-aware line counting
- Add 7 unit tests covering short, exact, overflow, wrapping, and mixed cases
2026-04-08 13:24:26 +03:00
Ed Zynda 13d177e5d0 fix(extensions): use structured logging that respects log levels
Switch from standard log.Printf to charmbracelet/log for extension loading
messages. This ensures DEBUG output only appears when explicitly enabled.

- Remove unconditional WARN log for failed extension loads
- Convert DEBUG loaded extension message to structured log.Debug call
2026-04-08 00:39:21 +03:00
Ed Zynda 3ffc995f27 feat(sdk): add NewTool/NewParallelTool for dependency-free custom tools
- Add ToolOutput struct, TextResult/ErrorResult helpers, and
  ToolCallIDFromContext so SDK consumers can create custom tools
  without importing charm.land/fantasy
- Add NewTool (sequential) and NewParallelTool (concurrent) generic
  constructors with automatic JSON schema generation from struct tags
- Remove dead UpdateUsageFromResponse method and fantasy import from
  internal/ui/cli.go
- Update SDK skill, README, and www/ docs with custom tool examples
  and corrected hook signatures
2026-04-07 22:05:42 +03:00
Ed Zynda b2bd016135 fix(tui): redirect log output to file to prevent TUI corruption
- Add tea.LogToFile in runInteractiveModeBubbleTea to send stdlib log
  output to /tmp/kit/kit.log instead of stderr
- Replace charmbracelet/log with stdlib log in extensions loader,
  runner, watcher, prompts loader, and pkg/kit so all log calls go
  through the redirected stdlib logger
- Leave charmbracelet/log in CLI-only commands (install, acp) and
  acpserver where stderr logging is correct
2026-04-07 21:20:04 +03:00
Ed Zynda 812dedaea2 feat(pkg/kit): add SessionManager interface for custom session backends
Add SessionManager interface to allow pluggable session storage backends.
This enables users to implement custom session managers for databases,
cloud storage, or other persistence mechanisms instead of the default
JSONL file-based TreeManager.

Changes:
- Add SessionManager interface with methods for message storage,
  tree navigation, compaction, and extension data
- Add treeManagerAdapter to wrap existing TreeManager for backward compatibility
- Update Kit struct to use SessionManager interface instead of concrete type
- Add SessionManager option to Options struct
- Update all session-related methods to use interface
- Add documentation for custom SessionManager usage

The default behavior is preserved - when no SessionManager is provided,
Kit automatically uses the TreeManager via the adapter.
2026-04-07 17:41:46 +03:00
Ed Zynda f65b6737f2 feat(sdk): add SkipConfig and DisableCoreTools options
Add two new Options fields for programmatic SDK usage:

- SkipConfig: Skip .kit.yml file loading while still using viper defaults
  and environment variables. Useful for fully programmatic configuration.

- DisableCoreTools: Allow creating agents with 0 tools (chat-only mode) or
  with only custom tools. When true and Tools is empty, no tools are loaded.
  When combined with custom Tools, only those tools are loaded.

Updates documentation in README, pkg/kit/README, skills/kit-sdk/SKILL,
and www/pages/sdk/options.
2026-04-07 17:10:58 +03:00
Ed Zynda 5d45aa196b fix(watcher): remove debug logging that corrupts TUI
Remove charmbracelet/log debug statements from the file watcher that
were writing directly to stderr, corrupting the Bubble Tea terminal UI.

- Remove log.Debug calls for directory operations and file changes
- Remove log.Warn for watcher errors (silently ignore instead)
- Remove the charmbracelet/log import entirely
2026-04-07 16:31:29 +03:00
Ed Zynda debb39f56c fix(ui): show MCP tools in /tools and status bar after async loading
Background MCP tool loading (added in 7e54710) caused tools to not appear
in the UI because tool names and counts were captured at startup before
loading completed. This adds:

- MCPToolsReadyEvent and MCPServerLoadedEvent for progress notifications
- Dynamic GetToolNames/GetMCPToolCount callbacks for live updates
- Per-server status messages as each MCP server finishes loading
- Refresh handlers to update /tools output and status bar when ready
2026-04-07 16:29:09 +03:00
Ed Zynda 7ce6f4fd9e fix(watcher): dynamically watch new subdirectories for skill/prompt reload
- Detect new subdirectory creation in the fsnotify event loop and add
  it to the watcher so files created inside trigger reload events
- Handle cp -r case by checking if new directories already contain
  matching files and scheduling an immediate debounced reload
- Add dirContainsMatchingFiles helper method
- Add tests for both new-subdirectory and copy-with-existing-files cases
2026-04-07 15:01:18 +03:00
36 changed files with 1901 additions and 384 deletions
+28 -1
View File
@@ -531,7 +531,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -540,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
})
```
### Custom Tools
Create custom tools with automatic schema generation — no external dependencies needed:
```go
type SearchInput struct {
Query string `json:"query" description:"Search query"`
}
searchTool := kit.NewTool("search", "Search the codebase",
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
return kit.TextResult("Found: ..."), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
})
```
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
### With Callbacks
```go
+52 -5
View File
@@ -731,6 +731,11 @@ func runNormalMode(ctx context.Context) error {
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
}
// appInstancePtr is used to break the circular dependency between
// kit.New (which needs the OnMCPServerLoaded callback) and app.New
// (which is needed by the callback to send events to the TUI).
var appInstancePtr *app.App
kitOpts := &kit.Options{
Quiet: quietFlag,
Debug: debugMode,
@@ -739,6 +744,14 @@ func runNormalMode(ctx context.Context) error {
SessionPath: sessionPath,
AutoCompact: autoCompactFlag,
MCPAuthHandler: authHandler,
// This callback is called when each MCP server finishes loading.
// We use a closure that captures appInstancePtr which is set after
// app.New() is called below.
OnMCPServerLoaded: func(serverName string, toolCount int, err error) {
if appInstancePtr != nil {
appInstancePtr.NotifyMCPServerLoaded(serverName, toolCount, err)
}
},
CLI: &kit.CLIOptions{
MCPConfig: mcpConfig,
ShowSpinner: true,
@@ -809,6 +822,7 @@ func runNormalMode(ctx context.Context) error {
}
appInstance := app.New(appOpts, messages)
appInstancePtr = appInstance // Wire up the MCP server loaded callback.
defer appInstance.Close()
// Wire OAuth handler to route messages through the TUI once it's running.
@@ -1680,6 +1694,25 @@ func runNormalMode(ctx context.Context) error {
return extensionCommandsForUI(kitInstance)
}
// Build dynamic tool name and MCP tool count providers. These are called
// by the TUI when MCPToolsReadyEvent fires to refresh the /tools list
// and startup info bar after background MCP tool loading completes.
getToolNames := func() []string {
return kitInstance.GetToolNames()
}
getMCPToolCount := func() int {
return kitInstance.GetMCPToolCount()
}
// Start a goroutine that waits for background MCP tool loading to
// complete and notifies the TUI so it can refresh tool names and counts.
if len(mcpConfig.MCPServers) > 0 {
go func() {
_ = kitInstance.WaitForMCPTools()
appInstance.NotifyMCPToolsReady()
}()
}
// Build model switching callbacks for the /model command.
setModelForUI := func(modelString string) error {
err := kitInstance.SetModel(context.Background(), modelString)
@@ -1807,7 +1840,7 @@ func runNormalMode(ctx context.Context) error {
// Check if running in non-interactive mode
if positionalPrompt != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
}
// Quiet mode is not allowed in interactive mode
@@ -1815,7 +1848,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet requires a prompt")
}
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
}
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
@@ -1828,7 +1861,7 @@ func runNormalMode(ctx context.Context) error {
//
// When --no-exit is set, after the prompt completes the interactive BubbleTea
// TUI is started so the user can continue the conversation.
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
// Expand @file references in the prompt before sending to the agent.
if cwd, err := os.Getwd(); err == nil {
prompt = ui.ProcessFileAttachments(prompt, cwd)
@@ -1871,7 +1904,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
// If --no-exit was requested, hand off to the interactive TUI.
if noExit {
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
}
return nil
@@ -1969,7 +2002,19 @@ func writeJSONError(err error) {
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
// background goroutines (watchers, extension handlers, SDK internals)
// will visually corrupt the terminal.
logDir := filepath.Join(os.TempDir(), "kit")
_ = os.MkdirAll(logDir, 0o700)
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
if logErr == nil {
defer func() { _ = logFile.Close() }()
}
// Determine terminal size; fall back gracefully.
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil || termWidth == 0 {
@@ -1988,6 +2033,8 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
Height: termHeight,
ServerNames: serverNames,
ToolNames: toolNames,
GetToolNames: getToolNames,
GetMCPToolCount: getMCPToolCount,
MCPToolCount: mcpToolCount,
ExtensionToolCount: extensionToolCount,
UsageTracker: usageTracker,
+55 -5
View File
@@ -35,6 +35,11 @@ type AgentConfig struct {
// CodingTools or tools with a custom WorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper is an optional function that wraps the combined tool list
// before it is passed to the LLM agent. Used by the extensions system
// to intercept tool calls/results.
@@ -43,6 +48,11 @@ type AgentConfig struct {
// ExtraTools are additional tools to include alongside core and MCP tools.
// Used by extensions to register custom tools.
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). The callback receives the server
// name, tool count, and any error. Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// ToolCallHandler is a function type for handling tool calls as they happen.
@@ -79,6 +89,14 @@ type ReasoningCompleteHandler func()
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
type ToolOutputHandler = core.ToolOutputCallback
// StepMessagesHandler is a function type for persisting messages after each
// complete step in a multi-step agent turn. The handler receives the messages
// produced by the step (typically an assistant message with tool calls followed
// by a tool-role message with results, or a final assistant message with text).
// This enables incremental session persistence so that progress is saved as
// it happens rather than only at the end of the turn.
type StepMessagesHandler func(stepMessages []fantasy.Message)
// StepUsageHandler is a function type for handling token usage after each
// complete step in a multi-step agent turn. This enables real-time cost
// tracking during long-running tool-calling conversations.
@@ -131,6 +149,11 @@ type GenerateWithLoopResult struct {
TotalUsage fantasy.Usage
// StopReason is the LLM provider's finish reason for the final response.
StopReason string
// PersistedMessageCount is the number of new messages (beyond the original
// input) that were already persisted incrementally via OnStepMessages during
// generation. The caller should skip these when doing post-generation
// persistence to avoid duplicates.
PersistedMessageCount int
}
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
@@ -148,8 +171,16 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Register core tools (direct AgentTool implementations, no MCP overhead).
// Use caller-provided tools if set, otherwise default to all core tools.
coreTools := agentConfig.CoreTools
if len(coreTools) == 0 {
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
var coreTools []fantasy.AgentTool
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
// Explicitly zero tools - chat-only mode
coreTools = nil
} else if len(agentConfig.CoreTools) > 0 {
// Custom tools provided - use them
coreTools = agentConfig.CoreTools
} else {
// Default: load all core tools
coreTools = core.AllTools()
}
@@ -208,6 +239,10 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
// Set per-server loaded callback if provided.
if agentConfig.OnMCPServerLoaded != nil {
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
}
a.toolManager = toolManager
a.mcpReady = make(chan struct{})
@@ -355,7 +390,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
}
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
@@ -368,6 +403,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onReasoningDelta ReasoningDeltaHandler,
onReasoningComplete ReasoningCompleteHandler,
onToolOutput ToolOutputHandler,
onStepMessages StepMessagesHandler,
onStepUsage StepUsageHandler,
) (*GenerateWithLoopResult, error) {
@@ -407,6 +443,10 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// when it returns an error, but the OnStepFinish callback fires
// for every step that completed before the error occurred.
var completedStepMessages []fantasy.Message
// persistedCount tracks how many new messages (beyond the original
// input) were persisted incrementally via onStepMessages, so the
// caller can skip them during post-generation persistence.
var persistedCount int
// Use the streaming agent
streamCall := fantasy.AgentStreamCall{
@@ -492,6 +532,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// persisted even if a later step is cancelled.
completedStepMessages = append(completedStepMessages, step.Messages...)
// Persist step messages incrementally so progress is saved
// as it happens rather than only at the end of the turn.
if onStepMessages != nil && len(step.Messages) > 0 {
onStepMessages(step.Messages)
persistedCount += len(step.Messages)
}
if ctx.Err() != nil {
return ctx.Err()
}
@@ -570,7 +617,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
partialMessages = append(partialMessages, messages...)
partialMessages = append(partialMessages, completedStepMessages...)
return &GenerateWithLoopResult{
ConversationMessages: partialMessages,
ConversationMessages: partialMessages,
PersistedMessageCount: persistedCount,
}, err
}
return nil, err
@@ -585,7 +633,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onResponse(result.Response.Content.Text())
}
return convertAgentResult(result, messages), nil
r := convertAgentResult(result, messages)
r.PersistedMessageCount = persistedCount
return r, nil
}
// Non-streaming path with no callbacks — use the simpler Generate call.
+19 -10
View File
@@ -41,10 +41,17 @@ type AgentCreationOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used.
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper wraps the combined tool list before agent creation.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include (e.g. from extensions).
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// CreateAgent creates an agent with optional spinner for Ollama models.
@@ -52,16 +59,18 @@ type AgentCreationOptions struct {
// Returns the created agent or an error if creation fails.
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
agentConfig := &AgentConfig{
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
var agent *Agent
+28
View File
@@ -1010,6 +1010,34 @@ func (a *App) NotifyContentReload() {
}
}
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
// tool names and MCP tool count from provider callbacks. Called when background
// MCP tool loading completes. In non-interactive mode this is a no-op.
func (a *App) NotifyMCPToolsReady() {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPToolsReadyEvent{})
}
}
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
// display a system message when a single MCP server finishes loading. Called
// per server as background MCP tool loading progresses.
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPServerLoadedEvent{
ServerName: serverName,
ToolCount: toolCount,
Error: err,
})
}
}
// SendEvent sends a tea.Msg to the registered program. Safe to call from
// any goroutine. No-op when no program is registered.
//
+14
View File
@@ -172,6 +172,20 @@ type WidgetUpdateEvent struct{}
// its autocomplete entries and internal state from the provider callbacks.
type ContentReloadEvent struct{}
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
// The TUI refreshes its tool names and MCP tool count from provider callbacks
// so that /tools and the startup info bar reflect the loaded MCP tools.
type MCPToolsReadyEvent struct{}
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
// (successfully or with error). The TUI displays a system message so users
// see real-time progress as each server initializes.
type MCPServerLoadedEvent struct {
ServerName string
ToolCount int
Error error // nil on success
}
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
// pre-fill the input editor with text. The TUI handles this by setting the
// textarea content and moving the cursor to the end.
+1 -6
View File
@@ -34,15 +34,10 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
for _, p := range paths {
ext, err := loadSingleExtension(p)
if err != nil {
log.Warn("skipping extension", "path", p, "err", err)
continue
}
loaded = append(loaded, *ext)
log.Debug("loaded extension", "path", p,
"handlers", countHandlers(ext),
"tools", len(ext.Tools),
"commands", len(ext.Commands),
"tool_renderers", len(ext.ToolRenderers))
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
}
return loaded, nil
}
+3 -8
View File
@@ -2,12 +2,12 @@ package extensions
import (
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
"github.com/charmbracelet/log"
"github.com/spf13/viper"
)
@@ -370,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
log.Warn("extension handler error",
"path", ext.Path,
"event", event.Type(),
"err", err)
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
continue
}
if result == nil {
@@ -707,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
safeInvoke := func(h func(string)) {
defer func() {
if rec := recover(); rec != nil {
log.Warn("custom event handler panicked",
"event", name,
"err", fmt.Sprintf("%v", rec))
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
}
}()
h(data)
+6 -6
View File
@@ -3,13 +3,13 @@ package extensions
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/fsnotify/fsnotify"
)
@@ -39,7 +39,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
for _, dir := range dirs {
// Watch the directory itself.
if err := fsw.Add(dir); err != nil {
log.Debug("watcher: skipping directory", "dir", dir, "err", err)
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
continue
}
@@ -52,7 +52,7 @@ func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Debug("watcher: skipping subdirectory", "dir", subdir, "err", err)
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
}
}
}
@@ -101,7 +101,7 @@ func (w *Watcher) Start(ctx context.Context) {
continue
}
log.Debug("watcher: file changed", "file", event.Name, "op", event.Op)
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
// Debounce: reset timer on each event.
if timer != nil {
@@ -113,14 +113,14 @@ func (w *Watcher) Start(ctx context.Context) {
case <-timerC:
timerC = nil
timer = nil
log.Debug("watcher: reloading extensions")
log.Printf("DEBUG watcher: reloading extensions")
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Warn("watcher: error", "err", err)
log.Printf("WARN watcher: error: %v", err)
}
}
}
+22 -13
View File
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ExtraTools are additional tools added alongside core, MCP, and extension
// tools. They do not replace the defaults — they extend them.
ExtraTools []fantasy.AgentTool
@@ -61,6 +65,9 @@ type AgentSetupOptions struct {
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
AuthHandler tools.MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// AgentSetupResult bundles the created agent and any debug logger so the caller
@@ -183,19 +190,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
}
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: maxSteps,
StreamingEnabled: streamingEnabled,
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: maxSteps,
StreamingEnabled: streamingEnabled,
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)
+2 -6
View File
@@ -2,11 +2,10 @@ package prompts
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
)
// LoadOptions configures how templates are discovered and loaded.
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
DroppedPath: tpl.FilePath,
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
})
log.Debug("template collision",
"name", tpl.Name,
"dropped", tpl.FilePath,
"kept", existing.FilePath)
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
} else {
tpl.Source = source
seen[tpl.Name] = tpl
+23 -7
View File
@@ -29,6 +29,10 @@ type MCPToolManager struct {
config *config.Config
debug bool
debugLogger DebugLogger
// onServerLoaded, if non-nil, is called when each server finishes loading.
// Called with server name, tool count, and error (nil on success).
onServerLoaded func(serverName string, toolCount int, err error)
}
// toolMapping stores the mapping between prefixed tool names and their original details
@@ -76,6 +80,13 @@ func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
}
}
// SetOnServerLoaded sets the callback that's invoked when each MCP server finishes
// loading. The callback receives the server name, tool count, and any error.
// Call this before LoadTools to receive per-server notifications.
func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount int, err error)) {
m.onServerLoaded = cb
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
@@ -108,8 +119,12 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
wg.Add(1)
go func(name string, sc config.MCPServerConfig) {
defer wg.Done()
err := m.loadServerTools(ctx, name, sc)
count, err := m.loadServerTools(ctx, name, sc)
results <- serverResult{name: name, err: err}
// Notify callback if set (for real-time UI updates).
if m.onServerLoaded != nil {
m.onServerLoaded(name, count, err)
}
}(serverName, serverConfig)
}
@@ -137,14 +152,15 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
// loadServerTools loads tools from a single MCP server.
// Thread-safe: may be called concurrently for different servers.
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
// Returns the number of tools loaded from this server, or -1 on error.
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (int, error) {
// Add debug logging
m.debugLogConnectionInfo(serverName, serverConfig)
// Get connection from pool
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to get connection from pool: %v", err)
return -1, fmt.Errorf("failed to get connection from pool: %v", err)
}
// Get tools from this server
@@ -152,7 +168,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
if err != nil {
// Handle connection error
m.connectionPool.HandleConnectionError(serverName, err)
return fmt.Errorf("failed to list tools: %v", err)
return -1, fmt.Errorf("failed to list tools: %v", err)
}
// Create name set for allowed tools
@@ -185,7 +201,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
// Fix for JSON Schema draft-07 vs draft-04 compatibility
@@ -194,7 +210,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Parse into map[string]any for fantasy's parameters format
var schemaMap map[string]any
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Extract properties and required from the schema
@@ -249,7 +265,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
m.tools = append(m.tools, localTools...)
m.mu.Unlock()
return nil
return len(localTools), nil
}
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
-28
View File
@@ -5,7 +5,6 @@ import (
"os"
"time"
"charm.land/fantasy"
"charm.land/lipgloss/v2"
"golang.org/x/term"
@@ -173,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
}
// UpdateUsageFromResponse records token usage using metadata from the fantasy
// response. Only actual API-reported tokens are used for cost tracking.
// If the provider doesn't report token counts, no usage is recorded.
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
if c.usageTracker == nil {
return
}
usage := response.Usage
inputTokens := int(usage.InputTokens)
outputTokens := int(usage.OutputTokens)
// Only use actual API-reported tokens for cost tracking.
// We intentionally do NOT estimate tokens - estimation is inaccurate
// and should never be used for cost calculations.
if inputTokens > 0 {
cacheReadTokens := int(usage.CacheReadTokens)
cacheWriteTokens := int(usage.CacheCreationTokens)
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
// Per-response usage is a single API call, so it represents the
// actual context window fill level.
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
}
// If inputTokens is 0, the provider didn't report usage - we skip recording
// rather than estimating, to ensure cost accuracy.
}
// DisplayUsageAfterResponse renders and displays token usage information immediately
// following an AI response. This provides real-time feedback about the cost and
// token consumption of each interaction.
+7 -7
View File
@@ -69,7 +69,7 @@ type InputComponent struct {
hideHint bool
// agentBusy indicates the agent is currently working. When true, the
// hint text shows steering shortcut (Ctrl+S) instead of submit.
// hint text shows steering shortcut (Ctrl+X s) instead of submit.
agentBusy bool
// pendingImages holds clipboard images attached to the next submission.
@@ -109,7 +109,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
ta.Placeholder = "Type your message..."
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 5000
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(3) // Default to 3 lines like huh
ta.Focus()
@@ -514,12 +514,12 @@ func (s *InputComponent) View() tea.View {
availableHintWidth := s.width - 3
if s.agentBusy {
// When the agent is working, show steering shortcut.
if availableHintWidth >= 55 {
hint = "enter queue • ctrl+s steer • esc esc cancel"
} else if availableHintWidth >= 35 {
hint = "↵ queue • ^S steer • esc×2 cancel"
if availableHintWidth >= 60 {
hint = "enter queue • ctrl+x s steer • esc esc cancel"
} else if availableHintWidth >= 40 {
hint = "↵ queue • ^X s steer • esc×2 cancel"
} else {
hint = "^S steer"
hint = "^X s steer"
}
} else if availableHintWidth >= 67 {
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
+1 -1
View File
@@ -152,7 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
// RenderUserMessage renders a user's input message using herald Tip alert
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
rendered := render.UserBlock(content, r.ty, style.GetTheme())
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
return UIMessage{
Type: UserMessage,
+206 -73
View File
@@ -281,6 +281,16 @@ type AppModelOptions struct {
// ToolNames holds available tool names for the /tools command.
ToolNames []string
// GetToolNames, if non-nil, returns the current tool names. Called on
// MCPToolsReadyEvent to refresh the tool list after background MCP tool
// loading completes. May be nil if dynamic tool refresh is not needed.
GetToolNames func() []string
// GetMCPToolCount, if non-nil, returns the current MCP tool count.
// Called on MCPToolsReadyEvent to refresh the startup info bar.
// May be nil if dynamic tool refresh is not needed.
GetMCPToolCount func() int
// UsageTracker provides token usage statistics for /usage and /reset-usage.
// May be nil if usage tracking is unavailable for the current model.
UsageTracker *UsageTracker
@@ -467,7 +477,7 @@ type AppModel struct {
queuedMessages []string
// steeringMessages stores the text of prompts that were sent as steer
// messages (injected mid-turn via Ctrl+S). Rendered with a "STEERING"
// messages (injected mid-turn via Ctrl+X s). Rendered with a "STEERING"
// badge above the input. Cleared when the steer is consumed.
steeringMessages []string
@@ -488,6 +498,11 @@ type AppModel struct {
// A second ESC within 2 seconds will cancel the current step.
canceling bool
// leaderKeyActive tracks whether the Ctrl+X leader key prefix has been
// pressed. The next keypress is interpreted as a chord suffix (e.g. "s"
// for steer). Cleared on any subsequent keypress.
leaderKeyActive bool
// providerName is the LLM provider for the startup message.
providerName string
@@ -495,8 +510,12 @@ type AppModel struct {
loadingMessage string
// serverNames, toolNames are used by /servers and /tools commands.
serverNames []string
toolNames []string
serverNames []string
toolNames []string
getToolNames func() []string // dynamic tool name provider (for MCP refresh)
// getMCPToolCount returns the current MCP tool count dynamically.
getMCPToolCount func() int
// usageTracker provides token usage stats for /usage and /reset-usage.
// May be nil when usage tracking is unavailable.
@@ -722,18 +741,20 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
rdr := mr
m := &AppModel{
state: stateInput,
appCtrl: appCtrl,
renderer: rdr,
modelName: opts.ModelName,
providerName: opts.ProviderName,
loadingMessage: opts.LoadingMessage,
serverNames: opts.ServerNames,
toolNames: opts.ToolNames,
usageTracker: opts.UsageTracker,
cwd: opts.Cwd,
width: width,
height: height,
state: stateInput,
appCtrl: appCtrl,
renderer: rdr,
modelName: opts.ModelName,
providerName: opts.ProviderName,
loadingMessage: opts.LoadingMessage,
serverNames: opts.ServerNames,
toolNames: opts.ToolNames,
getToolNames: opts.GetToolNames,
getMCPToolCount: opts.GetMCPToolCount,
usageTracker: opts.UsageTracker,
cwd: opts.Cwd,
width: width,
height: height,
}
// Store extension commands for dispatch.
@@ -1252,6 +1273,71 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(cmds...)
}
// ── Leader key chord handling (Ctrl+X prefix) ──────────────
// If the leader key was previously pressed, the current key
// completes the chord. We consume it regardless of match so
// the prefix doesn't leak to child components.
if m.leaderKeyActive {
m.leaderKeyActive = false
switch msg.String() {
case "s":
// Ctrl+X s → Steer: inject the current input as a steering
// message into the running agent turn.
if m.state == stateWorking && m.appCtrl != nil {
var text string
if ic, ok := m.input.(*InputComponent); ok {
text = strings.TrimSpace(ic.textarea.Value())
}
if text != "" {
// Clear the input, collect pending images, and push to history.
var images []uicore.ImageAttachment
if ic, ok := m.input.(*InputComponent); ok {
ic.pushHistory(text)
ic.textarea.SetValue("")
images = ic.ClearPendingImages()
}
// Preprocess @file references.
processedText := text
if m.cwd != "" {
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
}
// Convert image attachments to kit.LLMFilePart for the app layer.
var fileParts []kit.LLMFilePart
for _, img := range images {
fileParts = append(fileParts, kit.LLMFilePart{
Data: img.Data,
MediaType: img.MediaType,
})
}
// Build display text (include image count if any).
displayText := text
if len(images) > 0 {
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
}
// Inject the steer message.
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
if sLen > 0 {
m.steeringMessages = append(m.steeringMessages, displayText)
m.layoutDirty = true
} else {
// Started immediately (agent was idle).
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
if m.state != stateWorking {
m.state = stateWorking
}
}
}
}
}
// Chord consumed — don't propagate to children.
return m, tea.Batch(cmds...)
}
switch msg.String() {
case "esc":
if m.state == stateWorking {
@@ -1270,61 +1356,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
// In other states pass ESC through to children below.
case "ctrl+s":
// Steer: inject the current input as a steering message into the
// running agent turn. Only active during stateWorking — in input
// state, Ctrl+S is passed through to children (no-op by default).
if m.state == stateWorking && m.appCtrl != nil {
var text string
if ic, ok := m.input.(*InputComponent); ok {
text = strings.TrimSpace(ic.textarea.Value())
}
if text != "" {
// Clear the input, collect pending images, and push to history.
var images []uicore.ImageAttachment
if ic, ok := m.input.(*InputComponent); ok {
ic.pushHistory(text)
ic.textarea.SetValue("")
images = ic.ClearPendingImages()
}
// Preprocess @file references.
processedText := text
if m.cwd != "" {
processedText = fileutil.ProcessFileAttachments(text, m.cwd)
}
// Convert image attachments to kit.LLMFilePart for the app layer.
var fileParts []kit.LLMFilePart
for _, img := range images {
fileParts = append(fileParts, kit.LLMFilePart{
Data: img.Data,
MediaType: img.MediaType,
})
}
// Build display text (include image count if any).
displayText := text
if len(images) > 0 {
displayText = fmt.Sprintf("%s\n[%d image(s) attached]", text, len(images))
}
// Inject the steer message.
sLen := m.appCtrl.SteerWithFiles(processedText, fileParts)
if sLen > 0 {
m.steeringMessages = append(m.steeringMessages, displayText)
m.layoutDirty = true
} else {
// Started immediately (agent was idle).
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
m.flushStreamAndPendingUserMessages()
if m.state != stateWorking {
m.state = stateWorking
}
}
}
return m, tea.Batch(cmds...)
}
case "ctrl+x":
// Activate leader key prefix — the next keypress completes the chord.
m.leaderKeyActive = true
return m, tea.Batch(cmds...)
}
// Route key events to the focused child. Check for editor
@@ -1843,6 +1878,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.refreshSkillItems()
m.printSystemMessage("Prompts and skills reloaded.")
case app.MCPToolsReadyEvent:
// Background MCP tool loading completed — refresh tool names and count.
m.refreshToolNames()
m.refreshMCPToolCount()
case app.MCPServerLoadedEvent:
// A single MCP server finished loading — display a system message.
if msg.Error != nil {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' failed to load: %v", msg.ServerName, msg.Error))
} else if msg.ToolCount > 0 {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' loaded with %d tools", msg.ServerName, msg.ToolCount))
} else {
m.printSystemMessage(fmt.Sprintf("MCP server '%s' loaded (no tools)", msg.ServerName))
}
case app.EditorTextSetEvent:
// Extension wants to pre-fill the input editor with text.
if ic, ok := m.input.(*InputComponent); ok {
@@ -2431,22 +2481,34 @@ func (m *AppModel) renderHeaderFooter(getter func() *WidgetData) string {
return renderContentBlock(data.Text, m.width, opts...)
}
// maxQueuedMessageLines is the maximum number of visible content lines
// rendered for each queued or steering message block. Messages exceeding
// this limit are truncated with an ellipsis to prevent large pastes from
// overflowing the screen and squeezing the stream region to zero.
const maxQueuedMessageLines = 3
// renderQueuedMessages renders queued and steering prompts as styled content
// blocks with badges, anchored between the separator and input. Steering
// messages use a distinct "STEERING" badge to differentiate from queued ones.
// Long messages are visually truncated to maxQueuedMessageLines.
func (m *AppModel) renderQueuedMessages() string {
if len(m.queuedMessages) == 0 && len(m.steeringMessages) == 0 {
return ""
}
theme := style.GetTheme()
// Available content width inside the block: container minus border (1)
// minus left padding (2). Used to estimate line wrapping for truncation.
contentWidth := max(m.width-3, 10)
var blocks []string
// Render steering messages first (higher priority).
if len(m.steeringMessages) > 0 {
badge := style.CreateBadge("STEERING", theme.Warning)
for _, msg := range m.steeringMessages {
content := msg + "\n" + badge
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
content := display + "\n" + badge
rendered := renderContentBlock(
content,
m.width,
@@ -2461,7 +2523,8 @@ func (m *AppModel) renderQueuedMessages() string {
if len(m.queuedMessages) > 0 {
badge := style.CreateBadge("QUEUED", theme.Accent)
for _, msg := range m.queuedMessages {
content := msg + "\n" + badge
display := truncateMessageForBlock(msg, maxQueuedMessageLines, contentWidth)
content := display + "\n" + badge
rendered := renderContentBlock(
content,
m.width,
@@ -2475,6 +2538,58 @@ func (m *AppModel) renderQueuedMessages() string {
return strings.Join(blocks, "\n")
}
// truncateMessageForBlock truncates a message to at most maxLines visible
// lines, accounting for soft-wrapping at the given width. If the message is
// truncated, the last visible line is replaced with an ellipsis ("…").
func truncateMessageForBlock(msg string, maxLines, width int) string {
if width <= 0 {
width = 1
}
lines := strings.Split(msg, "\n")
// Count visible lines (each hard line may wrap into multiple visual lines).
var kept []string
visibleCount := 0
truncated := false
for _, line := range lines {
// Calculate how many visual lines this hard line occupies.
lineWidth := lipgloss.Width(line)
wrapped := 1
if lineWidth > width {
wrapped = (lineWidth + width - 1) / width // ceil division
}
if visibleCount+wrapped > maxLines {
// This line would exceed the limit. Keep a partial if we
// still have room for at least one more visual line.
remaining := maxLines - visibleCount
if remaining > 0 {
// Truncate the line to fit the remaining visual lines.
runes := []rune(line)
maxRunes := remaining * width
if maxRunes < len(runes) {
kept = append(kept, string(runes[:maxRunes]))
} else {
kept = append(kept, line)
}
}
truncated = true
break
}
kept = append(kept, line)
visibleCount += wrapped
}
if !truncated {
return msg
}
return strings.Join(kept, "\n") + "…"
}
// --------------------------------------------------------------------------
// Print helpers — add content to ScrollList
// --------------------------------------------------------------------------
@@ -2773,6 +2888,24 @@ func (m *AppModel) refreshSkillItems() {
m.skillItems = m.getSkillItems()
}
// refreshToolNames reloads tool names from the provider callback.
// Called on MCPToolsReadyEvent when background MCP tool loading completes.
func (m *AppModel) refreshToolNames() {
if m.getToolNames == nil {
return
}
m.toolNames = m.getToolNames()
}
// refreshMCPToolCount reloads the MCP tool count from the provider callback.
// Called on MCPToolsReadyEvent when background MCP tool loading completes.
func (m *AppModel) refreshMCPToolCount() {
if m.getMCPToolCount == nil {
return
}
m.mcpToolCount = m.getMCPToolCount()
}
// printHelpMessage renders the help text listing all available slash commands.
func (m *AppModel) printHelpMessage() {
help := "## Available Commands\n\n" +
@@ -2827,7 +2960,7 @@ func (m *AppModel) printHelpMessage() {
"**Keys:**\n" +
"- `Ctrl+C`: Exit at any time\n" +
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
"- `Ctrl+S`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
"- `Ctrl+X s`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
"You can also just type your message to chat with the AI assistant."
m.printSystemMessage(help)
+105
View File
@@ -2,6 +2,7 @@ package ui
import (
"errors"
"strings"
"testing"
tea "charm.land/bubbletea/v2"
@@ -892,3 +893,107 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
t.Fatalf("expected Run('queued prompt') called, got %v", ctrl.runCalls)
}
}
// --------------------------------------------------------------------------
// truncateMessageForBlock
// --------------------------------------------------------------------------
// TestTruncateMessageForBlock_shortMessage verifies that short messages are
// returned unchanged.
func TestTruncateMessageForBlock_shortMessage(t *testing.T) {
msg := "hello world"
got := truncateMessageForBlock(msg, 3, 80)
if got != msg {
t.Fatalf("expected unchanged message, got %q", got)
}
}
// TestTruncateMessageForBlock_exactLines verifies that a message with exactly
// maxLines hard lines is returned unchanged.
func TestTruncateMessageForBlock_exactLines(t *testing.T) {
msg := "line1\nline2\nline3"
got := truncateMessageForBlock(msg, 3, 80)
if got != msg {
t.Fatalf("expected unchanged message, got %q", got)
}
}
// TestTruncateMessageForBlock_tooManyLines verifies that messages exceeding
// maxLines are truncated with an ellipsis.
func TestTruncateMessageForBlock_tooManyLines(t *testing.T) {
msg := "line1\nline2\nline3\nline4\nline5"
got := truncateMessageForBlock(msg, 3, 80)
want := "line1\nline2\nline3…"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
// TestTruncateMessageForBlock_longWrappingLine verifies that a single long
// line that would wrap beyond maxLines is truncated.
func TestTruncateMessageForBlock_longWrappingLine(t *testing.T) {
// 100 chars at width 20 = 5 visual lines, exceeds maxLines=3
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
got := truncateMessageForBlock(msg, 3, 20)
// Should be truncated to 3*20=60 runes + "…"
if len([]rune(got)) != 61 { // 60 runes + "…"
t.Fatalf("expected 61 runes (60 + ellipsis), got %d runes: %q", len([]rune(got)), got)
}
if got[len(got)-3:] != "…" { // "…" is 3 bytes in UTF-8
t.Fatal("expected trailing ellipsis")
}
}
// TestTruncateMessageForBlock_emptyMessage verifies that empty messages are
// returned unchanged.
func TestTruncateMessageForBlock_emptyMessage(t *testing.T) {
got := truncateMessageForBlock("", 3, 80)
if got != "" {
t.Fatalf("expected empty string, got %q", got)
}
}
// TestTruncateMessageForBlock_mixedWrapAndHardLines verifies truncation when
// some hard lines wrap and the total exceeds maxLines.
func TestTruncateMessageForBlock_mixedWrapAndHardLines(t *testing.T) {
// First line: 40 chars at width 20 = 2 visual lines
// Second line: "short" = 1 visual line (total: 3, exactly at limit)
// Third line: would exceed
msg := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort\nextra"
got := truncateMessageForBlock(msg, 3, 20)
want := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nshort…"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
// TestRenderQueuedMessages_truncatesLongMessages verifies that the rendered
// queued message view truncates long messages instead of showing them in full.
func TestRenderQueuedMessages_truncatesLongMessages(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.width = 80
// Queue a very long message (20 lines).
var b strings.Builder
for i := range 20 {
if i > 0 {
b.WriteByte('\n')
}
b.WriteString("This is a long line of text for testing purposes")
}
m.queuedMessages = []string{b.String()}
rendered := m.renderQueuedMessages()
if rendered == "" {
t.Fatal("expected non-empty rendered output")
}
// The full message would be ~20+ lines. With truncation to 3 content
// lines + badge + padding, it should be much shorter.
lines := len(strings.Split(rendered, "\n"))
// 3 content lines + 1 badge + 2 padding + border overhead ≈ ~7 lines max
if lines > 10 {
t.Fatalf("expected truncated output to be ≤10 lines, got %d lines", lines)
}
}
+1 -1
View File
@@ -78,7 +78,7 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
ta.Placeholder = placeholder
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 1000
ta.CharLimit = 0
ta.SetWidth(width - 12) // account for border + padding
ta.SetHeight(1)
ta.Focus()
+9 -1
View File
@@ -14,11 +14,19 @@ import (
)
// UserBlock renders a user message with herald Tip styling.
func UserBlock(content string, ty *herald.Typography, theme style.Theme) string {
// The width parameter controls line wrapping so long messages don't overflow.
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
// Wrap content before passing to herald Alert so long lines break
// inside the alert box. Subtract 4 to account for the alert bar
// prefix ("│ ") and a small margin.
if width > 4 {
content = lipgloss.Wrap(content, width-4, "")
}
rendered := ty.Tip(content)
return styleMarginBottom(theme, rendered)
}
+5 -3
View File
@@ -85,11 +85,13 @@ func GetMarkdownTypography() *herald.Typography {
return ty
}
// ToMarkdown renders markdown content using herald-md.
// The width parameter is currently unused as herald handles wrapping
// based on terminal width internally.
// ToMarkdown renders markdown content using herald-md and wraps the result
// to the given width so that long lines do not overflow the terminal.
func ToMarkdown(content string, width int) string {
ty := GetMarkdownTypography()
rendered := heraldmd.Render(ty, []byte(content))
if width > 0 {
rendered = lipgloss.Wrap(rendered, width, "")
}
return rendered
}
+1 -1
View File
@@ -23,7 +23,7 @@ func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInp
ta := textarea.New()
ta.Placeholder = ""
ta.ShowLineNumbers = false
ta.CharLimit = 1000
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(4) // Default to 3 lines like huh
ta.Focus()
+38 -9
View File
@@ -13,7 +13,6 @@ import (
"sync"
"time"
"github.com/charmbracelet/log"
"github.com/fsnotify/fsnotify"
)
@@ -63,7 +62,6 @@ func New(opts Options) (*ContentWatcher, error) {
for _, dir := range opts.Dirs {
if err := fsw.Add(dir); err != nil {
log.Debug("watcher: skipping directory", "label", opts.Label, "dir", dir, "err", err)
continue
}
@@ -75,9 +73,7 @@ func New(opts Options) (*ContentWatcher, error) {
for _, entry := range entries {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Debug("watcher: skipping subdirectory", "label", opts.Label, "dir", subdir, "err", err)
}
_ = fsw.Add(subdir)
}
}
}
@@ -122,6 +118,26 @@ func (w *ContentWatcher) Start(ctx context.Context) {
return
}
// When a new subdirectory is created, start watching it so
// that files added inside (e.g. new-skill/SKILL.md) trigger
// reload events. Also schedule a reload in case the directory
// was created with matching files already inside.
if event.Op&fsnotify.Create != 0 {
if info, err := os.Stat(event.Name); err == nil && info.IsDir() {
if addErr := w.watcher.Add(event.Name); addErr == nil {
// Check if the new directory already contains matching files.
if w.dirContainsMatchingFiles(event.Name) {
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
}
}
continue
}
}
// Only care about files matching our extensions.
if !w.matchesExtension(event.Name) {
continue
@@ -132,8 +148,6 @@ func (w *ContentWatcher) Start(ctx context.Context) {
continue
}
log.Debug("watcher: file changed", "label", w.label, "file", event.Name, "op", event.Op)
// Debounce: reset timer on each event.
if timer != nil {
timer.Stop()
@@ -144,14 +158,13 @@ func (w *ContentWatcher) Start(ctx context.Context) {
case <-timerC:
timerC = nil
timer = nil
log.Debug("watcher: reloading", "label", w.label)
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Warn("watcher: error", "label", w.label, "err", err)
_ = err
}
}
}
@@ -182,6 +195,22 @@ func (w *ContentWatcher) matchesExtension(name string) bool {
return false
}
// dirContainsMatchingFiles returns true if the directory contains at least
// one file matching the watched extensions. Used to detect cases where a
// directory is created with files already inside (e.g. cp -r).
func (w *ContentWatcher) dirContainsMatchingFiles(dir string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if !entry.IsDir() && w.matchesExtension(entry.Name()) {
return true
}
}
return false
}
// CollectDirs returns the directories to watch for a given set of standard
// directories and extra paths. Directories are deduplicated by absolute path
// and verified to exist. For explicit file paths, the parent directory is
+82
View File
@@ -190,6 +190,88 @@ func TestContentWatcher_WatchesSubdirectories(t *testing.T) {
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectory(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Wait for watcher to be ready.
time.Sleep(100 * time.Millisecond)
// Create a NEW subdirectory after the watcher started (the bug scenario).
subdir := filepath.Join(dir, "new-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
// Give fsnotify time to pick up the new directory.
time.Sleep(100 * time.Millisecond)
// Write a matching file inside the new subdirectory.
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# New Skill"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for file in new subdirectory, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectoryWithExistingFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Create a subdirectory with a matching file already inside (simulates cp -r).
subdir := filepath.Join(dir, "copied-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Copied"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(300 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for copied subdirectory with files, got %d", got)
}
_ = w.Close()
}
func TestCollectDirs_Deduplicates(t *testing.T) {
dir := t.TempDir()
+24 -2
View File
@@ -68,8 +68,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -172,6 +176,24 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
- `GetSessionID()` - Get session UUID
- `Close()` - Clean up resources
### Options
Key `Options` fields for SDK usage:
| Field | Description |
|-------|-------------|
| `Model` | Override model (e.g., "anthropic/claude-sonnet-4-5-20250929") |
| `SystemPrompt` | Override system prompt |
| `ConfigFile` | Load specific config file (empty = search defaults) |
| `SkipConfig` | Skip `.kit.yml` loading (defaults + env vars still apply) |
| `Tools` | Replace core tools with custom set |
| `ExtraTools` | Add tools alongside defaults |
| `DisableCoreTools` | Use no core tools (0 tools, for chat-only) |
| `NoSession` | Ephemeral mode (no session persistence) |
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `Debug` | Enable debug logging |
## Environment Variables
All CLI environment variables work with the SDK:
+231
View File
@@ -0,0 +1,231 @@
package kit
import (
"strings"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/session"
)
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
inner *session.TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
// AppendMessage implements SessionManager.
func (a *treeManagerAdapter) AppendMessage(msg LLMMessage) (string, error) {
// LLMMessage is just an alias for fantasy.Message, so no conversion needed
return a.inner.AppendLLMMessage(msg)
}
// GetMessages implements SessionManager.
func (a *treeManagerAdapter) GetMessages() []LLMMessage {
// LLMMessage is just an alias for fantasy.Message
return a.inner.GetLLMMessages()
}
// BuildContext implements SessionManager.
func (a *treeManagerAdapter) BuildContext() ([]LLMMessage, string, string) {
msgs, provider, modelID := a.inner.BuildContext()
return msgs, provider, modelID
}
// Branch implements SessionManager.
func (a *treeManagerAdapter) Branch(entryID string) error {
return a.inner.Branch(entryID)
}
// GetCurrentBranch implements SessionManager.
func (a *treeManagerAdapter) GetCurrentBranch() []BranchEntry {
branch := a.inner.GetBranch("")
var result []BranchEntry
for _, entry := range branch {
be := a.convertEntry(entry)
if be != nil {
result = append(result, *be)
}
}
return result
}
// GetChildren implements SessionManager.
func (a *treeManagerAdapter) GetChildren(parentID string) []string {
return a.inner.GetChildren(parentID)
}
// GetEntry implements SessionManager.
func (a *treeManagerAdapter) GetEntry(entryID string) *BranchEntry {
entry := a.inner.GetEntry(entryID)
if entry == nil {
return nil
}
return a.convertEntry(entry)
}
// GetSessionID implements SessionManager.
func (a *treeManagerAdapter) GetSessionID() string {
return a.inner.GetSessionID()
}
// GetSessionName implements SessionManager.
func (a *treeManagerAdapter) GetSessionName() string {
return a.inner.GetSessionName()
}
// SetSessionName implements SessionManager.
func (a *treeManagerAdapter) SetSessionName(name string) error {
_, err := a.inner.AppendSessionInfo(name)
return err
}
// GetCreatedAt implements SessionManager.
func (a *treeManagerAdapter) GetCreatedAt() time.Time {
return a.inner.GetHeader().Timestamp
}
// IsPersisted implements SessionManager.
func (a *treeManagerAdapter) IsPersisted() bool {
return a.inner.IsPersisted()
}
// AppendCompaction implements SessionManager.
func (a *treeManagerAdapter) AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
return a.inner.AppendCompaction(summary, firstKeptEntryID,
tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
}
// GetLastCompaction implements SessionManager.
func (a *treeManagerAdapter) GetLastCompaction() *CompactionEntry {
c := a.inner.GetLastCompaction()
if c == nil {
return nil
}
return &CompactionEntry{
ID: c.ID,
Summary: c.Summary,
FirstKeptEntryID: c.FirstKeptEntryID,
TokensBefore: c.TokensBefore,
TokensAfter: c.TokensAfter,
MessagesRemoved: c.MessagesRemoved,
ReadFiles: c.ReadFiles,
ModifiedFiles: c.ModifiedFiles,
Timestamp: c.Timestamp,
}
}
// AppendExtensionData implements SessionManager.
func (a *treeManagerAdapter) AppendExtensionData(extType, data string) (string, error) {
return a.inner.AppendExtensionData(extType, data)
}
// GetExtensionData implements SessionManager.
func (a *treeManagerAdapter) GetExtensionData(extType string) []ExtensionDataEntry {
entries := a.inner.GetExtensionData(extType)
var result []ExtensionDataEntry
for _, e := range entries {
result = append(result, ExtensionDataEntry{
ID: e.ID,
ExtType: e.ExtType,
Data: e.Data,
Timestamp: e.Timestamp,
})
}
return result
}
// AppendModelChange implements SessionManager.
func (a *treeManagerAdapter) AppendModelChange(provider, modelID string) (string, error) {
return a.inner.AppendModelChange(provider, modelID)
}
// GetContextEntryIDs implements SessionManager.
func (a *treeManagerAdapter) GetContextEntryIDs() []string {
return a.inner.GetContextEntryIDs()
}
// Close implements SessionManager.
func (a *treeManagerAdapter) Close() error {
return a.inner.Close()
}
// Helper: Convert internal entry types to BranchEntry
func (a *treeManagerAdapter) convertEntry(entry any) *BranchEntry {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// Build content text from parts
var content strings.Builder
for _, part := range msg.Parts {
if textPart, ok := part.(TextContent); ok {
content.WriteString(textPart.Text)
}
}
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeMessage,
Role: string(msg.Role),
Content: content.String(),
Model: e.Model,
Provider: e.Provider,
Timestamp: e.Timestamp,
RawParts: msg.Parts,
}
case *session.BranchSummaryEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeBranchSummary,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ModelChangeEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeModelChange,
Content: "Model changed to " + e.Provider + "/" + e.ModelID,
Model: e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp,
}
case *session.CompactionEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeCompaction,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ExtensionDataEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeExtensionData,
Content: "Extension data: " + e.ExtType,
Timestamp: e.Timestamp,
}
default:
return nil
}
}
// convertKitMessagesToFantasy converts kit LLM messages to fantasy messages.
// Since LLMMessage is an alias for fantasy.Message, this is a no-op.
func convertKitMessagesToFantasy(msgs []LLMMessage) []fantasy.Message {
// LLMMessage is just an alias for fantasy.Message, so we can type convert
return msgs
}
+12 -12
View File
@@ -21,9 +21,9 @@ type ContextStats struct {
const defaultReserveTokens = 16384
// EstimateContextTokens returns the estimated token count of the current
// conversation based on tree session messages.
// conversation based on session messages.
func (m *Kit) EstimateContextTokens() int {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
return compaction.EstimateMessageTokens(messages)
}
@@ -42,8 +42,8 @@ func (m *Kit) ShouldCompact() bool {
reserveTokens = m.compactionOpts.ReserveTokens
}
messages := m.treeSession.GetLLMMessages()
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
messages := m.session.GetMessages()
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
}
// GetContextStats returns current context usage statistics including
@@ -55,7 +55,7 @@ func (m *Kit) ShouldCompact() bool {
// because it includes system prompts, tool definitions, and other overhead
// that the heuristic cannot account for.
func (m *Kit) GetContextStats() ContextStats {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
// Prefer the real API-reported input token count when available.
m.lastInputTokensMu.RLock()
@@ -114,7 +114,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
}
}
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
if len(messages) < 2 {
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
}
@@ -145,7 +145,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Carry forward file tracking from previous compaction.
var prev *compaction.PreviousCompaction
if lastCompaction := m.treeSession.GetLastCompaction(); lastCompaction != nil {
if lastCompaction := m.session.GetLastCompaction(); lastCompaction != nil {
prev = &compaction.PreviousCompaction{
ReadFiles: lastCompaction.ReadFiles,
ModifiedFiles: lastCompaction.ModifiedFiles,
@@ -171,7 +171,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Non-destructive: append a CompactionEntry to the session tree instead
// of clearing and rewriting messages.
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if result.CutPoint >= 0 && result.CutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[result.CutPoint]
@@ -188,9 +188,9 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// custom summary. It still determines the cut point and persists a
// CompactionEntry.
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
originalTokens := compaction.EstimateMessageTokens(messages)
originalTokens := compaction.EstimateMessageTokens(convertKitMessagesToFantasy(messages))
cutPoint := compaction.FindCutPoint(messages, opts.KeepRecentTokens)
cutPoint := compaction.FindCutPoint(convertKitMessagesToFantasy(messages), opts.KeepRecentTokens)
if cutPoint == 0 {
cutPoint = len(messages) - 1
if cutPoint < 1 {
@@ -198,7 +198,7 @@ func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts
}
}
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if cutPoint >= 0 && cutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[cutPoint]
@@ -234,7 +234,7 @@ func (m *Kit) persistAndEmitCompaction(
originalTokens, compactedTokens, messagesRemoved int,
readFiles, modifiedFiles []string,
) error {
if _, err := m.treeSession.AppendCompaction(
if _, err := m.session.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
+34 -11
View File
@@ -227,28 +227,51 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
// Session data
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
return iterBranchMessages(e.kit.treeSession, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if e.kit.session == nil {
return nil
}
// Try to use the legacy iterBranchMessages for backward compatibility
// with the default TreeManager adapter
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
}
})
}
// For custom SessionManagers, use the public interface
branch := e.kit.session.GetCurrentBranch()
var result []extensions.SessionMessage
for _, entry := range branch {
if entry.Type == EntryTypeMessage {
result = append(result, extensions.SessionMessage{
ID: entry.ID,
Role: entry.Role,
Content: entry.Content,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
})
}
return result
}
func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return "", fmt.Errorf("no session available")
}
return e.kit.treeSession.AppendExtensionData(extType, data)
return e.kit.session.AppendExtensionData(extType, data)
}
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return nil
}
entries := e.kit.treeSession.GetExtensionData(extType)
entries := e.kit.session.GetExtensionData(extType)
result := make([]extensions.ExtensionEntry, 0, len(entries))
for _, e := range entries {
result = append(result, extensions.ExtensionEntry{
+135 -63
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
@@ -11,7 +12,6 @@ import (
"time"
"charm.land/fantasy"
charmlog "github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
@@ -39,7 +39,7 @@ type ContextFile struct {
// agents, sessions, and model configurations.
type Kit struct {
agent *agent.Agent
treeSession *session.TreeManager
session SessionManager
modelString string
events *eventBus
autoCompact bool
@@ -172,27 +172,39 @@ type StructuredMessage struct {
// flattens all content to a single text string, this preserves tool calls,
// tool results, reasoning blocks, and finish markers as distinct typed parts.
func (m *Kit) GetStructuredMessages() []StructuredMessage {
return iterBranchMessages(m.treeSession, func(me *session.MessageEntry, msg message.Message) StructuredMessage {
return StructuredMessage{
ID: me.ID,
ParentID: me.ParentID,
Role: msg.Role,
Parts: msg.Parts,
Model: msg.Model,
Provider: msg.Provider,
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if m.session == nil {
return nil
}
branch := m.session.GetCurrentBranch()
var results []StructuredMessage
for _, entry := range branch {
if entry.Type != EntryTypeMessage {
continue
}
})
results = append(results, StructuredMessage{
ID: entry.ID,
ParentID: entry.ParentID,
Role: MessageRole(entry.Role),
Parts: entry.RawParts,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
return results
}
// iterBranchMessages iterates over the current branch's MessageEntry items,
// converting each to a message.Message and calling fn to build the result.
// Returns nil if there is no tree session. Skips entries that are not
// Returns nil if there is no session. Skips entries that are not
// MessageEntry or that fail conversion.
// Deprecated: Use SessionManager.GetCurrentBranch() directly.
func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.MessageEntry, message.Message) T) []T {
if tm == nil {
return nil
}
branch := tm.GetBranch("")
var results []T
for _, entry := range branch {
@@ -445,6 +457,17 @@ type Options struct {
Tools []Tool // Custom tool set. If empty, AllTools() is used.
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
// SkipConfig, when true, skips loading .kit.yml configuration files.
// Viper defaults (setSDKDefaults) and environment variables (KIT_*)
// are still applied. Use this for fully programmatic configuration.
SkipConfig bool
// DisableCoreTools, when true, prevents loading any core tools.
// Use with Tools or ExtraTools to provide only custom tools.
// If both DisableCoreTools is true and Tools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// Session configuration
SessionDir string // Base directory for session discovery (default: cwd)
SessionPath string // Open a specific session file by path
@@ -474,8 +497,20 @@ type Options struct {
// display a URL in a custom UI, redirect to a web app, etc.).
MCPAuthHandler MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading during Kit initialization. The callback receives the server name,
// tool count, and any error. Called from a background goroutine; safe to
// call app.NotifyMCPServerLoaded() from within the callback to display
// real-time progress in the TUI.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// CLI is optional CLI-specific configuration. SDK users leave this nil.
CLI *CLIOptions
// SessionManager allows custom session storage backends.
// If nil (default), Kit uses the built-in file-based TreeManager.
// When provided, SessionPath, Continue, and NoSession options are ignored.
SessionManager SessionManager
}
// CLIOptions holds fields only relevant to the CLI binary. SDK users should
@@ -570,7 +605,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Initialize config (loads config files and env vars).
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
// Check if model is already set, which indicates config was loaded.
if viper.GetString("model") == "" {
// SkipConfig bypasses .kit.yml file loading (viper defaults and env vars still apply).
if !opts.SkipConfig && viper.GetString("model") == "" {
if err := InitConfig(opts.ConfigFile, false); err != nil {
return fmt.Errorf("failed to initialize config: %w", err)
}
@@ -679,16 +715,18 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// Pass the pre-built ProviderConfig and scalar viper snapshots so
// SetupAgent doesn't need to re-read viper (which would require the lock).
setupOpts := kitsetup.AgentSetupOptions{
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
DisableCoreTools: opts.DisableCoreTools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
// Set up OAuth handler for remote MCP servers.
@@ -701,7 +739,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
defaultHandler, authErr := NewDefaultMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote servers.
charmlog.Warn("Failed to create OAuth handler; remote MCP servers requiring auth will fail", "error", authErr)
log.Printf("WARN Failed to create OAuth handler; remote MCP servers requiring auth will fail: %v", authErr)
} else {
setupOpts.AuthHandler = defaultHandler
}
@@ -719,16 +757,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, err
}
// Initialize tree session.
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
// Initialize session manager.
var sessionManager SessionManager
if opts.SessionManager != nil {
// Use custom session manager provided by user.
sessionManager = opts.SessionManager
} else {
// DEFAULT: Use built-in TreeManager (existing behavior).
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
}
// Wrap TreeManager in adapter to satisfy SessionManager interface.
sessionManager = NewTreeManagerAdapter(treeSession)
}
k := &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
@@ -1262,14 +1309,22 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
IsStderr: isStderr,
})
},
// Persist step messages incrementally so that progress survives
// crashes and long-running turns don't lose work. Each step's
// messages are persisted as a unit: for tool-calling steps this is
// the assistant message (with tool_use parts) + tool-role message
// (with tool_result parts) as a pair; for the final step it's the
// assistant text/reasoning message alone.
func(stepMessages []fantasy.Message) {
for _, msg := range stepMessages {
_, _ = m.session.AppendMessage(msg)
}
},
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
// Emit step usage event for real-time cost tracking
if viper.GetBool("debug") {
charmlog.Debug("Kit.generate emitting StepUsageEvent",
"input", inputTokens,
"output", outputTokens,
"cacheRead", cacheReadTokens,
"cacheCreate", cacheCreationTokens,
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
)
}
m.events.emit(StepUsageEvent{
@@ -1287,11 +1342,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
// 2. Persist pre-generation messages to the tree session.
// 3. Build context from the tree (walks leaf-to-root for current branch).
// 4. Emit turn/message start events.
// 5. Run generation.
// 6. Emit turn/message end events.
// 7. Persist post-generation messages (tool calls, results, assistant).
// 5. Run generation (messages are persisted incrementally per step).
// 6. Persist any remaining messages not covered by incremental persistence.
// 7. Emit turn/message end events.
// 8. Run AfterTurn hooks.
//
// During generation, each completed step's messages are persisted immediately
// via the onStepMessages callback. Tool calls are always persisted as
// call/response pairs (assistant + tool messages together). Reasoning and
// text-only assistant messages are persisted as soon as their step completes.
// This ensures long-running turns don't lose progress on crash or cancellation.
//
// promptLabel is the human-readable label emitted in TurnStartEvent.Prompt.
// prompt is the raw user text passed to BeforeTurn hooks.
func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, preMessages []fantasy.Message) (*TurnResult, error) {
@@ -1336,9 +1397,9 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
}
// Persist pre-generation messages to tree session.
// Persist pre-generation messages to session.
for _, msg := range preMessages {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
// Auto-compact if enabled and conversation is near the context limit.
@@ -1346,8 +1407,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
_, _ = m.compactInternal(ctx, m.compactionOpts, "", true) // best-effort, automatic
}
// Build context from the tree so only the current branch is sent.
messages := m.treeSession.GetLLMMessages()
// Build context from the session so only the current branch is sent.
messages, _, _ := m.session.BuildContext()
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
@@ -1361,16 +1422,18 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
result, err := m.generate(ctx, messages)
if err != nil {
// Persist any messages from completed steps (tool call/result
// pairs) so partial progress is not lost. The agent layer only
// includes fully-paired tool_use + tool_result messages in
// completedStepMessages, so there are no orphaned entries that
// would break subsequent API requests. The user message and any
// completed work remain in the session; only the in-progress
// (pending) message or tool call is discarded.
if result != nil && len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
// Persist any messages from completed steps that were NOT already
// persisted incrementally by the onStepMessages callback. The agent
// layer only includes fully-paired tool_use + tool_result messages
// in completedStepMessages, so there are no orphaned entries that
// would break subsequent API requests.
if result != nil {
newMessages := result.ConversationMessages[sentCount:]
alreadyPersisted := result.PersistedMessageCount
if alreadyPersisted < len(newMessages) {
for _, msg := range newMessages[alreadyPersisted:] {
_, _ = m.session.AppendMessage(msg)
}
}
}
m.events.emit(TurnEndEvent{Error: err})
@@ -1381,12 +1444,17 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
responseText := result.FinalResponse.Content.Text()
// Persist new messages (tool calls, tool results, assistant response)
// BEFORE emitting events so that extension handlers calling
// GetContextStats() see up-to-date token counts.
// Persist any new messages that were NOT already persisted incrementally
// by the onStepMessages callback during generation. This handles the
// non-streaming path (where onStepMessages is not called) and any edge
// cases where the final response messages weren't covered by step callbacks.
if len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
newMessages := result.ConversationMessages[sentCount:]
alreadyPersisted := result.PersistedMessageCount
if alreadyPersisted < len(newMessages) {
for _, msg := range newMessages[alreadyPersisted:] {
_, _ = m.session.AppendMessage(msg)
}
}
}
@@ -1468,7 +1536,7 @@ func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
// Returns an error if there are no previous messages in the session.
func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
// Verify there is conversation history to follow up on.
if len(m.treeSession.GetLLMMessages()) == 0 {
if len(m.session.GetMessages()) == 0 {
return "", fmt.Errorf("cannot follow up: no previous messages")
}
@@ -1624,10 +1692,12 @@ func (m *Kit) PromptResultWithMessages(ctx context.Context, messages []string) (
return m.runTurn(ctx, promptLabel, messages[len(messages)-1], preMessages)
}
// ClearSession resets the tree session's leaf pointer to the root, starting
// ClearSession resets the session's leaf pointer to the root, starting
// a fresh conversation branch.
func (m *Kit) ClearSession() {
m.treeSession.ResetLeaf()
if m.session != nil {
_ = m.session.Branch("")
}
}
// GetModelString returns the current model string identifier (e.g.,
@@ -1696,8 +1766,8 @@ func (m *Kit) Close() error {
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionShutdown) {
_, _ = m.extRunner.Emit(extensions.SessionShutdownEvent{})
}
if m.treeSession != nil {
_ = m.treeSession.Close()
if m.session != nil {
_ = m.session.Close()
}
// Release the OAuth callback port if we own the handler.
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
@@ -1705,3 +1775,5 @@ func (m *Kit) Close() error {
}
return m.agent.Close()
}
// Conversion helpers are defined in adapter.go.
+144
View File
@@ -0,0 +1,144 @@
package kit
import (
"time"
)
// SessionManager defines the contract for conversation storage backends.
// Implementations can use files (default), databases, cloud storage, etc.
//
// Implementations must be safe for concurrent use. During generation,
// AppendMessage is called incrementally from the agent's step-completion
// callback while read methods (GetMessages, GetCurrentBranch, etc.) may be
// called concurrently from the UI or extension goroutines.
type SessionManager interface {
// AppendMessage adds a message to the current branch and returns its entry ID.
// The entry ID is used for tree navigation and must be unique within the session.
//
// During generation, AppendMessage is called incrementally after each
// completed agent step rather than in a batch at the end of the turn.
// For tool-calling steps, the assistant message (containing tool_use parts)
// and the tool-role message (containing tool_result parts) are appended
// together as a pair. This ensures the session never contains an orphaned
// tool call without its result, which would break subsequent LLM requests.
AppendMessage(msg LLMMessage) (entryID string, err error)
// GetMessages returns all messages on the current branch (from root to leaf),
// including any compaction summaries at the appropriate positions.
GetMessages() []LLMMessage
// BuildContext returns the message history to send to the LLM, applying
// compaction rules and branch summaries as needed.
// Returns: messages, currentProvider, currentModelID
BuildContext() (messages []LLMMessage, provider string, modelID string)
// Branch moves the leaf pointer to the given entry ID, creating a branch point.
// Subsequent AppendMessage calls extend from this new position.
// entryID can be empty to reset to root (new conversation branch).
Branch(entryID string) error
// GetCurrentBranch returns the path from root to current leaf as entry metadata.
// Used for UI display and navigation.
GetCurrentBranch() []BranchEntry
// GetChildren returns direct child entry IDs for a given parent entry.
// Used to display branch points in the conversation tree.
GetChildren(parentID string) []string
// GetEntry returns a specific entry by ID, or nil if not found.
GetEntry(entryID string) *BranchEntry
// GetSessionID returns the unique session identifier (UUID).
GetSessionID() string
// GetSessionName returns the user-defined display name, or empty.
GetSessionName() string
// SetSessionName sets a display name for the session.
SetSessionName(name string) error
// GetCreatedAt returns when the session was created.
GetCreatedAt() time.Time
// IsPersisted returns true if this session writes to durable storage.
IsPersisted() bool
// AppendCompaction adds a compaction entry that summarizes older messages.
// firstKeptEntryID is the ID of the first message to preserve in context.
// readFiles and modifiedFiles track file changes for the compaction summary.
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
// GetLastCompaction returns the most recent compaction entry on the current
// branch, or nil if none exists.
GetLastCompaction() *CompactionEntry
// AppendExtensionData stores custom extension data in the session tree.
// Extensions use this to persist state across restarts.
AppendExtensionData(extType, data string) (string, error)
// GetExtensionData returns all extension data entries of the given type
// on the current branch. If extType is empty, returns all extension data.
GetExtensionData(extType string) []ExtensionDataEntry
// AppendModelChange records a provider/model switch in the session.
AppendModelChange(provider, modelID string) (string, error)
// GetContextEntryIDs returns the entry IDs corresponding to the messages
// returned by BuildContext, in the same order. Used by compaction to
// determine which entries to summarize.
GetContextEntryIDs() []string
// Close releases resources (database connections, file handles, etc.).
Close() error
}
// BranchEntry represents a single node in the conversation tree.
// This is a SDK-friendly struct (not the internal entry types).
type BranchEntry struct {
ID string
ParentID string
Type EntryType // "message", "branch_summary", "model_change", "compaction", "extension_data"
Role string // for messages: "user", "assistant", "system", "tool"
Content string // text content or summary
Model string // model used (for messages and model_change)
Provider string // provider used
Timestamp time.Time
Children []string // child entry IDs (for tree display)
// RawParts contains the full typed content parts for structured access.
// Only populated for message entries.
RawParts []ContentPart
}
// EntryType identifies the kind of entry in the session tree.
type EntryType string
const (
EntryTypeMessage EntryType = "message"
EntryTypeBranchSummary EntryType = "branch_summary"
EntryTypeModelChange EntryType = "model_change"
EntryTypeCompaction EntryType = "compaction"
EntryTypeExtensionData EntryType = "extension_data"
)
// CompactionEntry represents a context compaction/summarization event.
type CompactionEntry struct {
ID string
Summary string
FirstKeptEntryID string
TokensBefore int
TokensAfter int
MessagesRemoved int
ReadFiles []string
ModifiedFiles []string
Timestamp time.Time
}
// ExtensionDataEntry represents custom extension data stored in the session.
type ExtensionDataEntry struct {
ID string
ExtType string
Data string
Timestamp time.Time
}
+111 -80
View File
@@ -8,7 +8,6 @@ import (
"time"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/session"
)
@@ -47,49 +46,73 @@ func OpenTreeSession(path string) (*TreeManager, error) {
// --- Instance methods on Kit ---
// GetSessionManager returns the session manager, or nil if not configured.
func (m *Kit) GetSessionManager() SessionManager {
return m.session
}
// GetTreeSession returns the tree session manager, or nil if not configured.
// Deprecated: Use GetSessionManager instead.
func (m *Kit) GetTreeSession() *TreeManager {
return m.treeSession
// Try to unwrap the adapter if using default implementation
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner
}
return nil
}
// SetSessionManager replaces the session manager on a Kit instance.
func (m *Kit) SetSessionManager(sm SessionManager) {
m.session = sm
}
// SetTreeSession replaces the tree session on a Kit instance. This is used by
// the CLI when it handles session creation externally (e.g. --resume with a
// TUI picker) and needs to inject the result into a Kit-like workflow.
// Deprecated: Use SetSessionManager instead.
func (m *Kit) SetTreeSession(ts *TreeManager) {
m.treeSession = ts
m.session = NewTreeManagerAdapter(ts)
}
// GetSessionPath returns the file path of the active tree session, or empty
// for in-memory sessions or when no tree session is configured.
// GetSessionPath returns the file path of the active session, or empty
// for in-memory sessions or when no file-based session is configured.
func (m *Kit) GetSessionPath() string {
if m.treeSession != nil {
return m.treeSession.GetFilePath()
// Only file-based sessions have a path
// Try to get it from the underlying TreeManager if using default adapter
if m.session == nil {
return ""
}
// Check if it's the default adapter
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner.GetFilePath()
}
return ""
}
// GetSessionID returns the UUID of the active tree session, or empty when no
// tree session is configured.
// GetSessionID returns the UUID of the active session, or empty when no
// session is configured.
func (m *Kit) GetSessionID() string {
if m.treeSession != nil {
return m.treeSession.GetSessionID()
if m.session == nil {
return ""
}
return ""
return m.session.GetSessionID()
}
// Branch moves the tree session's leaf pointer to the given entry ID, creating
// Branch moves the session's leaf pointer to the given entry ID, creating
// a branch point. Subsequent Prompt() calls will extend from the new position.
func (m *Kit) Branch(entryID string) error {
return m.treeSession.Branch(entryID)
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.session.Branch(entryID)
}
// SetSessionName sets a user-defined display name for the active tree session.
// SetSessionName sets a user-defined display name for the active session.
func (m *Kit) SetSessionName(name string) error {
if m.treeSession == nil {
return fmt.Errorf("session naming requires a tree session")
if m.session == nil {
return fmt.Errorf("session naming requires a session")
}
_, err := m.treeSession.AppendSessionInfo(name)
return err
return m.session.SetSessionName(name)
}
// ---------------------------------------------------------------------------
@@ -97,27 +120,27 @@ func (m *Kit) SetSessionName(name string) error {
// ---------------------------------------------------------------------------
// GetTreeNode returns a node by ID with full metadata and children.
// Returns nil if entry not found or no tree session.
// Returns nil if entry not found or no session.
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
entry := m.treeSession.GetEntry(entryID)
entry := m.session.GetEntry(entryID)
if entry == nil {
return nil
}
return m.entryToTreeNode(entry)
return m.branchEntryToTreeNode(entry)
}
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
func (m *Kit) GetCurrentBranch() []TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var nodes []TreeNode
for _, entry := range branch {
node := m.entryToTreeNode(entry)
node := m.branchEntryToTreeNode(&entry)
if node != nil {
nodes = append(nodes, *node)
}
@@ -127,34 +150,34 @@ func (m *Kit) GetCurrentBranch() []TreeNode {
// GetChildren returns direct child IDs of an entry.
func (m *Kit) GetChildren(parentID string) []string {
if m.treeSession == nil {
if m.session == nil {
return nil
}
return m.treeSession.GetChildren(parentID)
return m.session.GetChildren(parentID)
}
// NavigateTo branches/forks the session to the specified entry ID.
// Returns an error if the session is unavailable or the entry ID is not found.
func (m *Kit) NavigateTo(entryID string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.treeSession.Branch(entryID)
return m.session.Branch(entryID)
}
// SummarizeBranch uses the LLM to summarize the conversation between two
// entry IDs. Returns the summary text, or an error if the range is invalid,
// the session is unavailable, or the LLM call fails.
func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
if m.treeSession == nil {
return "", fmt.Errorf("no tree session available")
if m.session == nil {
return "", fmt.Errorf("no session available")
}
// Get the branch and find the range
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var startIdx, endIdx = -1, -1
for i, entry := range branch {
id := m.treeSession.EntryID(entry)
id := entry.ID
if id == fromID {
startIdx = i
}
@@ -170,7 +193,7 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// Build text to summarize
var content strings.Builder
for i := startIdx; i <= endIdx; i++ {
node := m.entryToTreeNode(branch[i])
node := m.branchEntryToTreeNode(&branch[i])
if node != nil && node.Content != "" {
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
}
@@ -195,73 +218,81 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// CollapseBranch replaces a branch range with a summary entry.
// Returns an error if the session is unavailable or the operation fails.
func (m *Kit) CollapseBranch(fromID, toID, summary string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
return err
// Note: This operation is not directly supported by SessionManager interface
// as it requires AppendBranchSummary which is TreeManager-specific.
// For custom SessionManagers, this would need to be implemented differently.
// For now, we try to use the underlying TreeManager if available.
if adapter, ok := m.session.(*treeManagerAdapter); ok {
_, err := adapter.inner.AppendBranchSummary(fromID, summary)
return err
}
return fmt.Errorf("CollapseBranch not supported by custom session manager")
}
// entryToTreeNode converts a session entry to a TreeNode.
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// branchEntryToTreeNode converts a BranchEntry to a TreeNode.
func (m *Kit) branchEntryToTreeNode(entry *BranchEntry) *TreeNode {
if entry == nil {
return nil
}
switch entry.Type {
case EntryTypeMessage:
// Build content from RawParts
var content strings.Builder
for _, p := range msg.Parts {
for _, p := range entry.RawParts {
switch pt := p.(type) {
case message.TextContent:
case TextContent:
content.WriteString(pt.Text)
case message.ReasoningContent:
case ReasoningContent:
content.WriteString(pt.Thinking)
case message.ToolCall:
case ToolCall:
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
case message.ToolResult:
case ToolResult:
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
}
}
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "message",
Role: string(msg.Role),
Role: entry.Role,
Content: content.String(),
Model: msg.Model,
Provider: msg.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.BranchSummaryEntry:
case EntryTypeBranchSummary:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "branch_summary",
Content: e.Summary,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ModelChangeEntry:
case EntryTypeModelChange:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "model_change",
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
Model: e.Provider + "/" + e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ExtensionDataEntry:
case EntryTypeExtensionData:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "extension_data",
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
default:
return nil
+119
View File
@@ -1,6 +1,8 @@
package kit
import (
"context"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/core"
@@ -16,6 +18,123 @@ type ToolOption = core.ToolOption
// If empty, os.Getwd() is used at execution time.
var WithWorkDir = core.WithWorkDir
// --- Custom tool creation ---
// ToolOutput is the return value from custom tool handlers created with
// [NewTool] or [NewParallelTool]. It provides a dependency-free way to
// return results without importing the underlying LLM framework.
type ToolOutput struct {
// Content is the text content returned to the LLM.
Content string
// IsError, when true, signals to the LLM that the tool call failed.
IsError bool
// Data contains optional binary data (images, audio, etc.).
Data []byte
// MediaType is the MIME type for binary Data (e.g. "image/png").
MediaType string
// Metadata is optional opaque metadata attached to the response.
// It is not sent to the LLM but may be consumed by hooks or the UI.
Metadata any
}
// TextResult creates a successful text [ToolOutput].
func TextResult(content string) ToolOutput {
return ToolOutput{Content: content}
}
// ErrorResult creates an error [ToolOutput]. The LLM will see the content
// as a tool error, allowing it to retry or adjust its approach.
func ErrorResult(content string) ToolOutput {
return ToolOutput{Content: content, IsError: true}
}
// toolCallIDKey is the context key for the tool call ID.
type toolCallIDKey struct{}
// ToolCallIDFromContext extracts the tool call ID from the context.
// The call ID is set automatically by [NewTool] and [NewParallelTool]
// before invoking the handler. Returns an empty string if no ID is present.
func ToolCallIDFromContext(ctx context.Context) string {
s, _ := ctx.Value(toolCallIDKey{}).(string)
return s
}
// NewTool creates a custom [Tool] with automatic JSON schema generation from
// the TInput struct type. The handler receives a typed input (deserialized
// from the LLM's JSON arguments) and returns a [ToolResult].
//
// Struct tags on TInput control the generated schema:
//
// json:"name" → parameter name
// description:"..." → parameter description shown to the LLM
// enum:"a,b,c" → restrict valid values
// omitempty → marks the parameter as optional
//
// The tool call ID is injected into the context and can be retrieved with
// [ToolCallIDFromContext].
//
// Example:
//
// type WeatherInput struct {
// City string `json:"city" description:"City name"`
// }
//
// tool := kit.NewTool("get_weather", "Get weather for a city",
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
// return kit.TextResult("72°F, sunny in " + input.City), nil
// },
// )
func NewTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// NewParallelTool is like [NewTool] but marks the tool as safe for concurrent
// execution alongside other tools. Use this when the tool has no side effects
// or when concurrent calls are safe.
func NewParallelTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewParallelAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// --- Individual tool constructors ---
// NewReadTool creates a file-reading tool.
+119
View File
@@ -0,0 +1,119 @@
package kit_test
import (
"context"
"testing"
kit "github.com/mark3labs/kit/pkg/kit"
)
// TestNewTool_BasicTextResult verifies that NewTool creates a working tool
// that returns text content via ToolOutput.
func TestNewTool_BasicTextResult(t *testing.T) {
type Input struct {
Name string `json:"name"`
}
tool := kit.NewTool("greet", "Greet someone",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("hello " + input.Name), nil
},
)
info := tool.Info()
if info.Name != "greet" {
t.Errorf("Info().Name = %q, want %q", info.Name, "greet")
}
if info.Description != "Greet someone" {
t.Errorf("Info().Description = %q, want %q", info.Description, "Greet someone")
}
if info.Parallel {
t.Error("NewTool should not mark tool as parallel")
}
}
// TestNewParallelTool_MarkedParallel verifies that NewParallelTool marks the
// tool as safe for concurrent execution.
func TestNewParallelTool_MarkedParallel(t *testing.T) {
type Input struct {
Query string `json:"query"`
}
tool := kit.NewParallelTool("search", "Search for things",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("found: " + input.Query), nil
},
)
info := tool.Info()
if info.Name != "search" {
t.Errorf("Info().Name = %q, want %q", info.Name, "search")
}
if !info.Parallel {
t.Error("NewParallelTool should mark tool as parallel")
}
}
// TestTextResult verifies the TextResult convenience constructor.
func TestTextResult(t *testing.T) {
r := kit.TextResult("ok")
if r.Content != "ok" {
t.Errorf("Content = %q, want %q", r.Content, "ok")
}
if r.IsError {
t.Error("TextResult should not set IsError")
}
}
// TestErrorResult verifies the ErrorResult convenience constructor.
func TestErrorResult(t *testing.T) {
r := kit.ErrorResult("bad input")
if r.Content != "bad input" {
t.Errorf("Content = %q, want %q", r.Content, "bad input")
}
if !r.IsError {
t.Error("ErrorResult should set IsError")
}
}
// TestToolCallIDFromContext verifies round-trip context injection.
func TestToolCallIDFromContext(t *testing.T) {
// Empty context returns empty string.
if id := kit.ToolCallIDFromContext(context.Background()); id != "" {
t.Errorf("expected empty string from bare context, got %q", id)
}
}
// TestToolOutput_Metadata verifies that metadata can be set on ToolOutput.
func TestToolOutput_Metadata(t *testing.T) {
r := kit.ToolOutput{
Content: "data",
Metadata: map[string]string{"key": "value"},
}
if r.Metadata == nil {
t.Error("expected non-nil Metadata")
}
m, ok := r.Metadata.(map[string]string)
if !ok {
t.Fatalf("expected map[string]string, got %T", r.Metadata)
}
if m["key"] != "value" {
t.Errorf("Metadata[key] = %q, want %q", m["key"], "value")
}
}
// TestToolOutput_BinaryData verifies that binary data fields work correctly.
func TestToolOutput_BinaryData(t *testing.T) {
data := []byte{0x89, 0x50, 0x4E, 0x47}
r := kit.ToolOutput{
Content: "image result",
Data: data,
MediaType: "image/png",
}
if len(r.Data) != 4 {
t.Errorf("Data len = %d, want 4", len(r.Data))
}
if r.MediaType != "image/png" {
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
}
}
+144 -2
View File
@@ -85,10 +85,15 @@ host, err := kit.New(ctx, &kit.Options{
SessionPath: "/path/to/session.jsonl", // open specific session file
Continue: true, // resume most recent session for SessionDir
NoSession: true, // ephemeral in-memory session, no disk persistence
SessionManager: myCustomSession, // custom SessionManager implementation (advanced)
// Tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
Tools: []kit.Tool{kit.NewBashTool()}, // REPLACES entire default tool set
ExtraTools: []kit.Tool{myTool}, // ADDS alongside core/MCP/extension tools
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Skills
Skills: []string{"/path/to/skill.md"}, // explicit skill files (empty = auto-discover)
@@ -342,6 +347,77 @@ Lower values run first. Within the same priority, registration order applies. Fi
## Tools
### Creating custom tools
Use `kit.NewTool` to create custom tools. The JSON schema is auto-generated from the input struct — no external dependencies required:
```go
type WeatherInput struct {
City string `json:"city" description:"City name, e.g. 'San Francisco'"`
}
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
// Your logic here (API calls, database lookups, etc.)
return kit.TextResult("72°F, sunny in " + input.City), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{weatherTool},
})
```
**Struct tags** control the generated schema:
| Tag | Purpose | Example |
|-----|---------|---------|
| `json:"name"` | Parameter name | `json:"city"` |
| `description:"..."` | Description shown to the LLM | `description:"City name"` |
| `enum:"a,b,c"` | Restrict valid values | `enum:"json,text,csv"` |
| `omitempty` | Marks parameter as optional | `json:"limit,omitempty"` |
**Return helpers:**
| Function | Description |
|----------|-------------|
| `kit.TextResult(content)` | Successful text result |
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
**ToolOutput fields** (for advanced use):
```go
kit.ToolOutput{
Content: "result text", // text returned to the LLM
IsError: false, // true = LLM sees this as an error
Data: pngBytes, // optional binary data (images, audio)
MediaType: "image/png", // MIME type for binary Data
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
}
```
**Parallel tools** — mark as safe for concurrent execution:
```go
searchTool := kit.NewParallelTool("search", "Search the web",
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
return kit.TextResult("results..."), nil
},
)
```
**Tool call ID** — available in context for logging/tracing:
```go
tool := kit.NewTool("my_tool", "...",
func(ctx context.Context, input MyInput) (kit.ToolOutput, error) {
callID := kit.ToolCallIDFromContext(ctx) // correlation ID from the LLM
log.Printf("[%s] my_tool called", callID)
return kit.TextResult("ok"), nil
},
)
```
### Built-in tool constructors
```go
@@ -431,6 +507,72 @@ kit.DeleteSession("/path/to/session.jsonl")
tm, _ := kit.OpenTreeSession("/path/to/session.jsonl") // open for direct access
```
### Custom Session Manager (Advanced)
You can provide a custom session manager to store conversation history in your own backend (database, cloud storage, etc.) instead of the default JSONL files.
```go
// Implement the SessionManager interface
type MyDatabaseSessionManager struct {
db *sql.DB
// ... other fields
}
func (s *MyDatabaseSessionManager) AppendMessage(msg kit.LLMMessage) (string, error) {
// Store message in your database
}
func (s *MyDatabaseSessionManager) GetMessages() []kit.LLMMessage {
// Retrieve messages from your database
}
// ... implement all other SessionManager methods
// Use with Kit
host, _ := kit.New(ctx, &kit.Options{
SessionManager: myCustomSession, // Your custom implementation
Model: "anthropic/claude-sonnet-latest",
})
```
**SessionManager Interface:**
```go
type SessionManager interface {
AppendMessage(msg kit.LLMMessage) (entryID string, err error)
GetMessages() []kit.LLMMessage
BuildContext() (messages []kit.LLMMessage, provider string, modelID string)
Branch(entryID string) error
GetCurrentBranch() []kit.BranchEntry
GetChildren(parentID string) []string
GetEntry(entryID string) *kit.BranchEntry
GetSessionID() string
GetSessionName() string
SetSessionName(name string) error
GetCreatedAt() time.Time
IsPersisted() bool
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
GetLastCompaction() *kit.CompactionEntry
AppendExtensionData(extType, data string) (string, error)
GetExtensionData(extType string) []kit.ExtensionDataEntry
AppendModelChange(provider, modelID string) (string, error)
GetContextEntryIDs() []string
Close() error
}
```
**Use Cases:**
- **PocketBase integration**: Store sessions as PocketBase records
- **Cloud storage**: Persist sessions to S3, GCS, or Azure Blob
- **Multi-user apps**: Store sessions per user in a database
- **Custom retention**: Implement your own session cleanup policies
**Note:** When using a custom SessionManager, the following Options are ignored:
- `SessionPath` - your manager handles its own storage
- `Continue` - your manager handles session selection
- `NoSession` - use an in-memory implementation instead
---
## Model Management
+49 -21
View File
@@ -7,17 +7,16 @@ description: Monitor tool calls and streaming output with the Kit Go SDK.
## Event-based monitoring
For more granular control, use the event subscription API:
Subscribe to events for real-time monitoring. Each method returns an unsubscribe function:
```go
// Subscribe returns an unsubscribe function
unsub := host.OnToolCall(func(event kit.ToolCallEvent) {
fmt.Printf("Tool: %s, Args: %s\n", event.Name, event.Args)
fmt.Printf("Tool: %s, Args: %s\n", event.ToolName, event.ToolArgs)
})
defer unsub()
unsub2 := host.OnToolResult(func(event kit.ToolResultEvent) {
fmt.Printf("Result: %s (error: %v)\n", event.Name, event.IsError)
fmt.Printf("Result: %s (error: %v)\n", event.ToolName, event.IsError)
})
defer unsub2()
@@ -44,33 +43,62 @@ defer unsub6()
## Hook system
Hooks allow you to intercept and modify behavior. Unlike events, hooks can modify or cancel operations:
Hooks can **modify or cancel** operations. Unlike events (read-only), hooks are read-write interceptors.
### BeforeToolCall — block tool execution
```go
// Intercept tool calls before execution
host.OnBeforeToolCall(0, func(ctx context.Context, name string, args string) (string, error) {
if name == "bash" {
log.Println("Bash command:", args)
host.OnBeforeToolCall(kit.HookPriorityNormal, func(h kit.BeforeToolCallHook) *kit.BeforeToolCallResult {
// h.ToolCallID, h.ToolName, h.ToolArgs
if h.ToolName == "bash" && strings.Contains(h.ToolArgs, "rm -rf") {
return &kit.BeforeToolCallResult{Block: true, Reason: "dangerous command"}
}
return args, nil // return modified args or error to cancel
return nil // allow
})
```
// Process results after tool execution
host.OnAfterToolResult(0, func(ctx context.Context, name string, result string) (string, error) {
return result, nil
})
### AfterToolResult — modify tool output
// Before/after each agent turn
host.OnBeforeTurn(0, func(ctx context.Context) error {
return nil
})
host.OnAfterTurn(0, func(ctx context.Context) error {
```go
host.OnAfterToolResult(kit.HookPriorityNormal, func(h kit.AfterToolResultHook) *kit.AfterToolResultResult {
// h.ToolCallID, h.ToolName, h.ToolArgs, h.Result, h.IsError
if h.ToolName == "read" {
filtered := redactSecrets(h.Result)
return &kit.AfterToolResultResult{Result: &filtered}
}
return nil
})
```
The first argument is a priority (lower = runs first).
### BeforeTurn — modify prompt, inject messages
```go
host.OnBeforeTurn(kit.HookPriorityNormal, func(h kit.BeforeTurnHook) *kit.BeforeTurnResult {
// h.Prompt
newPrompt := h.Prompt + "\nAlways respond in JSON."
return &kit.BeforeTurnResult{Prompt: &newPrompt}
// Also available: SystemPrompt *string, InjectText *string
})
```
### AfterTurn — observation only
```go
host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
// h.Response, h.Error
log.Printf("Turn completed: %d chars", len(h.Response))
})
```
### Hook priorities
```go
kit.HookPriorityHigh = 0 // runs first
kit.HookPriorityNormal = 50 // default
kit.HookPriorityLow = 100 // runs last
```
Lower values run first. First non-nil result wins.
## Subagent event monitoring
+33 -2
View File
@@ -29,8 +29,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true,
// Tools
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true,
@@ -58,7 +62,34 @@ host, err := kit.New(ctx, &kit.Options{
| `NoSession` | `bool` | `false` | Ephemeral mode (no persistence) |
| `Tools` | `[]Tool` | — | Replace the entire default tool set |
| `ExtraTools` | `[]Tool` | — | Additional tools alongside core/MCP/extension tools |
| `DisableCoreTools` | `bool` | `false` | Use no core tools (0 tools, for chat-only) |
| `SkipConfig` | `bool` | `false` | Skip .kit.yml file loading |
| `AutoCompact` | `bool` | `false` | Auto-compact when near context limit |
| `CompactionOptions` | `*CompactionOptions` | — | Configuration for auto-compaction |
| `Skills` | `[]string` | — | Explicit skill files/dirs to load |
| `SkillsDir` | `string` | — | Override default skills directory |
## Tool configuration
**`Tools`** replaces ALL default tools (core + MCP + extension). **`ExtraTools`** adds tools alongside the defaults. Use `Tools` to restrict capabilities; use `ExtraTools` to extend them.
Create custom tools with `kit.NewTool` — no external dependencies needed:
```go
type LookupInput struct {
ID string `json:"id" description:"Record ID to look up"`
}
lookupTool := kit.NewTool("lookup", "Look up a record by ID",
func(ctx context.Context, input LookupInput) (kit.ToolOutput, error) {
record := db.Find(input.ID)
return kit.TextResult(record.String()), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{lookupTool},
})
```
See [Overview](/sdk/overview#custom-tools) for full custom tool documentation.
+38
View File
@@ -68,6 +68,44 @@ The SDK provides several prompt variants:
| `Steer(ctx, instruction)` | System-level steering without user message |
| `FollowUp(ctx, text)` | Continue without new user input |
## Custom tools
Create custom tools with `kit.NewTool`. The JSON schema is auto-generated from the input struct — no external dependencies required:
```go
type WeatherInput struct {
City string `json:"city" description:"City name"`
}
weatherTool := kit.NewTool("get_weather", "Get current weather for a city",
func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
return kit.TextResult("72°F, sunny in " + input.City), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{weatherTool},
})
```
Struct tags control the schema:
- `json:"name"` — parameter name
- `description:"..."` — description shown to the LLM
- `enum:"a,b,c"` — restrict valid values
- `omitempty` — marks the parameter as optional
Return values:
| Helper | Description |
|--------|-------------|
| `kit.TextResult(s)` | Successful text result |
| `kit.ErrorResult(s)` | Error result (LLM sees it as a tool error) |
For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing.
## Event system
Subscribe to events for monitoring: