Compare commits

...

60 Commits

Author SHA1 Message Date
Ed Zynda 3ffc995f27 feat(sdk): add NewTool/NewParallelTool for dependency-free custom tools
- Add ToolOutput struct, TextResult/ErrorResult helpers, and
  ToolCallIDFromContext so SDK consumers can create custom tools
  without importing charm.land/fantasy
- Add NewTool (sequential) and NewParallelTool (concurrent) generic
  constructors with automatic JSON schema generation from struct tags
- Remove dead UpdateUsageFromResponse method and fantasy import from
  internal/ui/cli.go
- Update SDK skill, README, and www/ docs with custom tool examples
  and corrected hook signatures
2026-04-07 22:05:42 +03:00
Ed Zynda b2bd016135 fix(tui): redirect log output to file to prevent TUI corruption
- Add tea.LogToFile in runInteractiveModeBubbleTea to send stdlib log
  output to /tmp/kit/kit.log instead of stderr
- Replace charmbracelet/log with stdlib log in extensions loader,
  runner, watcher, prompts loader, and pkg/kit so all log calls go
  through the redirected stdlib logger
- Leave charmbracelet/log in CLI-only commands (install, acp) and
  acpserver where stderr logging is correct
2026-04-07 21:20:04 +03:00
Ed Zynda 812dedaea2 feat(pkg/kit): add SessionManager interface for custom session backends
Add SessionManager interface to allow pluggable session storage backends.
This enables users to implement custom session managers for databases,
cloud storage, or other persistence mechanisms instead of the default
JSONL file-based TreeManager.

Changes:
- Add SessionManager interface with methods for message storage,
  tree navigation, compaction, and extension data
- Add treeManagerAdapter to wrap existing TreeManager for backward compatibility
- Update Kit struct to use SessionManager interface instead of concrete type
- Add SessionManager option to Options struct
- Update all session-related methods to use interface
- Add documentation for custom SessionManager usage

The default behavior is preserved - when no SessionManager is provided,
Kit automatically uses the TreeManager via the adapter.
2026-04-07 17:41:46 +03:00
Ed Zynda f65b6737f2 feat(sdk): add SkipConfig and DisableCoreTools options
Add two new Options fields for programmatic SDK usage:

- SkipConfig: Skip .kit.yml file loading while still using viper defaults
  and environment variables. Useful for fully programmatic configuration.

- DisableCoreTools: Allow creating agents with 0 tools (chat-only mode) or
  with only custom tools. When true and Tools is empty, no tools are loaded.
  When combined with custom Tools, only those tools are loaded.

Updates documentation in README, pkg/kit/README, skills/kit-sdk/SKILL,
and www/pages/sdk/options.
2026-04-07 17:10:58 +03:00
Ed Zynda 5d45aa196b fix(watcher): remove debug logging that corrupts TUI
Remove charmbracelet/log debug statements from the file watcher that
were writing directly to stderr, corrupting the Bubble Tea terminal UI.

- Remove log.Debug calls for directory operations and file changes
- Remove log.Warn for watcher errors (silently ignore instead)
- Remove the charmbracelet/log import entirely
2026-04-07 16:31:29 +03:00
Ed Zynda debb39f56c fix(ui): show MCP tools in /tools and status bar after async loading
Background MCP tool loading (added in 7e54710) caused tools to not appear
in the UI because tool names and counts were captured at startup before
loading completed. This adds:

- MCPToolsReadyEvent and MCPServerLoadedEvent for progress notifications
- Dynamic GetToolNames/GetMCPToolCount callbacks for live updates
- Per-server status messages as each MCP server finishes loading
- Refresh handlers to update /tools output and status bar when ready
2026-04-07 16:29:09 +03:00
Ed Zynda 7ce6f4fd9e fix(watcher): dynamically watch new subdirectories for skill/prompt reload
- Detect new subdirectory creation in the fsnotify event loop and add
  it to the watcher so files created inside trigger reload events
- Handle cp -r case by checking if new directories already contain
  matching files and scheduling an immediate debounced reload
- Add dirContainsMatchingFiles helper method
- Add tests for both new-subdirectory and copy-with-existing-files cases
2026-04-07 15:01:18 +03:00
Ed Zynda c2f2bdb3d3 feat: auto-reload custom prompts and skills on file change
- Add internal/watcher package with general-purpose ContentWatcher
  using fsnotify, configurable file extensions, and debouncing
- Add ContentReloadEvent and App.NotifyContentReload() for TUI signaling
- Add GetPromptTemplates/GetSkillItems callback fields on AppModelOptions
  following the existing GetExtensionCommands lazy-provider pattern
- Add Kit.ReloadSkills() to re-discover skills from disk
- Wire fsnotify watcher for .kit/prompts/, .kit/skills/, .agents/skills/,
  and global config directories, triggering on .md/.txt changes
- TUI refreshes autocomplete entries and skill list on reload
2026-04-07 14:09:59 +03:00
Ed Zynda 201d14804e fix(ui): prevent double-rendered messages after reasoning-only responses
- Always fire onResponse callback even when response text is empty so
  ResponseCompleteEvent reaches the TUI and resets the StreamComponent
- Check for existing StreamingMessageItem in flushStreamAndPendingUserMessages
  before creating a new StyledMessageItem to avoid duplicate content
- Mark trailing StreamingMessageItem complete on StepComplete, StepCancelled,
  and StepError to freeze live timers and prevent dangling streaming state
2026-04-07 13:52:30 +03:00
Ed Zynda 7e54710d4a perf(agent): load MCP tools asynchronously to speed up startup
Load MCP server tools in the background so the UI appears immediately
instead of blocking until all servers connect. The first LLM call
automatically waits for tools to be ready before proceeding.

Key changes:
- NewAgent() starts MCP loading in a background goroutine and returns
  immediately with core/extension tools only
- GenerateWithLoop() calls ensureMCPTools() to lazily wait and rebuild
  the fantasy agent with full tool set before first LLM call
- Parallelize LoadTools() across all configured MCP servers
- Add WaitForMCPTools() and MCPToolsReady() for status checking
- Refactor SetModel/SetExtraTools to use shared rebuildFantasyAgent()
- Expose async MCP status methods in public SDK
2026-04-07 13:36:10 +03:00
Ed Zynda 88870be4d2 feat: add frequency-penalty and presence-penalty parameters
- Add --frequency-penalty and --presence-penalty CLI flags (0.0-2.0)
- Wire through config, viper, ProviderConfig, and fantasy agent options
- Support in config file, env vars (KIT_FREQUENCY_PENALTY), and SDK
- Pass to Ollama via options map (frequency_penalty, presence_penalty)
- Apply on both initial agent creation and runtime model swap
2026-04-06 10:52:33 +03:00
Ed Zynda 46bf809715 chore(models): update embedded models.json from models.dev
- Providers: 97 -> 109 (+12 new)
- Models: 3039 -> 4156 (+1117 new)
- New providers: alibaba-coding-plan, alibaba-coding-plan-cn, clarifai,
  dinference, drun, llmgateway, perplexity-agent, tencent-coding-plan,
  the-grid-ai, xiaomi-token-plan-ams, xiaomi-token-plan-cn,
  xiaomi-token-plan-sgp
2026-04-06 09:50:43 +03:00
Ed Zynda e19e9642a2 feat(session): include system prompt and model in shared sessions
Add SystemPromptEntry type to capture system prompt, model, and provider
when sharing sessions via /share command. The entry is inserted into the
JSONL after the header and displayed in the web viewer as a collapsible
section with a model badge.

- Add SystemPromptEntry with Content, Model, and Provider fields
- Capture current system prompt and model at share time
- Display in web viewer with collapsible UI and model badge
- Update documentation for /share command
2026-04-04 19:33:02 +03:00
Ed Zynda 32675b8b35 chore(deps): update all go module dependencies
- mcp-go v0.46.0 → v0.47.0
- herald v0.11.0 → v0.13.0
- herald-md v0.2.0 → v0.3.0
- smithy-go v1.24.2 → v1.24.3
- otel v1.42.0 → v1.43.0
- googleapis/gax-go v2.20.0 → v2.21.0
- google.golang.org/api v0.273.1 → v0.274.0
- runewidth v0.0.21 → v0.0.22
- azure-sdk-internal v1.11.2 → v1.12.0
- various aws-sdk-go-v2 sub-modules patched
2026-04-04 18:11:56 +03:00
Ed Zynda aecce001ee feat(mcp): add OAuth support for remote MCP servers
- Add MCPAuthHandler interface at SDK level (pkg/kit/) so all consumers
  (CLI, TUI, SDK embedders) control the OAuth UX through one interface
- Default handler opens system browser + local callback server with PKCE
- CLIMCPAuthHandler wraps default with status messages (stderr pre-TUI,
  system messages via TUI event system once running)
- Always enable OAuth on remote transports (streamable HTTP, SSE) when
  handler is configured; harmless for servers that don't need it
- Dynamic client registration when no client ID is pre-configured
- File-based TokenStore persists tokens to ~/.config/.kit/mcp_tokens.json
  keyed by server URL so users don't re-auth on restart
- Catch OAuthAuthorizationRequiredError at connection init (startup) and
  tool execution (mid-session token expiry), run auth flow, retry once
- Fix error wrapping (%v -> %w) in connection pool so errors.As can
  unwrap through the chain to find OAuth errors
- Thread AuthHandler through MCPToolManager -> AgentConfig ->
  AgentCreationOptions -> AgentSetupOptions -> kit.Options
2026-04-04 17:41:57 +03:00
Ed Zynda 32d73171fd fix(extensions): write manifest Include in single pass and preserve on update
- InstallWithInclude wrote manifest twice via two different code paths,
  with the first write missing Include; unify into shared install() method
  that writes the manifest once with all fields including Include
- Update() now reads the existing manifest entry to preserve Include and
  Installed timestamp instead of constructing a fresh entry from scratch
2026-04-04 17:19:00 +03:00
Ed Zynda 265fd2ec0c fix(extensions): skip _test.go files and non-extension examples/ subdirs
- Filter out _test.go files in findExtensionsInDir, findExtensionsInRepo,
  and ScanForExtensions to prevent Yaegi from loading test files
- Narrow examples/ traversal so only recognized extension directories
  (extensions/, ext/, *-ext/, *-extensions/) are scanned, not arbitrary
  subdirs like examples/sdk/ that import pkg/kit
2026-04-04 16:44:13 +03:00
Ed Zynda efebf2eba6 fix(kit-telegram): add typing indicator and config fallback to global path
- Send sendChatAction("typing") every 4s while agent is processing,
  started on AgentStart and stopped on AgentEnd/SessionShutdown
- configPath() now checks project-local .kit/ first, then falls back
  to ~/.config/kit/kit-telegram.json for cross-project portability
2026-04-04 16:33:08 +03:00
Ed Zynda f7b655ae33 feat(extensions): add Abort, IsIdle, Compact, SendMultimodalMessage, GetSessionUsage to Context
- ctx.Abort(): cancel current agent turn and clear queue without
  injecting a new message (App.Abort + App.IsBusy methods)
- ctx.IsIdle(): check whether the agent is currently processing
- ctx.Compact(CompactConfig): trigger async context compaction with
  OnComplete/OnError callbacks (App.CompactAsync method)
- ctx.SendMultimodalMessage(text, []FilePart): send text+image messages
  to the agent, bridging ext.FilePart to fantasy.FilePart via RunWithFiles
- ctx.GetSessionUsage() SessionUsage: expose aggregated session token
  usage and cost from the UsageTracker

New types: CompactConfig, FilePart, SessionUsage
Wired in both context setups in cmd/root.go with nil-guard defaults
in runner.go and Yaegi symbol exports in symbols.go
2026-04-04 15:01:02 +03:00
Ed Zynda 35982b41ad fix(pkg): transparently handle <think> tags in stream
Move reasoning tag detection from the provider and UI layers into the agent layer. This prevents raw XML tags from leaking into text streams while ensuring structured reasoning events are emitted correctly for all callers.
2026-04-03 13:49:12 +03:00
Ed Zynda 788e3b71fd feat(config): per-model baseUrl and apiKey for custom models
- Add `baseUrl` and `apiKey` fields to CustomModelConfig (config and models packages)
- Store them on ModelInfo so they travel through the registry
- createCustomProvider resolves URL/key from model definition first,
  falling back to global --provider-url / --provider-api-key
- Fix registry initialisation: call ReloadGlobalRegistry() in InitConfig()
  so customModels from config are visible on startup (not just at init time)
- Include custom provider in GetLLMProviders() so custom models appear
  in the /model selector
- Hide the built-in custom/custom stub from the selector when user-defined
  custom models are present
2026-04-03 12:37:14 +03:00
Ed Zynda 3496bc2684 feat(ui): add bordered container and improved styling to session selector
- Add full-width bordered container with rounded border and primary color
- Add max height constraint to prevent terminal overflow
- Improve selection highlighting with inverted colors matching PopupList style
- Change cursor indicator from › to > for consistency
- Add separator lines between header, content, and footer
- Add footer showing current filter mode
2026-04-02 17:20:55 +03:00
Ed Zynda 997c7d15ff fix: include pasted images in steering messages
Steering messages (Ctrl+S during agent work) now carry file attachments
just like queued messages do. Previously, pasted images were silently
dropped when steering.

Changes:
- Add SteerMessage struct with Text and Files fields
- Update steer channel from chan string to chan SteerMessage
- Add SteerWithFiles methods through the stack (UI, app, SDK)
- Update PrepareStep to include files in injected user messages
2026-04-02 17:19:34 +03:00
Ed Zynda 83246e47d5 feat(ui): add bordered container and improved styling to tree selector
- Add full-width bordered container with rounded border and primary color
- Add max height constraint to prevent terminal overflow
- Improve selection highlighting with inverted colors matching PopupList style
- Change cursor indicator from › to > for consistency
- Use MutedBorder for tree lines and Success color for active marker
- Update search display format to match PopupList (
2026-04-02 17:18:16 +03:00
Ed Zynda 50e7b78c33 fix(ui): strip herald CodeBlock padding to fix mouse selection off-by-one
Herald's codeBlockWithLineNumbers() hardcodes PaddingTop(1) and
PaddingBottom(1), adding invisible blank lines with background color
above and below the code content. These padding lines occupy line
indices in the rendered item but are visually indistinguishable from
empty space, causing mouse click coordinates to map to the wrong
content line (consistently 1 row off in tool output blocks).

Strip the padding lines after CodeBlock rendering since the Compose
separator above and Figure caption below already provide adequate
visual spacing.
2026-04-02 16:49:44 +03:00
Ed Zynda b937af3056 refactor(ui): use herald Figure component for grep tool output
Add dedicated renderGrepBody function for the grep tool, replacing the
previous behavior of routing it through renderBashBody. The grep tool now:

- Shows a caption with total match count (e.g., '8 matches' or '1 match')
- Displays truncation info when matches exceed maxLsLines
- Uses consistent Figure component styling with ls, read, find, and bash tools
- Uses 'match/matches' terminology appropriate for grep results
2026-04-02 16:12:48 +03:00
Ed Zynda a5e995c750 refactor(ui): use herald Figure component for find tool output
Add dedicated renderFindBody function for the find tool, replacing the
previous behavior of routing it through renderBashBody. The find tool now:

- Shows a caption with total result count (e.g., '12 results')
- Displays truncation info when results exceed maxLsLines
- Uses consistent Figure component styling with ls, read, and bash tools
2026-04-02 16:11:49 +03:00
Ed Zynda e95e08a699 refactor(ui): use herald Figure component for ls tool output
Apply the same Figure component pattern to the ls tool for consistency
with read and bash tools. The caption now appears below the directory
listing and shows the count of hidden entries when truncated.
2026-04-02 16:10:00 +03:00
Ed Zynda bcaf92f62a refactor(ui): use herald Figure component for read and bash tool output
Replace inline truncation hints and exit code labels with herald's
Figure component. Captions now appear below content and show:

- read: filename • lines X-Y of Z • offset=N to continue
- bash: N more lines • exit code N

This provides consistent visual grouping and cleaner metadata
display for tool output blocks.
2026-04-02 16:09:17 +03:00
Ed Zynda ead4afbfe6 fix(subagent): prevent instant failure from already-dead parent contexts
- Replace detachedWithCancel (goroutine-based) with context.WithoutCancel
  + valuesContext; the old goroutine would fire immediately if the parent
  was already cancelled/deadline-exceeded, causing 'failed after 0s'
- Kit.Subagent() pre-flight: if the incoming ctx is already done, reset
  to context.Background() before applying the subagent timeout
- Both Subagent() error paths now return a non-nil *SubagentResult with
  Elapsed set, so the tool response always shows accurate timing
- Narrow viperInitMu scope in Kit.New(): snapshot viper state + call
  BuildProviderConfig under the lock, then release before SetupAgent /
  MCP loading; parallel subagent spawns no longer serialise on viper I/O
- AgentSetupOptions gains ProviderConfig + scalar fields so SetupAgent
  can skip viper reads when a pre-built config is supplied
- Add subagent_test.go covering the fixed context detachment behaviour
2026-04-02 15:54:47 +03:00
Ed Zynda 685aaf207f feat(extensions): add hot-reload with file watching and /reload-ext command
- Add fsnotify-based file watcher that auto-reloads extensions on .go
  file changes in autoloaded dirs with 300ms debounce
- Add /reload-ext built-in command (alias /re) for manual reload
- Add Agent.SetExtraTools() so extension tools update on reload
  instead of being baked in at agent creation time
- Run reload async via tea.Cmd to avoid prog.Send() deadlock when
  extension handlers call ctx.Print() during SessionStart/Shutdown
- Wire watcher lifecycle into cmd/root.go with graceful shutdown
2026-04-02 15:41:54 +03:00
Ed Zynda 76ff6c9639 style(ui): segment KITT scanner LEDs and center logo text
- Break scanner bar into individual LED segments with single-space gaps
- Center KIT text over the scanner bar (13-space indent for all lines)
- Maintain original 46-char total width for the scanner bar
2026-04-02 15:11:01 +03:00
Ed Zynda 1cf24ee5de fix(core): return error when read tool is used on a directory
- Return an error response guiding the agent to use ls instead
- Remove unused readDirectory helper function
2026-04-02 14:45:33 +03:00
Ed Zynda c9637090fa feat(subagent): return early error for invalid model instead of silent fallback
- Add ValidateModelString() to ModelsRegistry for format, provider,
  and model name validation with typo suggestions
- Validate model in Kit.Subagent() before expensive Kit.New() setup
- Remove silent fallback to parent model on creation failure
- Error propagates as tool result so calling agent can self-correct
- Add registry_test.go covering format, provider, and suggestion cases
2026-04-02 14:45:03 +03:00
Ed Zynda 0ff0ff42ab fix(ui): wrap tool error output in caution alert block
Prevent tool error text from spilling into the surrounding layout
by rendering it inside a herald Caution alert container.
2026-04-02 14:39:29 +03:00
Ed Zynda a4fb32ff2b feat(ui): add reusable PopupList and render /model as overlay
- Add PopupList: generic themed popup with fuzzy search, scrolling,
  keyboard navigation, and centered overlay rendering
- Refactor ModelSelectorComponent to delegate to PopupList instead
  of implementing its own full-screen rendering and input handling
- Render /model selector as a centered overlay on top of the chat
  view instead of replacing the entire screen
- PopupList accepts a pluggable FilterFunc for domain-specific
  fuzzy matching (model selector wires its own scoring)
- Add 11 tests for PopupList covering navigation, search, selection,
  cancellation, filtering, rendering, and edge cases
2026-04-02 14:39:21 +03:00
Ed Zynda 7d2f078111 fix(ui): freeze reasoning counter when last token is processed
- Wire fantasy's OnReasoningEnd callback through the full event chain:
  agent → SDK (ReasoningCompleteEvent) → app → TUI
- Freeze reasoning duration in both StreamComponent and
  StreamingMessageItem as soon as reasoning ends, not when the
  next assistant text chunk arrives
- Fix accent color on duration label in render.ReasoningBlock to
  match the live streaming style (VeryMuted prefix + Accent duration)
2026-04-02 14:18:42 +03:00
Ed Zynda b0b66941ab fix(extensions): batch go-edit-lint per turn and fix OnAgentEnd StopReason docs
- Refactor go-edit-lint to collect edited .go files during the agent
  turn via OnToolResult, then run gopls + golangci-lint once in
  OnAgentEnd instead of after every individual edit/write call
- Use ctx.SendMessage() to inject diagnostics as a follow-up prompt
  when issues are found, replacing the old tool-result rewriting
- Show a green 'all clean' block when no issues are detected
- Fix StopReason docs in skills/kit-extensions/SKILL.md: the value is
  'error' on failure, 'completed' when the LLM returns empty, or the
  raw provider value (e.g. 'stop', 'end_turn') passed through — not
  the previously documented 'completed'/'cancelled'/'error' enum
2026-04-02 14:04:41 +03:00
Ed Zynda cbb7387a72 fix(test): add return after t.Fatal to silence SA5011 nil-deref warnings
- internal/ui/model_test.go: bashItem nil check
- pkg/extensions/test/harness_test.go: footer and result nil checks
2026-04-01 21:24:02 +03:00
Ed Zynda 19430b0ecb chore(ui): remove dead toast and clipboard code
Remove 8 unused exports from clipboard package:
- CopyToClipboardWithMessage, IsClipboardSupported
- ToastMsg, ToastType, ToastInfo, ToastSuccess, ToastWarning, ToastError

These were remnants of a toast notification feature that was never
wired up. No callers exist anywhere in the codebase.
2026-04-01 21:11:00 +03:00
Ed Zynda 8e3cfeede5 fix(ui): correct mouse selection Y-offset for reasoning blocks
The getItemAndLineAtY() method was using item.Height() which returns 0
for reasoning blocks (StreamingMessageItem with role='reasoning') because
their render cache is intentionally never populated (they include a live
duration timer).

This caused all items below a reasoning block to have incorrect Y
coordinates — clicking on the reasoning text would highlight the
assistant text below it instead.

Two fixes:
1. getItemAndLineAtY() now uses renderedHeight() which calls Render()
   and counts lines — matching exactly what View() does. This is the
   single source of truth for item height during hit-testing.

2. StreamingMessageItem.Height() now falls back to Render(0) when
   cachedRender is empty, fixing the same issue for other callers
   (GotoBottom, ScrollBy, clampOffset, etc.).
2026-04-01 18:15:04 +03:00
Ed Zynda 4fa5775974 feat(ui): implement character-level mouse text selection and copy
Implement crush-style mouse text selection with character-level precision,
replacing the previously disabled stub implementation.

Architecture:
- New selection package (internal/ui/selection/) handles all coordinate
  math, word boundary detection, and cell-level ANSI text manipulation
- ScrollList upgraded with proper mouse down/drag/up flow supporting
  single click (character drag), double click (word), triple click (line)
- Model.go wires BubbleTea mouse events through to ScrollList with
  proper viewport Y-offset adjustment for the scrollback area

Key features:
- Character-level selection using ultraviolet ScreenBuffer for ANSI-aware
  cell parsing — correctly handles styled text, emoji, CJK wide chars
- Word selection via UAX#29 Unicode segmentation (clipperhouse/uax29)
- Display-width-aware columns via clipperhouse/displaywidth (not bytes)
- Dual clipboard: OSC 52 (remote terminals) + native (atotto/clipboard)
- Multi-click detection with 400ms threshold and 2px tolerance
- Mouse event throttling via existing MouseModeCellMotion
- Selection cleared on any keypress for clean UX

Dependencies (all already indirect in go.mod):
- github.com/charmbracelet/ultraviolet (ScreenBuffer, cell manipulation)
- github.com/charmbracelet/x/ansi (ANSI strip, StringWidth)
- github.com/clipperhouse/displaywidth (grapheme display width)
- github.com/clipperhouse/uax29/v2 (Unicode word segmentation)
2026-04-01 18:05:48 +03:00
Ed Zynda 4e7d823ee4 feat(ui): make /fork create new session file matching Pi behavior
- Add ForkToNewSession method to create new session with history up to target
- Add NewTreeSelectorForFork showing only user messages (flat list)
- Update performFork to create and switch to new session file
- Update /fork command description in docs and help text

Previously /fork just branched within the same session file like /tree.
Now /fork creates a completely new session file with parent_session reference,
matching Pi's behavior exactly.
2026-04-01 16:10:55 +03:00
Ed Zynda 7a16c76adc fix(ui): trim whitespace when loading session messages to prevent empty blocks
When loading session history, some assistant messages contain text parts
with only whitespace (e.g., single space ' '). These were being rendered
as empty message blocks, causing extra vertical spacing in the UI.

Fix by trimming whitespace from message content before checking if it's
non-empty in renderSessionHistory().

Changes:
- Apply strings.TrimSpace() to user message content before rendering
- Apply strings.TrimSpace() to assistant message content before rendering

This prevents empty/whitespace-only message blocks from being added to
the scrollback when resuming sessions.
2026-04-01 15:11:42 +03:00
Ed Zynda 70a21ee73a refactor(ui): extract shared message rendering functions
Extract pure rendering functions into internal/ui/render/blocks.go
to eliminate code duplication between streaming and historical
message rendering paths.

Changes:
- Create render package with UserBlock, AssistantBlock, ReasoningBlock,
  SystemBlock, ErrorBlock, and ToolBlock functions
- Update MessageRenderer methods to use shared render functions
- Update StreamingMessageItem to use shared render functions
- Reduce ~77 lines of duplicated code across message_items.go and messages.go

All existing tests pass, no functional changes.
2026-04-01 14:59:27 +03:00
Ed Zynda 28d2de8f39 Phase 1: Reorganize UI leaf utilities into subpackages
Moved leaf utility files to subpackages for better organization:
- events.go -> core/ (core message types)
- clipboard.go -> clipboard/ (clipboard operations)
- commands.go -> commands/ (slash commands)
- file_processor.go -> fileutil/ (file attachment processing)
- preferences.go -> prefs/ (theme/model preferences)
- enhanced_styles.go, styles.go, themes.go -> style/ (theming system)

Added exports.go to re-export commonly used types for backward
compatibility. External importers can still use ui.XXX without
changes.

All tests pass, basic smoke test successful.
2026-04-01 13:54:10 +03:00
Ed Zynda 7f192ae850 feat(ui): improve slash command popup contrast with full-width backgrounds
- Change border from MutedBorder to Primary for visibility
- Add full-width background styles for all popup items
- Use inverse colors for selected item (primary bg, background fg)
- Add background to scroll indicators and footer
- Add bottom margin for visual depth/shadow effect
2026-04-01 13:35:20 +03:00
Ed Zynda 9f6746ded9 fix(ui): re-enable auto-scroll on new message submission
Auto-scroll was being disabled when users manually scrolled (mouse wheel,
PgUp, etc.) but never re-enabled. Now it reactivates when submitting a
new message so the conversation view jumps to the bottom to show the
latest content.
2026-04-01 13:29:40 +03:00
Ed Zynda 7514d3a0ff chore(deps): update go and npm dependencies
- github.com/indaco/herald v0.10.0 → v0.11.0
- github.com/indaco/herald-md v0.1.0 → v0.2.0
- google.golang.org/api v0.273.0 → v0.273.1
- google.golang.org/genai v1.52.0 → v1.52.1
- google.golang.org/grpc v1.79.3 → v1.80.0
- gonum.org/v1/gonum v0.16.0 → v0.17.0
- add npm and www package-lock.json files
2026-04-01 13:24:36 +03:00
Ed Zynda c83281a52b docs: add feature-request prompt for GitHub feature requests
Add a dedicated /feature-request prompt that guides users through creating
well-formed feature requests using the GitHub feature_request template.

The prompt focuses on:
- Problem-first description
- Clear motivation and use cases
- Optional proposed implementation
- Conventional commit-style titles (feat: ...)

Usage: /feature-request <description of the feature>
2026-04-01 13:22:14 +03:00
Ed Zynda 4515bb92c2 docs: update file-issue prompt to use GitHub issue templates
The file-issue prompt now references the structured GitHub issue templates
(bug_report, feature_request, documentation) and guides users to use the
--template flag with gh issue create for consistent issue formatting.
2026-04-01 13:21:20 +03:00
Ed Zynda e326b84204 chore: add GitHub issue templates and file-issue prompt
Add structured GitHub issue templates for:
- Bug reports (with reproduction steps, code, component)
- Feature requests (with motivation and proposed implementation)
- Documentation issues

Also add a /file-issue kit prompt for quickly filing issues from the TUI.

The templates enforce conventional commit-style titles and include
checklists to ensure issues are well-formed before submission.
2026-04-01 13:20:43 +03:00
Ed Zynda 1b93049b8e fix(ui): remove j/k navigation from fuzzy selectors
Remove 'j' and 'k' keybindings from model, session, and tree selectors
to allow typing those characters for fuzzy filtering. Navigation now
uses only arrow keys (↑/↓) which matches the existing help text.
2026-04-01 13:11:44 +03:00
Ed Zynda 4912449dda fix(ui): render selectors in alt screen buffer
Fix /resume, /model, and /tree selectors to render in the alternate
screen buffer instead of terminal scrollback. All three selector
components now set AltScreen=true on their tea.View returns.
2026-04-01 13:09:23 +03:00
Ed Zynda b70cce4f34 refactor(ui): remove pre-alt-screen dead code and boilerplate
- Remove scrollbackBuf, appendScrollback(), drainScrollback() and all
  call sites — the entire terminal scrollback pipeline was dead code
  since the alt screen migration
- Remove StreamComponent.render(), renderCache, renderDirty,
  scrollbackFlushedLines, viewContent(), and ConsumeOverflow() body —
  rendering is now handled by StreamingMessageItem in the ScrollList
- Remove SetHeight and ConsumeOverflow from streamComponentIface since
  height is managed by ScrollList and overflow is a no-op
- Remove redundant AltScreen/MouseMode/ReportFocus/KeyboardEnhancements
  boilerplate from 6 child View() methods — parent already sets these
- Convert two orphan appendScrollback calls (extension default text,
  shell command output) to proper ScrollList message items
- Update ~30 stale comments referencing tea.Println and scrollback buffer
2026-04-01 01:13:19 +03:00
Ed Zynda 4c566836b2 refactor(ui): move startup banner into ScrollList, fix /resume rendering
- Render ASCII logo and startup info exclusively in the ScrollList
  instead of printing to stdout/terminal scrollback
- Remove PrintStartupInfo() and move kitBanner() to ui.KitBanner()
- Fix separator spacing: use single pre-rendered item with embedded
  blank lines to avoid left-border artifacts on spacing rows
- Rewrite renderSessionHistory() to populate ScrollList with proper
  MessageItems instead of legacy appendScrollback() calls
- Clear m.messages on /clear, /new, and /resume so the ScrollList
  resets correctly when switching sessions
- Add pendingGotoBottom flag to defer scroll-to-bottom until after
  distributeHeight() recalculates the correct viewport height
- Fix pre-existing test failures: initialize scrollList in test helper,
  update 5 tests from tea.Println assertions to ScrollList checks
2026-04-01 00:39:32 +03:00
Ed Zynda bb3261883a Add visual separator after startup info in ScrollList
Added a horizontal rule (────) with blank lines above and below to
visually separate the startup info from the conversation history.

The separator uses theme.Border color and spans 80 characters, providing
a clear visual break between startup messages and the chat content.

This makes it easier to distinguish where the conversation starts when
scrolling back through history.
2026-03-31 19:07:56 +03:00
Ed Zynda 512d0f16ce Show startup info in ScrollList (alt screen mode)
Added AddStartupMessageToScrollList() method that renders startup info
(model, context, skills, extensions, MCP tools) and extension startup
messages as system messages in the ScrollList.

This ensures startup info is visible and scrollable in alt screen mode,
rather than being printed before BubbleTea starts and becoming hidden
when alt screen takes over.

Changes:
- AppModelOptions: Added StartupExtensionMessages field
- AppModel: Store and render startup messages in Init()
- AddStartupMessageToScrollList(): Renders startup info + extension messages
- cmd/root.go: Pass startupExtensionMessages to NewAppModel

The startup info now appears at the top of conversation history and can
be scrolled back to at any time.
2026-03-31 19:03:21 +03:00
Ed Zynda 8159431ce4 Prevent scrolling past bottom of content in ScrollList
Enhanced clampOffset() to detect when the viewport has scrolled past the
bottom of the content (would show empty space) and automatically reposition
to show the last line of content at the bottom of the viewport.

This prevents the 'floating' effect where multiple PgDn or scroll down
operations would push content off the top while showing blank space below.

The clamping logic:
1. Calculates total content height
2. If content fits in viewport, forces position to top
3. Otherwise, checks if remaining content < viewport height
4. If so, repositions to show exactly the last line at viewport bottom

Also updated clampOffset to use rendered height calculation (handles
non-cached items like reasoning blocks) instead of cached Height().
2026-03-31 18:56:18 +03:00
Ed Zynda 9f9f265fb3 Fix autoscroll for streaming messages (iteratr pattern)
Root cause: GotoBottom() was calculating heights using Height() which returns
0 for non-cached items. Reasoning blocks never cache renders due to live
duration updates, causing incorrect scroll calculations during reasoning →
assistant transitions.

Fix: Calculate heights directly from rendered strings instead of relying on
cached Height() values. This ensures accurate scroll positioning for all
message types.

Changes:
- ScrollList.GotoBottom(): Render items and calculate height from string
- ScrollList.AtBottom(): Same pattern for bottom detection
- appendStreamingChunk(): Call GotoBottom() directly for existing messages
- refreshContent(): Remove redundant GotoBottom() (handled by SetItems)

Tested with 'explore this repo' prompt - autoscroll now works correctly
throughout reasoning and assistant streaming phases.
2026-03-31 18:53:18 +03:00
110 changed files with 16194 additions and 2716 deletions
+79
View File
@@ -0,0 +1,79 @@
name: Bug Report
description: Report a bug or issue with Kit
title: "fix: "
labels: ["bug"]
body:
- type: textarea
id: description
attributes:
label: Bug Description
description: What happened? What did you expect to happen?
placeholder: |
The BorderColor field in ToolRenderConfig is documented but never applied
during tool rendering. I expected the tool block to render with my custom
color, but it uses the default styling instead.
validations:
required: true
- type: textarea
id: reproduction
attributes:
label: Steps to Reproduce
description: Provide clear steps to reproduce the issue
placeholder: |
1. Create an extension with `api.RegisterToolRenderer(ext.ToolRenderConfig{...})`
2. Set `BorderColor: "#89b4fa"` in the config
3. Run a tool that uses this renderer
4. Observe the border color is not applied
render: markdown
validations:
required: true
- type: textarea
id: code
attributes:
label: Relevant Code / Configuration
description: Paste any code, configuration, or error messages
placeholder: |
```go
api.RegisterToolRenderer(ext.ToolRenderConfig{
ToolName: "bash",
DisplayName: "Shell",
BorderColor: "#a6e3a1", // This is ignored!
Background: "#1e1e2e", // This is ignored!
})
```
render: go
- type: input
id: component
attributes:
label: Affected Component
description: Which part of Kit is affected?
placeholder: e.g., extensions, ui, tool rendering, session management
- type: input
id: version
attributes:
label: Kit Version
description: What version of Kit are you running?
placeholder: e.g., v0.1.0, commit hash, or "main"
- type: textarea
id: context
attributes:
label: Additional Context
description: Any other context, proposed fixes, or related issues
placeholder: |
The issue appears to be in `internal/ui/messages.go:RenderToolMessage()`
which ignores the BorderColor and Background fields from ToolRendererData.
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I've searched existing issues and this hasn't been reported yet
required: true
- label: I've tested with the latest version of Kit
required: false
+11
View File
@@ -0,0 +1,11 @@
blank_issues_enabled: false
contact_links:
- name: Kit Documentation
url: https://github.com/mark3labs/kit/tree/main/www/pages
about: Check the documentation before filing an issue
- name: Extension Examples
url: https://github.com/mark3labs/kit/tree/main/examples/extensions
about: See working extension examples for reference
- name: Discussions
url: https://github.com/mark3labs/kit/discussions
about: For questions, ideas, or general discussion
+40
View File
@@ -0,0 +1,40 @@
name: Documentation Issue
description: Report missing, incorrect, or unclear documentation
title: "docs: "
labels: ["documentation"]
body:
- type: textarea
id: description
attributes:
label: Documentation Issue
description: What's wrong or missing in the documentation?
placeholder: |
The ToolRenderConfig documentation mentions BorderColor and Background fields,
but the code doesn't actually use them. The docs should either be updated
to reflect reality, or the bug should be fixed.
validations:
required: true
- type: input
id: location
attributes:
label: Documentation Location
description: Where is the affected documentation?
placeholder: e.g., README.md, examples/extensions/tool-renderer-demo.go, pkg/kit docs
- type: textarea
id: suggestion
attributes:
label: Suggested Improvement
description: How should the documentation be improved?
placeholder: |
Add a note that BorderColor and Background are not yet implemented,
or fix the bug and document the correct behavior.
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I've checked that this documentation issue still exists in the latest version
required: true
@@ -0,0 +1,64 @@
name: Feature Request
description: Suggest a new feature or enhancement for Kit
title: "feat: "
labels: ["enhancement"]
body:
- type: textarea
id: description
attributes:
label: Feature Description
description: What would you like to see added or changed?
placeholder: |
I'd like to be able to customize the border color of tool result blocks
dynamically based on the tool type or result status.
validations:
required: true
- type: textarea
id: motivation
attributes:
label: Motivation / Use Case
description: Why is this feature needed? What problem does it solve?
placeholder: |
When running multiple tools in sequence, it's hard to visually distinguish
between file reads (blue), shell commands (green), and errors (red)
without custom border colors.
validations:
required: true
- type: textarea
id: proposed
attributes:
label: Proposed Implementation
description: How do you think this should work? (optional)
placeholder: |
Extend `ToolRenderConfig` to accept a function that receives the tool
result and returns a color based on the content:
```go
BorderColorFunc: func(result string, isError bool) string {
if isError {
return "#f38ba8"
}
return "#89b4fa"
}
```
render: go
- type: checkboxes
id: alternatives
attributes:
label: Alternatives Considered
options:
- label: I've considered workarounds or alternative approaches
required: false
- type: checkboxes
id: terms
attributes:
label: Checklist
options:
- label: I've searched existing issues and this hasn't been requested yet
required: true
- label: This feature aligns with Kit's design philosophy (TUI-first, extension-based)
required: false
+74 -34
View File
@@ -28,11 +28,15 @@ type lintResult struct {
Err error
}
// Package-level state: set of .go files edited during the current agent turn.
var editedFiles map[string]bool
func Init(api ext.API) {
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint on Go file edits")
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
})
// Track edited .go files — don't lint yet.
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
if e.IsError || !isEditOrWrite(e.ToolName) {
return nil
@@ -43,30 +47,72 @@ func Init(api ext.API) {
return nil
}
report := runGoDiagnostics(ctx.CWD, absPath)
// Check if there are issues and add explicit prompt for the LLM to react
goplsIssues, lintIssues := countIssues(report)
hasIssues := goplsIssues > 0 || lintIssues > 0
var enhanced string
if hasIssues {
enhanced = e.Content + "\n\n" + report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review the issues above and fix them before proceeding."
} else {
enhanced = e.Content + "\n\n" + report
if editedFiles == nil {
editedFiles = make(map[string]bool)
}
editedFiles[absPath] = true
return nil
})
// After the agent turn ends, lint all collected files.
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
if len(editedFiles) == 0 {
return
}
// Show TUI message block for diagnostics visibility (only if there are issues)
// Snapshot and reset immediately so the next turn starts clean.
files := editedFiles
editedFiles = nil
// Skip lint on errored turns.
if e.StopReason == "error" {
return
}
// Collect unique directories and file list for gopls.
var allGoplsOutput []string
for absPath := range files {
res := runGopls(ctx.CWD, absPath)
formatted := formatToolResult(res, "")
if formatted != "" {
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
}
}
lintRes := runGolangCILint(ctx.CWD, "./...")
goplsSection := "No diagnostics."
if len(allGoplsOutput) > 0 {
goplsSection = strings.Join(allGoplsOutput, "\n\n")
}
lintSection := formatToolResult(lintRes, "No lint issues.")
// Build file list for the report header.
var fileNames []string
for absPath := range files {
fileNames = append(fileNames, filepath.Base(absPath))
}
report := fmt.Sprintf(
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
strings.Join(fileNames, ", "),
goplsSection,
lintSection,
)
goplsIssues, lintIssues := countIssues(report)
hasIssues := goplsIssues > 0 || lintIssues > 0
if hasIssues {
// Show TUI block so the user sees it too.
var msgLines []string
msgLines = append(msgLines, fmt.Sprintf("File: %s", filepath.Base(absPath)))
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
if goplsIssues > 0 {
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
}
if lintIssues > 0 {
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
}
msgLines = append(msgLines, "", "⚠️ Please fix these issues before proceeding.")
borderColor := "#f9e2af" // yellow
if goplsIssues > 0 && lintIssues > 0 {
@@ -78,9 +124,16 @@ func Init(api ext.API) {
BorderColor: borderColor,
Subtitle: "go-edit-lint",
})
}
return &ext.ToolResultResult{Content: &enhanced}
// Inject a follow-up message so the agent fixes the issues.
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
} else {
ctx.PrintBlock(ext.PrintBlockOpts{
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
BorderColor: "#a6e3a1",
Subtitle: "go-edit-lint",
})
}
})
}
@@ -106,18 +159,6 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
return absPath, true
}
func runGoDiagnostics(cwd, absPath string) string {
gopls := runGopls(cwd, absPath)
lint := runGolangCILint(cwd, "./...")
return fmt.Sprintf(
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
filepath.Base(absPath),
formatToolResult(gopls, "No diagnostics."),
formatToolResult(lint, "No lint issues."),
)
}
func runGopls(cwd, absPath string) lintResult {
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
defer cancel()
@@ -178,7 +219,9 @@ func formatToolResult(res lintResult, emptyFallback string) string {
out := strings.TrimSpace(res.Output)
if out == "" {
if res.Err == nil {
lines = append(lines, emptyFallback)
if emptyFallback != "" {
lines = append(lines, emptyFallback)
}
}
} else {
lines = append(lines, out)
@@ -197,17 +240,15 @@ func truncate(s string, max int) string {
}
func countIssues(report string) (goplsCount, lintCount int) {
// Extract gopls section
goplsStart := strings.Index(report, "[gopls]")
lintStart := strings.Index(report, "[golangci-lint]")
endTag := strings.Index(report, "</go_diagnostics>")
if goplsStart != -1 && lintStart != -1 {
goplsSection := report[goplsStart:lintStart]
// Count non-empty lines excluding the header and "No diagnostics." message
for _, line := range strings.Split(goplsSection, "\n") {
line = strings.TrimSpace(line)
if line != "" && line != "[gopls]" && line != "No diagnostics." {
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
goplsCount++
}
}
@@ -215,7 +256,6 @@ func countIssues(report string) (goplsCount, lintCount int) {
if lintStart != -1 && endTag != -1 {
lintSection := report[lintStart:endTag]
// Count non-empty lines excluding the header and "No lint issues." message
for _, line := range strings.Split(lintSection, "\n") {
line = strings.TrimSpace(line)
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
+86
View File
@@ -0,0 +1,86 @@
---
description: Create a feature request using the GitHub template
---
Create a feature request for the Kit repository. The user wants to request: $@
## Feature Request Template
This prompt uses the `feature_request` GitHub template which requires:
| Field | Required | Purpose |
|-------|----------|---------|
| **Feature Description** | Yes | What should be added or changed |
| **Motivation / Use Case** | Yes | Why is this needed? What problem does it solve? |
| **Proposed Implementation** | No | How do you think this should work? |
## Steps
1. **Understand the request** from `$@`
- What capability is missing?
- What would the ideal behavior look like?
2. **Ask clarifying questions** if needed:
- "What problem does this solve for you?"
- "How would you expect this to work?"
- "Are there similar features in other tools you use?"
3. **Craft the title** using conventional format:
- `feat: <short description>`
- Lowercase, imperative mood, ≤72 chars
- Good examples:
- `feat: add keyboard shortcut for clearing input`
- `feat: support custom themes per extension`
- `feat: add fuzzy matching to model selector`
- Bad examples:
- `Feature request: can we have...` (too vague)
- `It would be nice if...` (not imperative)
4. **Build the body** with the template fields:
**Feature Description:**
- Clear statement of what to add/change
- Be specific about the behavior
- Include UI/UX details if relevant
**Motivation / Use Case:**
- What problem does this solve?
- Current workaround (if any) and why it's insufficient
- Who benefits from this feature?
**Proposed Implementation** (optional but helpful):
- High-level approach
- API changes if applicable
- Example usage code
5. **Create the issue**:
```bash
gh issue create --template feature_request --title "feat: ..." --body "..."
```
6. **Confirm success**:
- Show the issue URL and number
- Mention it was created with the feature_request template
## Guidelines
- Focus on the *problem* first, then the solution
- Include concrete examples of how the feature would be used
- Consider edge cases and mention them
- If proposing API changes, show before/after code
- Check if similar features exist in related tools (mention them for reference)
- Align with Kit's philosophy: TUI-first, extension-based, keyboard-driven
## Example
User: `/feature-request I want to be able to customize tool border colors dynamically`
You:
1. Title: `feat: dynamic border colors for tool results based on status`
2. Body:
- **Feature Description**: Allow `ToolRenderConfig` to accept a function that determines border color based on tool result content or status, enabling dynamic visual feedback.
- **Motivation**: When running multiple tools, it's hard to distinguish file reads (blue), shell commands (green), and errors (red) without custom colors per result.
- **Proposed Implementation**: Add `BorderColorFunc` callback that receives `(result string, isError bool)` and returns a color string.
3. Execute: `gh issue create --template feature_request --title "feat: ..." --body "..."`
4. Confirm: Created issue #43 using feature_request template
+100
View File
@@ -0,0 +1,100 @@
---
description: File a GitHub issue using the appropriate template
---
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
## Issue Templates Available
This repository has structured issue templates. You MUST use the appropriate template:
| Type | Template | Use For |
|------|----------|---------|
| `bug` | `bug_report` | Something is broken, not working as expected |
| `feat` | `feature_request` | New feature, enhancement, improvement |
| `docs` | `documentation` | Missing, incorrect, or unclear documentation |
## Steps
1. **Determine the issue type** from `$@`:
- Bug → use `--template bug_report`
- Feature → use `--template feature_request`
- Documentation → use `--template documentation`
2. **Ask clarifying questions** if critical info is missing:
- For bugs: "What were you doing when this happened?" (reproduction steps)
- For features: "What problem does this solve?" (motivation)
- For docs: "Where did you look for this information?" (location)
3. **Craft the title** using conventional format:
- `<type>: <short description>`
- Lowercase, imperative mood, ≤72 chars
- Examples:
- `fix: ToolRenderConfig BorderColor ignored during rendering`
- `feat: add keyboard shortcut for clearing input`
- `docs: clarify extension widget lifecycle`
4. **File the issue** using the template:
```bash
# For bugs
gh issue create --template bug_report --title "fix: ..." --body "..."
# For features
gh issue create --template feature_request --title "feat: ..." --body "..."
# For documentation
gh issue create --template documentation --title "docs: ..." --body "..."
```
The template will guide the user through the required fields. You need to provide:
- **Bug reports**: Description, reproduction steps, expected vs actual behavior
- **Feature requests**: Description, motivation/use case, optional proposed implementation
- **Documentation**: Description, location of docs, suggested improvement
5. **Confirm success** by showing:
- The issue URL
- The issue number
- Which template was used
## Template Field Guide
### Bug Report (`bug_report`)
Required fields in the body:
- **Bug Description** - what happened vs expected
- **Steps to Reproduce** - numbered list to recreate the bug
- **Relevant Code** - code snippets, configuration, error messages
- **Component** - which part of Kit (ui, extensions, session, etc.)
- **Version** - Kit version or commit hash
### Feature Request (`feature_request`)
Required fields in the body:
- **Feature Description** - what to add/change
- **Motivation / Use Case** - why this is needed
- **Proposed Implementation** - how it could work (optional)
### Documentation (`documentation`)
Required fields in the body:
- **Documentation Issue** - what's wrong or missing
- **Documentation Location** - file or URL where docs exist
- **Suggested Improvement** - how to fix the docs
## Guidelines
- ALWAYS use `--template <name>` instead of bare `gh issue create`
- Include file paths and line numbers when you know them
- Use triple backticks for code blocks
- Keep the body factual - avoid speculation unless in "Proposed Fix" section
- If you're unsure about technical details, say so in the issue
- For UI bugs, describe what you see vs what you expect
- For API bugs, include the relevant struct/function names
## Example Usage
User: `/file-issue The ToolRenderConfig BorderColor field is documented but never used in rendering`
You:
1. Determine this is a **bug** (documented field doesn't work)
2. Use `--template bug_report`
3. Gather: reproduction steps (register renderer with BorderColor), expected (custom color), actual (default color)
4. Create issue with title `fix: ToolRenderConfig BorderColor and Background fields are ignored`
5. Confirm: Created issue #42 using bug_report template
+80
View File
@@ -0,0 +1,80 @@
# Autoscroll Fix - Final Summary
## Root Cause
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
### The Problem
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
2. The `Height()` method returns `0` when `cachedRender == ""`
3. `GotoBottom()` was calling:
```go
itemHeight := item.Height() // Returns 0 for reasoning
if itemHeight == 0 {
item.Render(s.width) // Renders but doesn't cache (reasoning)
itemHeight = item.Height() // Still returns 0!
}
```
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
## The Solution
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
```go
// OLD: item.Height() which checks cached render
itemHeight := item.Height()
if itemHeight == 0 {
item.Render(s.width)
itemHeight = item.Height() // Still might be 0!
}
// NEW: Calculate from rendered string directly
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
```
This works for **all** items regardless of whether they cache their render or not.
## Files Changed
### `internal/ui/scrolllist.go`
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
### `internal/ui/model.go`
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
## Testing Results
✅ **Test prompt**: "explore this repo"
**Before fix**:
- Autoscroll stopped after reasoning block completed
- Viewport stuck showing end of reasoning ("Thought for 203ms")
- Assistant response streamed off-screen below
**After fix**:
- Autoscroll works throughout reasoning block
- Autoscroll continues during reasoning → assistant transition
- Viewport stays at bottom showing latest assistant content
- Final position shows end of response (build commands section)
## Behavior Verified
1. ✅ Streaming text auto-scrolls to bottom
2. ✅ Works across reasoning → assistant transition
3. ✅ Manual scroll up (PgUp) disables autoscroll
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
5. ✅ Accurate positioning with no offset errors
## Performance Note
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
- `Render()` is already optimized with caching for non-reasoning items
- `GotoBottom()` is only called during content updates (not every frame)
- Reasoning blocks need to render anyway for live duration updates
- This matches iteratr's approach of ensuring items are rendered before height calculations
+29 -2
View File
@@ -477,7 +477,7 @@ During an interactive session, use these slash commands:
| `/import <path>` | Import and switch to a session from a JSONL file |
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
| `/tree` | Navigate the session tree |
| `/fork` | Branch from an earlier message |
| `/fork` | Fork to new session from an earlier message |
| `/new` | Start a fresh session |
## Go SDK
@@ -531,7 +531,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
Tools: []kit.Tool{...}, // Replace default tool set entirely
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -540,6 +545,28 @@ host, err := kit.New(ctx, &kit.Options{
})
```
### Custom Tools
Create custom tools with automatic schema generation — no external dependencies needed:
```go
type SearchInput struct {
Query string `json:"query" description:"Search query"`
}
searchTool := kit.NewTool("search", "Search the codebase",
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
return kit.TextResult("Found: ..."), nil
},
)
host, _ := kit.New(ctx, &kit.Options{
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
})
```
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
### With Callbacks
```go
+312 -94
View File
@@ -7,10 +7,10 @@ import (
"image/color"
"log"
"os"
"path/filepath"
"strings"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/config"
@@ -18,6 +18,8 @@ import (
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/prompts"
"github.com/mark3labs/kit/internal/ui"
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/watcher"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -48,12 +50,14 @@ var (
noSessionFlag bool // --no-session: ephemeral mode, no persistence
// Model generation parameters
maxTokens int
temperature float32
topP float32
topK int32
stopSequences []string
thinkingLevel string
maxTokens int
temperature float32
topP float32
topK int32
frequencyPenalty float32
presencePenalty float32
stopSequences []string
thinkingLevel string
// Ollama-specific parameters
numGPU int32
@@ -154,6 +158,9 @@ func InitConfig() {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
// Rebuild the model registry now that viper has the config loaded,
// so customModels defined in the config file are picked up.
models.ReloadGlobalRegistry()
}
// LoadConfigWithEnvSubstitution loads a config file with environment variable
@@ -217,29 +224,10 @@ func configToUiTheme(cfg config.Theme) ui.Theme {
}
}
// kitBanner returns the KIT ASCII art title with KITT scanner lights,
// rendered with a KITT red gradient.
// kitBanner returns the KIT ASCII art title with KITT scanner lights.
// Delegates to ui.KitBanner() which owns the logo rendering.
func kitBanner() string {
kittDark := lipgloss.Color("#8B0000")
kittBright := lipgloss.Color("#FF2200")
lines := []string{
" ██╗ ██╗ ██╗ ████████╗",
" ██║ ██╔╝ ██║ ╚══██╔══╝",
" █████╔╝ ██║ ██║",
" ██╔═██╗ ██║ ██║",
" ██║ ██╗ ██║ ██║",
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
" ░░░░░░▒▒▒▒▒▓▓▓▓███████████████▓▓▓▓▒▒▒▒▒░░░░░░",
}
var result strings.Builder
for i, line := range lines {
if i > 0 {
result.WriteString("\n")
}
result.WriteString(ui.ApplyGradient(line, kittDark, kittBright))
}
return result.String()
return ui.KitBanner()
}
func init() {
@@ -307,6 +295,8 @@ func init() {
flags.Float32Var(&temperature, "temperature", 0.7, "controls randomness in responses (0.0-1.0)")
flags.Float32Var(&topP, "top-p", 0.95, "controls diversity via nucleus sampling (0.0-1.0)")
flags.Int32Var(&topK, "top-k", 40, "controls diversity by limiting top K tokens to sample from")
flags.Float32Var(&frequencyPenalty, "frequency-penalty", 0.0, "penalizes tokens based on frequency of appearance (0.0-2.0)")
flags.Float32Var(&presencePenalty, "presence-penalty", 0.0, "penalizes tokens based on whether they have appeared (0.0-2.0)")
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, minimal, low, medium, high")
@@ -329,6 +319,8 @@ func init() {
_ = viper.BindPFlag("temperature", rootCmd.PersistentFlags().Lookup("temperature"))
_ = viper.BindPFlag("top-p", rootCmd.PersistentFlags().Lookup("top-p"))
_ = viper.BindPFlag("top-k", rootCmd.PersistentFlags().Lookup("top-k"))
_ = viper.BindPFlag("frequency-penalty", rootCmd.PersistentFlags().Lookup("frequency-penalty"))
_ = viper.BindPFlag("presence-penalty", rootCmd.PersistentFlags().Lookup("presence-penalty"))
_ = viper.BindPFlag("stop-sequences", rootCmd.PersistentFlags().Lookup("stop-sequences"))
_ = viper.BindPFlag("thinking-level", rootCmd.PersistentFlags().Lookup("thinking-level"))
_ = viper.BindPFlag("num-gpu-layers", rootCmd.PersistentFlags().Lookup("num-gpu-layers"))
@@ -406,21 +398,21 @@ func runKit(ctx context.Context) error {
}
// extensionCommandsForUI converts extension-registered CommandDefs into the
// ui.ExtensionCommand type used by the interactive TUI. Command names are
// commands.ExtensionCommand type used by the interactive TUI. Command names are
// normalised to start with "/" so they integrate with the slash-command
// autocomplete and dispatch pipeline.
func extensionCommandsForUI(k *kit.Kit) []ui.ExtensionCommand {
func extensionCommandsForUI(k *kit.Kit) []commands.ExtensionCommand {
defs := k.Extensions().Commands()
if len(defs) == 0 {
return nil
}
cmds := make([]ui.ExtensionCommand, 0, len(defs))
cmds := make([]commands.ExtensionCommand, 0, len(defs))
for _, d := range defs {
name := d.Name
if len(name) > 0 && name[0] != '/' {
name = "/" + name
}
ec := ui.ExtensionCommand{
ec := commands.ExtensionCommand{
Name: name,
Description: d.Description,
Execute: func(args string) (string, error) {
@@ -733,13 +725,33 @@ func runNormalMode(ctx context.Context) error {
// Build Kit options from CLI flags and create the SDK instance.
// kit.New() handles: config → skills → agent → session → extension bridge.
authHandler, authErr := kit.NewCLIMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote MCP servers.
fmt.Fprintf(os.Stderr, "Warning: Failed to create OAuth handler: %v\n", authErr)
}
// appInstancePtr is used to break the circular dependency between
// kit.New (which needs the OnMCPServerLoaded callback) and app.New
// (which is needed by the callback to send events to the TUI).
var appInstancePtr *app.App
kitOpts := &kit.Options{
Quiet: quietFlag,
Debug: debugMode,
NoSession: noSessionFlag,
Continue: continueFlag,
SessionPath: sessionPath,
AutoCompact: autoCompactFlag,
Quiet: quietFlag,
Debug: debugMode,
NoSession: noSessionFlag,
Continue: continueFlag,
SessionPath: sessionPath,
AutoCompact: autoCompactFlag,
MCPAuthHandler: authHandler,
// This callback is called when each MCP server finishes loading.
// We use a closure that captures appInstancePtr which is set after
// app.New() is called below.
OnMCPServerLoaded: func(serverName string, toolCount int, err error) {
if appInstancePtr != nil {
appInstancePtr.NotifyMCPServerLoaded(serverName, toolCount, err)
}
},
CLI: &kit.CLIOptions{
MCPConfig: mcpConfig,
ShowSpinner: true,
@@ -810,8 +822,16 @@ func runNormalMode(ctx context.Context) error {
}
appInstance := app.New(appOpts, messages)
appInstancePtr = appInstance // Wire up the MCP server loaded callback.
defer appInstance.Close()
// Wire OAuth handler to route messages through the TUI once it's running.
if authHandler != nil {
authHandler.NotifyFunc = func(serverName, message string) {
appInstance.PrintFromExtension("info", message)
}
}
// Buffer for extension messages during startup (printed after startup banner).
var startupExtensionMessages []string
@@ -835,7 +855,37 @@ func runNormalMode(ctx context.Context) error {
PrintBlock: appInstance.PrintBlockFromExtension,
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Exit: func() { appInstance.QuitFromExtension() },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
@@ -1256,7 +1306,37 @@ func runNormalMode(ctx context.Context) error {
PrintBlock: appInstance.PrintBlockFromExtension,
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Exit: func() { appInstance.QuitFromExtension() },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
@@ -1556,6 +1636,49 @@ func runNormalMode(ctx context.Context) error {
})
}
// Build prompt template and skill item provider callbacks for hot-reload.
// These are called by the TUI when ContentReloadEvent fires.
getPromptTemplates := func() []*prompts.PromptTemplate {
if noPromptTemplates {
return nil
}
homeDir, _ := os.UserHomeDir()
cwd, _ := os.Getwd()
tpls, _, err := prompts.LoadAll(prompts.LoadOptions{
Cwd: cwd,
HomeDir: homeDir,
ExtraPaths: promptTemplatePaths,
ConfigPaths: viper.GetStringSlice("prompts"),
IncludeDefaults: true,
})
if err != nil {
log.Printf("Warning: failed to reload prompt templates: %v", err)
}
return tpls
}
getSkillItems := func() []ui.SkillItem {
// Re-discover skills from disk.
if err := kitInstance.ReloadSkills(); err != nil {
log.Printf("Warning: failed to reload skills: %v", err)
return nil
}
cwd, _ := os.Getwd()
var items []ui.SkillItem
for _, s := range kitInstance.GetSkills() {
source := "user"
if strings.HasPrefix(s.Path, cwd) {
source = "project"
}
items = append(items, ui.SkillItem{
Name: s.Name,
Path: s.Path,
Source: source,
})
}
return items
}
// Build extension UI providers once (shared between both modes).
getWidgets := widgetProviderForUI(kitInstance)
getHeader := headerProviderForUI(kitInstance)
@@ -1567,10 +1690,29 @@ func runNormalMode(ctx context.Context) error {
emitBeforeFork := beforeForkProviderForUI(kitInstance)
emitBeforeSessionSwitch := beforeSessionSwitchProviderForUI(kitInstance)
getGlobalShortcuts := globalShortcutsProviderForUI(kitInstance)
getExtensionCommands := func() []ui.ExtensionCommand {
getExtensionCommands := func() []commands.ExtensionCommand {
return extensionCommandsForUI(kitInstance)
}
// Build dynamic tool name and MCP tool count providers. These are called
// by the TUI when MCPToolsReadyEvent fires to refresh the /tools list
// and startup info bar after background MCP tool loading completes.
getToolNames := func() []string {
return kitInstance.GetToolNames()
}
getMCPToolCount := func() int {
return kitInstance.GetMCPToolCount()
}
// Start a goroutine that waits for background MCP tool loading to
// complete and notifies the TUI so it can refresh tool names and counts.
if len(mcpConfig.MCPServers) > 0 {
go func() {
_ = kitInstance.WaitForMCPTools()
appInstance.NotifyMCPToolsReady()
}()
}
// Build model switching callbacks for the /model command.
setModelForUI := func(modelString string) error {
err := kitInstance.SetModel(context.Background(), modelString)
@@ -1624,9 +1766,81 @@ func runNormalMode(ctx context.Context) error {
return nil
}
// Build extension reload callback for the /reload-ext command.
reloadExtensionsForUI := func() error {
err := kitInstance.Extensions().Reload()
if err != nil {
return err
}
go appInstance.NotifyWidgetUpdate()
return nil
}
// Start file watcher for automatic extension hot-reload.
extraPaths := viper.GetStringSlice("extension")
watchDirs := extensions.WatchedDirs(extraPaths)
if len(watchDirs) > 0 {
extWatcher, watchErr := extensions.NewWatcher(watchDirs, func() {
if err := reloadExtensionsForUI(); err != nil {
log.Printf("auto-reload extensions failed: %v", err)
}
})
if watchErr != nil {
log.Printf("extension file watcher not started: %v", watchErr)
} else {
go extWatcher.Start(ctx)
defer func() { _ = extWatcher.Close() }()
}
}
// Start file watchers for automatic prompt and skill hot-reload.
{
homeDir, _ := os.UserHomeDir()
cwd, _ := os.Getwd()
// Collect prompt template directories.
promptDirs := watcher.CollectDirs(
[]string{
filepath.Join(homeDir, ".kit", "prompts"),
filepath.Join(cwd, ".kit", "prompts"),
},
append(promptTemplatePaths, viper.GetStringSlice("prompts")...),
)
// Collect skill directories.
skillDirs := watcher.CollectDirs(
[]string{
filepath.Join(homeDir, ".config", "kit", "skills"),
filepath.Join(cwd, ".agents", "skills"),
filepath.Join(cwd, ".kit", "skills"),
},
nil,
)
// Combine all content directories and start a single watcher.
allContentDirs := append(promptDirs, skillDirs...)
if len(allContentDirs) > 0 {
contentWatcher, watchErr := watcher.New(watcher.Options{
Dirs: allContentDirs,
Extensions: []string{".md", ".txt"},
Label: "prompts/skills",
OnReload: func() {
log.Printf("auto-reloading prompts and skills")
appInstance.NotifyContentReload()
},
})
if watchErr != nil {
log.Printf("content file watcher not started: %v", watchErr)
} else {
go contentWatcher.Start(ctx)
defer func() { _ = contentWatcher.Close() }()
}
}
}
// Check if running in non-interactive mode
if positionalPrompt != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI)
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
}
// Quiet mode is not allowed in interactive mode
@@ -1634,7 +1848,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet requires a prompt")
}
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, startupExtensionMessages)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
}
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
@@ -1647,7 +1861,7 @@ func runNormalMode(ctx context.Context) error {
//
// When --no-exit is set, after the prompt completes the interactive BubbleTea
// TUI is started so the user can continue the conversation.
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error) error {
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
// Expand @file references in the prompt before sending to the agent.
if cwd, err := os.Getwd(); err == nil {
prompt = ui.ProcessFileAttachments(prompt, cwd)
@@ -1690,7 +1904,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
// If --no-exit was requested, hand off to the interactive TUI.
if noExit {
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, nil)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
}
return nil
@@ -1788,7 +2002,19 @@ func writeJSONError(err error) {
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []ui.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []ui.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, startupExtensionMessages []string) error {
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
// background goroutines (watchers, extension handlers, SDK internals)
// will visually corrupt the terminal.
logDir := filepath.Join(os.TempDir(), "kit")
_ = os.MkdirAll(logDir, 0o700)
logFile, logErr := tea.LogToFile(filepath.Join(logDir, "kit.log"), "kit")
if logErr == nil {
defer func() { _ = logFile.Close() }()
}
// Determine terminal size; fall back gracefully.
termWidth, termHeight, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil || termWidth == 0 {
@@ -1799,55 +2025,47 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
cwd, _ := os.Getwd()
appModel := ui.NewAppModel(appInstance, ui.AppModelOptions{
ModelName: modelName,
ProviderName: providerName,
LoadingMessage: loadingMessage,
Cwd: cwd,
Width: termWidth,
Height: termHeight,
ServerNames: serverNames,
ToolNames: toolNames,
MCPToolCount: mcpToolCount,
ExtensionToolCount: extensionToolCount,
UsageTracker: usageTracker,
ExtensionCommands: extCommands,
PromptTemplates: promptTemplates,
ContextPaths: contextPaths,
SkillItems: skillItems,
GetWidgets: getWidgets,
GetHeader: getHeader,
GetFooter: getFooter,
GetToolRenderer: getToolRenderer,
GetEditorInterceptor: getEditorInterceptor,
GetUIVisibility: getUIVisibility,
GetStatusBarEntries: getStatusBarEntries,
EmitBeforeFork: emitBeforeFork,
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
GetGlobalShortcuts: getGlobalShortcuts,
GetExtensionCommands: getExtensionCommands,
SetModel: setModel,
EmitModelChange: emitModelChange,
ThinkingLevel: thinkingLevel,
IsReasoningModel: isReasoningModel,
SetThinkingLevel: setThinkingLevel,
SwitchSession: switchSession,
ShowSessionPicker: resumeFlag,
ModelName: modelName,
ProviderName: providerName,
LoadingMessage: loadingMessage,
Cwd: cwd,
Width: termWidth,
Height: termHeight,
ServerNames: serverNames,
ToolNames: toolNames,
GetToolNames: getToolNames,
GetMCPToolCount: getMCPToolCount,
MCPToolCount: mcpToolCount,
ExtensionToolCount: extensionToolCount,
UsageTracker: usageTracker,
ExtensionCommands: extCommands,
PromptTemplates: promptTemplates,
GetPromptTemplates: getPromptTemplates,
ContextPaths: contextPaths,
SkillItems: skillItems,
GetSkillItems: getSkillItems,
StartupExtensionMessages: startupExtensionMessages,
GetWidgets: getWidgets,
GetHeader: getHeader,
GetFooter: getFooter,
GetToolRenderer: getToolRenderer,
GetEditorInterceptor: getEditorInterceptor,
GetUIVisibility: getUIVisibility,
GetStatusBarEntries: getStatusBarEntries,
EmitBeforeFork: emitBeforeFork,
EmitBeforeSessionSwitch: emitBeforeSessionSwitch,
GetGlobalShortcuts: getGlobalShortcuts,
GetExtensionCommands: getExtensionCommands,
SetModel: setModel,
EmitModelChange: emitModelChange,
ThinkingLevel: thinkingLevel,
IsReasoningModel: isReasoningModel,
SetThinkingLevel: setThinkingLevel,
SwitchSession: switchSession,
ReloadExtensions: reloadExtensions,
ShowSessionPicker: resumeFlag,
})
// Print KIT banner and startup info to stdout before Bubble Tea takes over the screen.
fmt.Println(kitBanner())
fmt.Println()
appModel.PrintStartupInfo()
// Print any extension messages that were captured during startup.
if len(startupExtensionMessages) > 0 {
fmt.Println()
for _, msg := range startupExtensionMessages {
fmt.Println(msg)
}
fmt.Println()
}
program := tea.NewProgram(appModel)
// Register the program with the app layer so agent events are sent to the TUI.
+6 -4
View File
@@ -7,10 +7,12 @@
// development: edit your extension source, then type /reload to pick up
// changes immediately.
//
// Event handlers, slash commands, tool renderers, message renderers, and
// keyboard shortcuts update immediately. Extension-defined tools are NOT
// updated (they are baked into the agent at creation time and require a
// restart).
// Note: Extensions in autoloaded directories (~/.config/kit/extensions/
// and .kit/extensions/) are automatically reloaded on save. The /reload
// command is useful for extensions loaded via -e from other locations.
//
// Event handlers, slash commands, tool definitions, tool renderers,
// message renderers, and keyboard shortcuts all update immediately.
//
// Commands:
// /reload — hot-reload all extensions from disk
+74 -1
View File
@@ -168,6 +168,10 @@ var (
// Test
pendingTest *PendingTest
// Typing indicator
typingTicker *time.Ticker
typingStop chan struct{}
// Latest context for background goroutines
latestCtx ext.Context
latestCtxSet bool
@@ -203,8 +207,23 @@ func configDir() string {
return filepath.Join(home, ".config", "kit")
}
func globalConfigDir() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".config", "kit")
}
func configPath() string {
return filepath.Join(configDir(), "kit-telegram.json")
// Prefer project-local config, fall back to global config.
local := filepath.Join(configDir(), "kit-telegram.json")
if _, err := os.Stat(local); err == nil {
return local
}
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
if _, err := os.Stat(global); err == nil {
return global
}
// Neither exists — return local path (will be created on connect).
return local
}
func failureLogDir() string {
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
return &msg, nil
}
func tgSendChatAction(token string, chatID int64, action string) error {
_, err := telegramRequest(token, "sendChatAction", map[string]any{
"chat_id": chatID,
"action": action,
}, 15)
return err
}
// ──────────────────────────────────────────────
// Error classification
// ──────────────────────────────────────────────
@@ -637,6 +664,48 @@ func clearHealthTimer() {
}
}
// ──────────────────────────────────────────────
// Typing indicator
// ──────────────────────────────────────────────
func startTypingLoop() {
mu.Lock()
defer mu.Unlock()
if typingTicker != nil {
return
}
cfg := config
if cfg == nil || !cfg.Enabled {
return
}
token := cfg.BotToken
chatID := cfg.ChatID
typingTicker = time.NewTicker(4 * time.Second)
typingStop = make(chan struct{})
// Send immediately, then every 4 seconds.
go func() {
tgSendChatAction(token, chatID, "typing")
for {
select {
case <-typingTicker.C:
tgSendChatAction(token, chatID, "typing")
case <-typingStop:
return
}
}
}()
}
func stopTypingLoop() {
mu.Lock()
defer mu.Unlock()
if typingTicker != nil {
typingTicker.Stop()
close(typingStop)
typingTicker = nil
}
}
// ──────────────────────────────────────────────
// Polling lifecycle
// ──────────────────────────────────────────────
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
mu.Unlock()
sendShutdownDisconnectedMessage()
stopTypingLoop()
stopPolling()
clearHealthTimer()
clearFooter()
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
mu.Unlock()
report("run.start", fmt.Sprintf("runId=%d", run.ID))
startTypingLoop()
ensureProgressMessage()
updateProgressMessage()
})
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
run := activeRun
mu.Unlock()
stopTypingLoop()
if run != nil {
// Capture final response from event
if e.Response != "" {
+24 -25
View File
@@ -9,13 +9,19 @@ require (
charm.land/huh/v2 v2.0.3
charm.land/lipgloss/v2 v2.0.2
github.com/alecthomas/chroma/v2 v2.23.1
github.com/atotto/clipboard v0.1.4
github.com/aymanbagabas/go-udiff v0.4.1
github.com/charmbracelet/fang v1.0.0
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b
github.com/clipperhouse/displaywidth v0.11.0
github.com/clipperhouse/uax29/v2 v2.7.0
github.com/coder/acp-go-sdk v0.6.3
github.com/indaco/herald v0.10.0
github.com/indaco/herald-md v0.1.0
github.com/mark3labs/mcp-go v0.46.0
github.com/fsnotify/fsnotify v1.9.0
github.com/indaco/herald v0.13.0
github.com/indaco/herald-md v0.3.0
github.com/mark3labs/mcp-go v0.47.0
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/traefik/yaegi v0.16.1
@@ -29,12 +35,11 @@ require (
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/atotto/clipboard v0.1.4 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.13 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
@@ -42,18 +47,16 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
github.com/aws/smithy-go v1.24.2 // indirect
github.com/aws/smithy-go v1.24.3 // indirect
github.com/catppuccin/go v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
github.com/charmbracelet/colorprofile v0.4.3 // indirect
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260330094520-2dce04b6f8a4 // indirect
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
@@ -62,25 +65,21 @@ require (
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.11.0 // indirect
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kaptinlin/go-i18n v0.3.0 // indirect
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
@@ -106,19 +105,19 @@ require (
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
go.opentelemetry.io/otel v1.42.0 // indirect
go.opentelemetry.io/otel/metric v1.42.0 // indirect
go.opentelemetry.io/otel/trace v1.42.0 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/api v0.273.0 // indirect
google.golang.org/genai v1.52.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
google.golang.org/grpc v1.79.3 // indirect
google.golang.org/api v0.274.0 // indirect
google.golang.org/genai v1.52.1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
@@ -130,7 +129,7 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.21 // indirect
github.com/mattn/go-runewidth v0.0.22 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
+42 -42
View File
@@ -18,12 +18,12 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
@@ -38,10 +38,10 @@ github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg=
github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY=
github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk=
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
@@ -56,14 +56,14 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3x
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
@@ -173,18 +173,18 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/indaco/herald v0.10.0 h1:XzahEKX6cr50qZQrUdA3QrQBHg8uGm5jETD0UDi21BI=
github.com/indaco/herald v0.10.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.1.0 h1:zmYudYo+uamzKTBcIffJVJYrqk9xDNnVrTh+de2zciw=
github.com/indaco/herald-md v0.1.0/go.mod h1:Z1HxPCbSn+/+TFzOM/UbsmKeEk/28NNI6JOTileKXto=
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
github.com/kaptinlin/go-i18n v0.3.0 h1:wP76dvYg04bvwTb+8NB+CmdZ2kL7lSSCQ9B/kFv7QHo=
github.com/kaptinlin/go-i18n v0.3.0/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
@@ -201,12 +201,12 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mark3labs/mcp-go v0.47.0 h1:h44yeM3DduDyQgzImYWu4pt6VRkqP/0p/95AGhWngnA=
github.com/mark3labs/mcp-go v0.47.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-runewidth v0.0.22 h1:76lXsPn6FyHtTY+jt2fTTvsMUCZq1k0qwRsAMuxzKAk=
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -276,16 +276,16 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.6
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
@@ -307,20 +307,20 @@ golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/genai v1.52.0 h1:ekVIxWHtLUNbt+v0WWi4j3JT4yrHDEbysMcHQcaCQoI=
google.golang.org/genai v1.52.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA=
google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+248 -108
View File
@@ -25,11 +25,21 @@ type AgentConfig struct {
StreamingEnabled bool
DebugLogger tools.DebugLogger
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
// If nil, remote MCP servers that require OAuth will fail to connect.
AuthHandler tools.MCPAuthHandler
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. This allows SDK users to provide a custom tool set (e.g.
// CodingTools or tools with a custom WorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper is an optional function that wraps the combined tool list
// before it is passed to the LLM agent. Used by the extensions system
// to intercept tool calls/results.
@@ -38,6 +48,11 @@ type AgentConfig struct {
// ExtraTools are additional tools to include alongside core and MCP tools.
// Used by extensions to register custom tools.
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). The callback receives the server
// name, tool count, and any error. Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// ToolCallHandler is a function type for handling tool calls as they happen.
@@ -63,6 +78,10 @@ type ToolCallContentHandler func(content string)
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
type ReasoningDeltaHandler func(delta string)
// ReasoningCompleteHandler is a function type for handling reasoning/thinking completion.
// Called when the last reasoning token has been processed, before text streaming starts.
type ReasoningCompleteHandler func()
// ToolOutputHandler is a function type for handling streaming tool output chunks.
// Used by tools like bash to stream output as it arrives rather than waiting
// for the command to complete. The isStderr flag indicates if the chunk
@@ -79,6 +98,10 @@ type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCrea
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
// AgentTool implementations — no MCP layer, no serialization overhead.
// Additional tools from external MCP servers can be loaded alongside core tools.
//
// When MCP servers are configured, tool loading happens in the background so the
// agent (and UI) can start immediately. The first LLM call automatically waits
// for MCP tools to finish loading before proceeding.
type Agent struct {
toolManager *tools.MCPToolManager
fantasyAgent fantasy.Agent
@@ -92,6 +115,18 @@ type Agent struct {
coreTools []fantasy.AgentTool
extraTools []fantasy.AgentTool
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool // stored for SetModel rebuild
// providerOptions and modelConfig are stored for rebuilding the fantasy
// agent when MCP tools arrive asynchronously or on SetModel.
providerOptions fantasy.ProviderOptions
skipMaxOutputTokens bool
modelConfig *models.ProviderConfig
// mcpReady is closed when background MCP tool loading completes (success
// or failure). nil when no MCP servers are configured.
mcpReady chan struct{}
// mcpErr holds any error from background MCP loading.
mcpErr error
}
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
@@ -110,7 +145,10 @@ type GenerateWithLoopResult struct {
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
// External MCP tools are loaded from the config if any MCP servers are configured.
// If MCP servers are configured, their tools are loaded in the background —
// the agent returns immediately and is usable with core tools only. The first
// LLM call (GenerateWithLoop) automatically waits for MCP tools to finish
// loading and rebuilds the agent with the full tool set.
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Create the LLM provider
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
@@ -120,34 +158,22 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
// Register core tools (direct AgentTool implementations, no MCP overhead).
// Use caller-provided tools if set, otherwise default to all core tools.
coreTools := agentConfig.CoreTools
if len(coreTools) == 0 {
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
var coreTools []fantasy.AgentTool
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
// Explicitly zero tools - chat-only mode
coreTools = nil
} else if len(agentConfig.CoreTools) > 0 {
// Custom tools provided - use them
coreTools = agentConfig.CoreTools
} else {
// Default: load all core tools
coreTools = core.AllTools()
}
// Build the combined tool list: core tools + any external MCP tools
// Build the initial tool list: core tools + extension tools (no MCP yet).
allTools := make([]fantasy.AgentTool, len(coreTools))
copy(allTools, coreTools)
// Load external MCP tools if configured
var toolManager *tools.MCPToolManager
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
toolManager = tools.NewMCPToolManager()
toolManager.SetModel(providerResult.Model)
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
// MCP tool loading failures are non-fatal; core tools still work
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
} else {
mcpTools := toolManager.GetTools()
allTools = append(allTools, mcpTools...)
}
}
// Append any extra tools provided by extensions.
if len(agentConfig.ExtraTools) > 0 {
allTools = append(allTools, agentConfig.ExtraTools...)
@@ -159,6 +185,144 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
}
// Build agent options
agentOpts := buildAgentOptions(agentConfig, providerResult, allTools)
// Create the agent
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Determine provider type from model string
providerType := "default"
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
providerType = p
}
}
a := &Agent{
fantasyAgent: fantasyAgent,
model: providerResult.Model,
providerCloser: providerResult.Closer,
maxSteps: agentConfig.MaxSteps,
systemPrompt: agentConfig.SystemPrompt,
loadingMessage: providerResult.Message,
providerType: providerType,
streamingEnabled: agentConfig.StreamingEnabled,
coreTools: coreTools,
extraTools: agentConfig.ExtraTools,
toolWrapper: agentConfig.ToolWrapper,
providerOptions: providerResult.ProviderOptions,
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
modelConfig: agentConfig.ModelConfig,
}
// Start MCP tool loading in the background if servers are configured.
// The mcpReady channel is closed when loading completes (success or failure).
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
toolManager := tools.NewMCPToolManager()
toolManager.SetModel(providerResult.Model)
if agentConfig.AuthHandler != nil {
toolManager.SetAuthHandler(agentConfig.AuthHandler)
}
if agentConfig.DebugLogger != nil {
toolManager.SetDebugLogger(agentConfig.DebugLogger)
}
// Set per-server loaded callback if provided.
if agentConfig.OnMCPServerLoaded != nil {
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
}
a.toolManager = toolManager
a.mcpReady = make(chan struct{})
go func() {
defer close(a.mcpReady)
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
a.mcpErr = err
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
}
}()
}
return a, nil
}
// WaitForMCPTools blocks until background MCP tool loading completes.
// Returns nil if no MCP servers are configured or if loading succeeded.
// Returns the loading error if all servers failed. Safe to call multiple times.
func (a *Agent) WaitForMCPTools() error {
if a.mcpReady == nil {
return nil
}
<-a.mcpReady
return a.mcpErr
}
// MCPToolsReady returns true if MCP tool loading has completed (or was never
// started). This is a non-blocking check useful for UI status display.
func (a *Agent) MCPToolsReady() bool {
if a.mcpReady == nil {
return true
}
select {
case <-a.mcpReady:
return true
default:
return false
}
}
// ensureMCPTools waits for MCP tools to load and rebuilds the fantasy agent
// with the full tool set. Called lazily before the first LLM call.
// This is idempotent — subsequent calls after the first rebuild are no-ops.
func (a *Agent) ensureMCPTools() {
if a.mcpReady == nil {
return
}
<-a.mcpReady
// If there are MCP tools, rebuild the fantasy agent to include them.
if a.toolManager != nil && len(a.toolManager.GetTools()) > 0 {
a.rebuildFantasyAgent()
}
// Nil out the channel so future calls are instant no-ops and we
// don't rebuild again.
a.mcpReady = nil
}
// rebuildFantasyAgent reconstructs the fantasy agent with the current full
// tool set (core + MCP + extension tools). Used after MCP tools arrive
// asynchronously and by SetModel.
func (a *Agent) rebuildFantasyAgent() {
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
providerResult := &models.ProviderResult{
Model: a.model,
ProviderOptions: a.providerOptions,
SkipMaxOutputTokens: a.skipMaxOutputTokens,
}
agentOpts := buildAgentOptions(&AgentConfig{
ModelConfig: a.modelConfig,
SystemPrompt: a.systemPrompt,
MaxSteps: a.maxSteps,
}, providerResult, allTools)
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
}
// buildAgentOptions constructs the fantasy.AgentOption slice from config,
// provider result, and the combined tool list. Shared by NewAgent,
// rebuildFantasyAgent, and SetModel.
func buildAgentOptions(agentConfig *AgentConfig, providerResult *models.ProviderResult, allTools []fantasy.AgentTool) []fantasy.AgentOption {
var agentOpts []fantasy.AgentOption
if agentConfig.SystemPrompt != "" {
@@ -196,33 +360,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.ModelConfig.TopK != nil {
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
}
}
// Create the agent
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Determine provider type from model string
providerType := "default"
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
providerType = p
if agentConfig.ModelConfig.FrequencyPenalty != nil {
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
}
if agentConfig.ModelConfig.PresencePenalty != nil {
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
}
}
return &Agent{
toolManager: toolManager,
fantasyAgent: fantasyAgent,
model: providerResult.Model,
providerCloser: providerResult.Closer,
maxSteps: agentConfig.MaxSteps,
systemPrompt: agentConfig.SystemPrompt,
loadingMessage: providerResult.Message,
providerType: providerType,
streamingEnabled: agentConfig.StreamingEnabled,
coreTools: coreTools,
extraTools: agentConfig.ExtraTools,
toolWrapper: agentConfig.ToolWrapper,
}, nil
return agentOpts
}
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
@@ -231,7 +377,7 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
onResponse, onToolCallContent, nil, nil, nil, nil)
onResponse, onToolCallContent, nil, nil, nil, nil, nil)
}
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
@@ -242,10 +388,16 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
onStreamingResponse StreamingResponseHandler,
onReasoningDelta ReasoningDeltaHandler,
onReasoningComplete ReasoningCompleteHandler,
onToolOutput ToolOutputHandler,
onStepUsage StepUsageHandler,
) (*GenerateWithLoopResult, error) {
// Wait for background MCP tool loading to complete and rebuild the
// fantasy agent with the full tool set. This is a no-op when no MCP
// servers are configured or tools have already been integrated.
a.ensureMCPTools()
// Inject tool output handler into context for use by core tools (e.g., bash).
if onToolOutput != nil {
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
@@ -295,6 +447,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return nil
},
// Reasoning/thinking complete callback
OnReasoningEnd: func(id string, _ fantasy.ReasoningContent) error {
if ctx.Err() != nil {
return ctx.Err()
}
if onReasoningComplete != nil {
onReasoningComplete()
}
return nil
},
// Text streaming callback
OnTextDelta: func(id, text string) error {
if ctx.Err() != nil {
@@ -381,7 +544,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
opts fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
// Drain all pending steer messages (non-blocking).
var steered []string
var steered []SteerMessage
for {
select {
case msg := <-steerCh:
@@ -398,9 +561,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
if len(steered) > 0 {
// Inject each steer message as a user message so the
// LLM sees the redirection on the next step.
for _, text := range steered {
for _, sm := range steered {
result.Messages = append(result.Messages,
fantasy.NewUserMessage(text))
fantasy.NewUserMessage(sm.Text, sm.Files...))
}
// Notify that steer messages were consumed.
if onConsumed != nil {
@@ -435,9 +598,12 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return nil, err
}
// Fire the response callback for callers that use it (e.g. non-streaming
// callers that still want the final response notification).
if onResponse != nil && result.Response.Content.Text() != "" {
// Fire the response callback so callers (e.g. the TUI) can reset
// streaming state. This must fire even when the response text is
// empty (e.g. reasoning-only responses) so the UI properly resets
// the stream component and avoids duplicate content on the next
// flush.
if onResponse != nil {
onResponse(result.Response.Content.Text())
}
@@ -454,8 +620,9 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return nil, err
}
// For non-streaming, fire the response callback with the final text
if onResponse != nil && result.Response.Content.Text() != "" {
// For non-streaming, fire the response callback so callers can reset
// streaming state (see streaming path comment above).
if onResponse != nil {
onResponse(result.Response.Content.Text())
}
@@ -623,6 +790,14 @@ func (a *Agent) GetExtensionToolCount() int {
return len(a.extraTools)
}
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
// tools) and rebuilds the internal agent with the updated tool list. The
// model, system prompt, and all other configuration are preserved.
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
a.extraTools = extraTools
a.rebuildFantasyAgent()
}
// GetLoadingMessage returns the loading message from provider creation.
func (a *Agent) GetLoadingMessage() string {
return a.loadingMessage
@@ -640,60 +815,14 @@ func (a *Agent) GetLoadedServerNames() []string {
// system prompt, and configuration are preserved. The old provider is closed
// if it has a closer. Returns the previous model string for notification.
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
// before the first LLM call).
a.ensureMCPTools()
providerResult, err := models.CreateProvider(ctx, config)
if err != nil {
return fmt.Errorf("failed to create model provider: %v", err)
}
// Rebuild tool list (same as NewAgent).
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
}
if a.toolWrapper != nil {
allTools = a.toolWrapper(allTools)
}
// Rebuild agent options.
var agentOpts []fantasy.AgentOption
if a.systemPrompt != "" {
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
}
if len(allTools) > 0 {
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
}
if a.maxSteps > 0 {
agentOpts = append(agentOpts, fantasy.WithStopConditions(
fantasy.StepCountIs(a.maxSteps),
))
}
// Pass provider-specific options (e.g. OpenAI Responses API reasoning settings).
if providerResult.ProviderOptions != nil {
agentOpts = append(agentOpts, fantasy.WithProviderOptions(providerResult.ProviderOptions))
}
// Pass generation parameters when available.
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
}
if config.Temperature != nil {
agentOpts = append(agentOpts, fantasy.WithTemperature(float64(*config.Temperature)))
}
if config.TopP != nil {
agentOpts = append(agentOpts, fantasy.WithTopP(float64(*config.TopP)))
}
if config.TopK != nil {
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
}
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
// Close old provider.
if a.providerCloser != nil {
_ = a.providerCloser.Close()
@@ -705,9 +834,11 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
}
// Swap fields.
a.fantasyAgent = newFantasyAgent
a.model = providerResult.Model
a.providerCloser = providerResult.Closer
a.providerOptions = providerResult.ProviderOptions
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
a.modelConfig = config
// Update provider type.
if config.ModelString != "" {
@@ -716,6 +847,9 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
}
}
// Rebuild the fantasy agent with the new model and current tool set.
a.rebuildFantasyAgent()
return nil
}
@@ -725,7 +859,13 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
}
// Close closes the agent and cleans up resources.
// If MCP tools are still loading in the background, Close waits for them
// to finish before closing connections to avoid resource leaks.
func (a *Agent) Close() error {
// Wait for background MCP loading to finish before closing connections.
if a.mcpReady != nil {
<-a.mcpReady
}
var toolErr error
if a.toolManager != nil {
toolErr = a.toolManager.Close()
+21 -9
View File
@@ -36,13 +36,22 @@ type AgentCreationOptions struct {
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
// DebugLogger is an optional logger for debugging MCP communications
DebugLogger tools.DebugLogger // Optional debug logger
// AuthHandler handles OAuth authorization for remote MCP servers
AuthHandler tools.MCPAuthHandler
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used.
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ToolWrapper wraps the combined tool list before agent creation.
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ExtraTools are additional tools to include (e.g. from extensions).
ExtraTools []fantasy.AgentTool
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// CreateAgent creates an agent with optional spinner for Ollama models.
@@ -50,15 +59,18 @@ type AgentCreationOptions struct {
// Returns the created agent or an error if creation fails.
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
agentConfig := &AgentConfig{
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
CoreTools: opts.CoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
ModelConfig: opts.ModelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: opts.SystemPrompt,
MaxSteps: opts.MaxSteps,
StreamingEnabled: opts.StreamingEnabled,
DebugLogger: opts.DebugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
var agent *Agent
+15 -4
View File
@@ -1,6 +1,17 @@
package agent
import "context"
import (
"context"
"charm.land/fantasy"
)
// SteerMessage carries a steering prompt and optional file attachments
// (e.g. clipboard images) through the steer channel.
type SteerMessage struct {
Text string
Files []fantasy.FilePart
}
// steerChKey is the context key for the steer channel.
type steerChKey struct{}
@@ -11,7 +22,7 @@ type steerConsumedKey struct{}
// ContextWithSteerCh returns a new context with the steer channel attached.
// The agent's PrepareStep function checks this channel between steps and
// injects any pending steer messages as user messages before the next LLM call.
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
func ContextWithSteerCh(ctx context.Context, ch <-chan SteerMessage) context.Context {
return context.WithValue(ctx, steerChKey{}, ch)
}
@@ -23,8 +34,8 @@ func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.C
}
// steerChFromContext extracts the steer channel from the context, or nil.
func steerChFromContext(ctx context.Context) <-chan string {
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
func steerChFromContext(ctx context.Context) <-chan SteerMessage {
ch, _ := ctx.Value(steerChKey{}).(<-chan SteerMessage)
return ch
}
+147 -5
View File
@@ -162,6 +162,24 @@ func (a *App) CancelCurrentStep() {
cancel()
}
// IsBusy returns true when the agent is currently processing a turn.
func (a *App) IsBusy() bool {
a.mu.Lock()
defer a.mu.Unlock()
return a.busy
}
// Abort cancels the current agent step (if running) and clears the queue.
// Unlike InterruptAndSend, no new message is injected — the agent simply
// stops. Safe to call when idle (no-op).
func (a *App) Abort() {
a.mu.Lock()
a.queue = a.queue[:0]
cancel := a.cancelStep
a.mu.Unlock()
cancel()
}
// QueueLength returns the number of prompts currently waiting in the queue.
//
// Satisfies ui.AppController.
@@ -187,6 +205,15 @@ func (a *App) QueueLength() int {
//
// Satisfies ui.AppController.
func (a *App) Steer(prompt string) int {
return a.SteerWithFiles(prompt, nil)
}
// SteerWithFiles injects a steering message with optional file attachments
// (e.g. pasted images) into the currently running agent turn. Behaves like
// Steer but includes file parts alongside the text.
//
// Satisfies ui.AppController.
func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
a.mu.Lock()
if a.closed {
@@ -195,8 +222,8 @@ func (a *App) Steer(prompt string) int {
}
if !a.busy {
// Not busy — start immediately, same as Run().
item := queueItem{Prompt: prompt}
// Not busy — start immediately, same as RunWithFiles().
item := queueItem{Prompt: prompt, Files: files}
a.busy = true
a.wg.Add(1)
a.mu.Unlock()
@@ -211,7 +238,7 @@ func (a *App) Steer(prompt string) int {
// execution, before next LLM call). If PrepareStep doesn't fire
// (text-only response), drainQueue will pick it up after the turn.
if a.opts.Kit != nil {
a.opts.Kit.InjectSteer(prompt)
a.opts.Kit.InjectSteerWithFiles(prompt, files)
}
return 1
}
@@ -390,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
return nil
}
// CompactAsync is like CompactConversation but calls onComplete/onError
// callbacks instead of sending TUI events. Used by the extension API's
// ctx.Compact() which needs callback-based notification.
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
a.mu.Lock()
if a.closed {
a.mu.Unlock()
return fmt.Errorf("app is closed")
}
if a.busy {
a.mu.Unlock()
return fmt.Errorf("cannot compact while the agent is working")
}
if a.opts.Kit == nil {
a.mu.Unlock()
return fmt.Errorf("SDK instance not available")
}
a.busy = true
a.wg.Add(1)
a.mu.Unlock()
go func() {
defer a.wg.Done()
defer func() {
a.mu.Lock()
a.busy = false
a.mu.Unlock()
}()
// Subscribe to SDK events for streaming compaction summary to the TUI.
sendFn := func(msg tea.Msg) {
if a.program != nil {
a.program.Send(msg)
}
}
unsub := a.subscribeSDKEvents(sendFn, nil)
defer unsub()
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
if err != nil {
a.sendEvent(CompactErrorEvent{Err: err})
if onError != nil {
onError(err.Error())
}
return
}
if result == nil {
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
if onError != nil {
onError("nothing to compact")
}
return
}
// Sync in-memory store with the compacted session.
if a.opts.TreeSession != nil {
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
}
a.sendEvent(CompactCompleteEvent{
Summary: result.Summary,
OriginalTokens: result.OriginalTokens,
CompactedTokens: result.CompactedTokens,
MessagesRemoved: result.MessagesRemoved,
})
if onComplete != nil {
onComplete()
}
}()
return nil
}
// --------------------------------------------------------------------------
// Non-interactive execution
// --------------------------------------------------------------------------
@@ -530,8 +629,8 @@ func (a *App) drainQueue(first queueItem) {
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
a.mu.Lock()
steerItems := make([]queueItem, len(leftover))
for i, text := range leftover {
steerItems[i] = queueItem{Prompt: text}
for i, sm := range leftover {
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
}
a.queue = append(steerItems, a.queue...)
a.mu.Unlock()
@@ -788,6 +887,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
sendFn(StreamChunkEvent{Content: ev.Chunk})
case kit.ReasoningDeltaEvent:
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
case kit.ReasoningCompleteEvent:
sendFn(ReasoningCompleteEvent{})
case kit.ToolOutputEvent:
sendFn(ToolOutputEvent{
ToolCallID: ev.ToolCallID,
@@ -896,6 +997,47 @@ func (a *App) NotifyWidgetUpdate() {
}
}
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
// prompt templates and skills from their provider callbacks. Called by file
// watchers when .md/.txt files change in prompt or skill directories.
// In non-interactive mode this is a no-op.
func (a *App) NotifyContentReload() {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(ContentReloadEvent{})
}
}
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
// tool names and MCP tool count from provider callbacks. Called when background
// MCP tool loading completes. In non-interactive mode this is a no-op.
func (a *App) NotifyMCPToolsReady() {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPToolsReadyEvent{})
}
}
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
// display a system message when a single MCP server finishes loading. Called
// per server as background MCP tool loading progresses.
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
a.mu.Lock()
prog := a.program
a.mu.Unlock()
if prog != nil {
prog.Send(MCPServerLoadedEvent{
ServerName: serverName,
ToolCount: toolCount,
Error: err,
})
}
}
// SendEvent sends a tea.Msg to the registered program. Safe to call from
// any goroutine. No-op when no program is registered.
//
+24
View File
@@ -16,6 +16,11 @@ type ReasoningChunkEvent struct {
Delta string
}
// ReasoningCompleteEvent is sent when reasoning/thinking is finished, after
// the last reasoning token has been processed. The TUI uses this to freeze
// the reasoning duration counter.
type ReasoningCompleteEvent struct{}
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
// It carries the tool name and its arguments for display purposes.
type ToolCallStartedEvent struct {
@@ -162,6 +167,25 @@ type ModelChangedEvent struct {
// from its WidgetProvider on the next render cycle.
type WidgetUpdateEvent struct{}
// ContentReloadEvent is sent when prompt templates or skills are reloaded
// from disk (e.g. by a file watcher detecting changes). The TUI refreshes
// its autocomplete entries and internal state from the provider callbacks.
type ContentReloadEvent struct{}
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
// The TUI refreshes its tool names and MCP tool count from provider callbacks
// so that /tools and the startup info bar reflect the loaded MCP tools.
type MCPToolsReadyEvent struct{}
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
// (successfully or with error). The TUI displays a system message so users
// see real-time progress as each server initializes.
type MCPServerLoadedEvent struct {
ServerName string
ToolCount int
Error error // nil on success
}
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
// pre-fill the input editor with text. The TUI handles this by setting the
// textarea content and moving the cursor to the end.
+11 -5
View File
@@ -162,6 +162,8 @@ type Theme struct {
// and merged into the custom provider in the model registry.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
@@ -197,11 +199,13 @@ type Config struct {
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
Theme any `json:"theme" yaml:"theme"`
// Model generation parameters
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
// Thinking / extended reasoning
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
@@ -368,6 +372,8 @@ mcpServers:
# temperature: 0.7 # Randomness (0.0-1.0)
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
# top-k: 40 # Top K sampling
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
# API Configuration (can also use environment variables)
+1 -20
View File
@@ -67,7 +67,7 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
}
if info.IsDir() {
return readDirectory(absPath)
return fantasy.NewTextErrorResponse(fmt.Sprintf("'%s' is a directory, not a file. Use the ls tool to list directory contents.", args.Path)), nil
}
content, err := os.ReadFile(absPath)
@@ -116,25 +116,6 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
return fantasy.NewTextResponse(tr.Content), nil
}
func readDirectory(absPath string) (fantasy.ToolResponse, error) {
entries, err := os.ReadDir(absPath)
if err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to read directory: %v", err)), nil
}
var result strings.Builder
for _, entry := range entries {
name := entry.Name()
if entry.IsDir() {
name += "/"
}
result.WriteString(name + "\n")
}
tr := truncateHead(result.String(), 500, defaultMaxBytes)
return fantasy.NewTextResponse(tr.Content), nil
}
// resolvePathWithWorkDir resolves a path to an absolute path relative to the
// given workDir. If workDir is empty, os.Getwd() is used.
func resolvePathWithWorkDir(path, workDir string) (string, error) {
+26 -33
View File
@@ -130,13 +130,22 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
), fmt.Errorf("no subagent spawner in context")
}
// Detach from the parent's deadline so the subagent gets its own
// independent timeout (applied downstream in Kit.Subagent). The parent
// context may carry a tight deadline from the LLM generation loop or
// other tool timeouts that would prematurely kill the subagent.
// We preserve context values (spawner, etc.) and propagate parent
// cancellation (e.g. user hits Ctrl-C) without inheriting the deadline.
spawnCtx := detachedWithCancel(ctx)
// Build a clean context for the subagent that inherits values (e.g. the
// spawner callback) but is completely detached from the parent's
// deadline AND cancellation. The subagent gets its own independent
// timeout (applied downstream in Kit.Subagent).
//
// Why full detachment instead of propagating parent cancellation?
// The parent context may already be done (deadline exceeded or
// cancelled) by the time this tool handler executes — for example when
// the generation loop context carries a deadline, when the user
// double-ESC cancels mid-turn, or when parallel tool execution
// encounters a race between stream completion and tool dispatch. Using
// context.WithoutCancel (Go 1.21+) ensures the subagent always starts
// cleanly with a fresh timeout, following the pattern used by crush for
// shutdown-resilient child work. The subagent's own timeout
// (defaultSubagentTimeout / user-specified) provides the safety net.
spawnCtx := context.WithoutCancel(valuesContext{parent: ctx})
// Spawn in-process subagent.
result, err := spawner(spawnCtx, call.ID, args.Task, args.Model, args.SystemPrompt, timeout)
@@ -173,37 +182,21 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
}
// ---------------------------------------------------------------------------
// Context detachment
// Context helpers
// ---------------------------------------------------------------------------
// detachedContext wraps a parent context, preserving its values but removing
// its deadline and cancellation. This allows the subagent to have its own
// independent timeout while still accessing context-stored values (e.g. the
// subagent spawner function).
type detachedContext struct {
// valuesContext preserves a parent context's values (e.g. the subagent
// spawner callback) while stripping its deadline and cancellation. Combined
// with context.WithoutCancel() this gives the subagent a completely clean
// context that only inherits value-based dependencies.
type valuesContext struct {
parent context.Context
}
func (d detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (d detachedContext) Done() <-chan struct{} { return nil }
func (d detachedContext) Err() error { return nil }
func (d detachedContext) Value(key any) any { return d.parent.Value(key) }
// detachedWithCancel creates a new context that inherits values from the
// parent but has no deadline. Cancellation of the parent is propagated: when
// the parent is cancelled the returned context is also cancelled, but the
// parent's deadline does not apply to the child.
func detachedWithCancel(parent context.Context) context.Context {
child, cancel := context.WithCancel(detachedContext{parent: parent})
go func() {
select {
case <-parent.Done():
cancel()
case <-child.Done():
}
}()
return child
}
func (v valuesContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (v valuesContext) Done() <-chan struct{} { return nil }
func (v valuesContext) Err() error { return nil }
func (v valuesContext) Value(key any) any { return v.parent.Value(key) }
// truncateResponse limits the response length to avoid overwhelming context windows.
func truncateResponse(s string, maxLen int) string {
+115
View File
@@ -0,0 +1,115 @@
package core
import (
"context"
"testing"
"time"
)
func TestValuesContext_StripsDeadlineAndCancellation(t *testing.T) {
// Parent with a tight deadline.
parent, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
time.Sleep(5 * time.Millisecond) // Let deadline expire.
if parent.Err() == nil {
t.Fatal("expected parent to be expired")
}
vc := valuesContext{parent: parent}
if _, ok := vc.Deadline(); ok {
t.Error("valuesContext should report no deadline")
}
if vc.Done() != nil {
t.Error("valuesContext.Done() should return nil")
}
if vc.Err() != nil {
t.Errorf("valuesContext.Err() should be nil, got %v", vc.Err())
}
}
func TestValuesContext_PreservesValues(t *testing.T) {
type testKey struct{}
parent := context.WithValue(context.Background(), testKey{}, "hello")
vc := valuesContext{parent: parent}
got, ok := vc.Value(testKey{}).(string)
if !ok || got != "hello" {
t.Errorf("expected value 'hello', got %q (ok=%v)", got, ok)
}
}
func TestSpawnContext_SurvivesCancelledParent(t *testing.T) {
// Simulate the exact scenario from the bug: the parent generation
// context is already cancelled when the subagent tool handler runs.
parent, cancel := context.WithCancel(context.Background())
cancel() // Cancelled before detach.
// This is what executeSubagent now does:
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
// The spawn context must be alive.
if spawnCtx.Err() != nil {
t.Fatalf("spawnCtx should be alive, got err: %v", spawnCtx.Err())
}
// Adding a timeout should produce a working context.
tCtx, tCancel := context.WithTimeout(spawnCtx, 5*time.Second)
defer tCancel()
if tCtx.Err() != nil {
t.Fatalf("timeout context should be alive, got err: %v", tCtx.Err())
}
}
func TestSpawnContext_SurvivesDeadlineExceededParent(t *testing.T) {
// Simulate: parent had a deadline that already expired.
parent, pCancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer pCancel()
time.Sleep(5 * time.Millisecond)
if parent.Err() != context.DeadlineExceeded {
t.Fatalf("expected parent deadline exceeded, got: %v", parent.Err())
}
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
if spawnCtx.Err() != nil {
t.Fatalf("spawnCtx should be alive after deadline-exceeded parent, got: %v", spawnCtx.Err())
}
}
func TestSpawnContext_PreservesSpawnerValue(t *testing.T) {
// Verify the subagent spawner callback survives context detachment.
called := false
spawner := SubagentSpawnFunc(func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error) {
called = true
return &SubagentSpawnResult{Response: "ok"}, nil
})
parent := WithSubagentSpawner(context.Background(), spawner)
// Cancel the parent.
parentCtx, cancel := context.WithCancel(parent)
cancel()
spawnCtx := context.WithoutCancel(valuesContext{parent: parentCtx})
// Should be able to retrieve the spawner from the detached context.
recovered := getSubagentSpawner(spawnCtx)
if recovered == nil {
t.Fatal("spawner should be recoverable from detached context")
}
result, err := recovered(spawnCtx, "tc1", "test task", "", "", time.Minute)
if err != nil {
t.Fatalf("spawner call failed: %v", err)
}
if !called {
t.Error("spawner was not called")
}
if result.Response != "ok" {
t.Errorf("expected 'ok', got %q", result.Response)
}
}
+100
View File
@@ -77,6 +77,64 @@ type Context struct {
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
CancelAndSend func(string)
// Abort cancels the current agent turn (if running) and clears the
// message queue. Unlike CancelAndSend, no new message is injected —
// the agent simply stops. Safe to call when idle (no-op).
//
// Example:
//
// ctx.Abort() // stop whatever the agent is doing
Abort func()
// IsIdle returns true when the agent is not processing a turn.
// Extensions can use this to decide whether to dispatch immediately
// or queue work for later.
//
// Example:
//
// if ctx.IsIdle() {
// ctx.SendMessage("start new task")
// }
IsIdle func() bool
// Compact triggers context compaction, summarising older messages to
// free context window space. Returns an error if compaction cannot
// start (e.g. agent is busy or app is closed). The actual compaction
// runs asynchronously; use OnComplete/OnError callbacks in
// CompactConfig to observe the result.
//
// Example:
//
// err := ctx.Compact(ext.CompactConfig{
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
// })
Compact func(CompactConfig) error
// SendMultimodalMessage injects a message with file attachments (images,
// documents) into the conversation and triggers a new agent turn. Files
// are described by FilePart structs containing the raw bytes, filename,
// and MIME type. If the agent is busy the message is queued.
//
// Example:
//
// data, _ := os.ReadFile("photo.jpg")
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
// })
SendMultimodalMessage func(text string, files []FilePart)
// GetSessionUsage returns aggregated token usage and cost statistics
// for the current session. This includes total input/output tokens,
// cache read/write tokens, total cost, and request count.
//
// Example:
//
// usage := ctx.GetSessionUsage()
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
GetSessionUsage func() SessionUsage
// SetWidget places or updates a persistent widget in the TUI. Widgets
// remain visible across agent turns until explicitly removed. The
// widget is identified by WidgetConfig.ID; calling SetWidget with the
@@ -937,6 +995,48 @@ type StatusBarEntry struct {
Priority int
}
// CompactConfig configures a programmatic context compaction request.
type CompactConfig struct {
// CustomInstructions is optional text appended to the summary prompt
// (e.g. "Focus on the API design decisions"). Empty uses the default.
CustomInstructions string
// OnComplete is called when compaction finishes successfully.
// May be nil if the caller doesn't need notification.
OnComplete func()
// OnError is called when compaction fails. The argument is the error message.
// May be nil if the caller doesn't need notification.
OnError func(errMsg string)
}
// FilePart describes a file attachment for multimodal messages. Extensions
// use this with SendMultimodalMessage to attach images or documents.
type FilePart struct {
// Filename is the name of the file (e.g. "photo.jpg").
Filename string
// Data is the raw file content.
Data []byte
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
MediaType string
}
// SessionUsage contains aggregated token usage and cost statistics for
// the current session. Extensions use this with GetSessionUsage() to
// report usage information.
type SessionUsage struct {
// TotalInputTokens is the sum of input tokens across all requests.
TotalInputTokens int
// TotalOutputTokens is the sum of output tokens across all requests.
TotalOutputTokens int
// TotalCacheReadTokens is the sum of cache read tokens.
TotalCacheReadTokens int
// TotalCacheWriteTokens is the sum of cache write tokens.
TotalCacheWriteTokens int
// TotalCost is the total cost in USD across all requests.
TotalCost float64
// RequestCount is the number of LLM requests made in this session.
RequestCount int
}
// PrintBlockOpts configures a custom styled block for PrintBlock.
type PrintBlockOpts struct {
// Text is the main content to display.
+25 -26
View File
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
// Install clones a git repository to the appropriate scope.
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
return i.install(source, scope, nil)
}
// install is the internal implementation that supports optional include paths.
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
targetDir := i.getInstallPath(source, scope)
// Check if already installed
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
Pinned: source.Pinned,
Scope: scope,
Installed: time.Now(),
Include: includePaths,
}
if err := i.addToManifest(entry, scope); err != nil {
// Don't fail the install, just log the error
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
cleanCmd.Dir = targetDir
_ = cleanCmd.Run() // Ignore errors - clean is best effort
// Update manifest timestamp
// Update manifest timestamp, preserving existing fields like Include
existing, _ := i.loadManifest(scope)
var include []string
var installed time.Time
if existing != nil {
for _, p := range existing.Packages {
if p.Host+"/"+p.Path == source.Identity() {
include = p.Include
installed = p.Installed
break
}
}
}
if installed.IsZero() {
installed = time.Now()
}
entry := ManifestEntry{
Source: source.String(),
Repo: source.Repo,
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
Ref: "",
Pinned: false,
Scope: scope,
Installed: time.Now(),
Installed: installed,
Updated: time.Now(),
Include: include,
}
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
// InstallWithInclude clones a repo and installs only the specified extensions.
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
// First, do a regular install
if err := i.Install(source, scope); err != nil {
return err
}
// If specific includes were requested, update the manifest
if len(includePaths) > 0 {
entry := ManifestEntry{
Source: source.String(),
Repo: source.Repo,
Host: source.Host,
Path: source.Path,
Ref: source.Ref,
Pinned: source.Pinned,
Scope: scope,
Include: includePaths,
}
if err := addEntryToManifest(entry, scope); err != nil {
return fmt.Errorf("updating manifest with includes: %w", err)
}
}
return nil
return i.install(source, scope, includePaths)
}
// CleanupTempDir removes a temporary directory used for preview.
+11 -18
View File
@@ -2,11 +2,11 @@ package extensions
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
"github.com/traefik/yaegi/interp"
"github.com/traefik/yaegi/stdlib"
"github.com/traefik/yaegi/stdlib/unrestricted"
@@ -34,15 +34,11 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
for _, p := range paths {
ext, err := loadSingleExtension(p)
if err != nil {
log.Warn("skipping extension", "path", p, "err", err)
log.Printf("WARN skipping extension: path=%s err=%v", p, err)
continue
}
loaded = append(loaded, *ext)
log.Debug("loaded extension", "path", p,
"handlers", countHandlers(ext),
"tools", len(ext.Tools),
"commands", len(ext.Commands),
"tool_renderers", len(ext.ToolRenderers))
log.Printf("DEBUG loaded extension: path=%s handlers=%d tools=%d commands=%d tool_renderers=%d", p, countHandlers(ext), len(ext.Tools), len(ext.Commands), len(ext.ToolRenderers))
}
return loaded, nil
}
@@ -133,7 +129,7 @@ func findExtensionsInDir(dir string) []string {
for _, entry := range entries {
full := filepath.Join(dir, entry.Name())
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
results = append(results, full)
} else if entry.IsDir() {
main := filepath.Join(full, "main.go")
@@ -190,9 +186,13 @@ func findExtensionsInRepo(repoPath string) []string {
isExtDir := base == "extensions" || base == "ext" ||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
// Allow walking into examples/ so we can reach examples/extensions/ etc,
// but don't treat examples/ itself or non-extension subdirs as extension locations.
if relPath == "examples" {
return nil
}
if !isExtDir && !isExamplesSubdir {
if !isExtDir {
mainPath := filepath.Join(path, "main.go")
if _, err := os.Stat(mainPath); err == nil {
if relPath == base { // Top-level directory
@@ -202,13 +202,6 @@ func findExtensionsInRepo(repoPath string) []string {
}
return filepath.SkipDir
}
if isExamplesSubdir || isExtDir {
if !multiFileDirs[relPath] {
multiFileDirs[relPath] = true
results = append(results, mainPath)
}
return filepath.SkipDir
}
}
return filepath.SkipDir
}
@@ -227,7 +220,7 @@ func findExtensionsInRepo(repoPath string) []string {
}
// It's a file
if !strings.HasSuffix(info.Name(), ".go") {
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
return nil
}
+7 -16
View File
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
isExtDir := base == "extensions" || base == "ext" ||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
// Or check if it's a subdirectory of examples/ that might contain extensions
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
// Allow walking into examples/ so we can reach examples/extensions/ etc,
// but don't treat examples/ itself or non-extension subdirs as extension locations.
if relPath == "examples" {
return nil
}
if !isExtDir && !isExamplesSubdir {
if !isExtDir {
// Check for main.go before skipping
mainPath := filepath.Join(path, "main.go")
if _, err := os.Stat(mainPath); err == nil {
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
}
return filepath.SkipDir
}
// Inside a valid extensions directory
if isExamplesSubdir || isExtDir {
if !multiFileDirs[relPath] {
multiFileDirs[relPath] = true
previews = append(previews, ExtensionPreview{
Path: "./" + relPath + "/main.go",
Name: deriveExtensionName(relPath+"/main.go", true),
IsMain: true,
})
}
return filepath.SkipDir
}
}
// Not an extension location
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
}
// It's a file - check if it's a valid extension
if !strings.HasSuffix(info.Name(), ".go") {
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
return nil
}
+18 -8
View File
@@ -2,12 +2,12 @@ package extensions
import (
"fmt"
"log"
"os"
"sort"
"strings"
"sync"
"github.com/charmbracelet/log"
"github.com/spf13/viper"
)
@@ -86,6 +86,21 @@ func normalizeContext(ctx Context) Context {
if ctx.CancelAndSend == nil {
ctx.CancelAndSend = func(string) {}
}
if ctx.Abort == nil {
ctx.Abort = func() {}
}
if ctx.IsIdle == nil {
ctx.IsIdle = func() bool { return true }
}
if ctx.Compact == nil {
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
}
if ctx.SendMultimodalMessage == nil {
ctx.SendMultimodalMessage = func(string, []FilePart) {}
}
if ctx.GetSessionUsage == nil {
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
}
if ctx.SetWidget == nil {
ctx.SetWidget = func(WidgetConfig) {}
}
@@ -355,10 +370,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
log.Warn("extension handler error",
"path", ext.Path,
"event", event.Type(),
"err", err)
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
continue
}
if result == nil {
@@ -692,9 +704,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
safeInvoke := func(h func(string)) {
defer func() {
if rec := recover(); rec != nil {
log.Warn("custom event handler panicked",
"event", name,
"err", fmt.Sprintf("%v", rec))
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
}
}()
h(data)
+3
View File
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
// Session types
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
// Option types
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
// LLM completion types
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
"FilePart": reflect.ValueOf((*FilePart)(nil)),
// Status bar types
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
+192
View File
@@ -0,0 +1,192 @@
package extensions
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
// Watcher monitors extension directories for file changes and triggers
// a reload callback when .go files are created, modified, or removed.
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
type Watcher struct {
watcher *fsnotify.Watcher
onReload func()
debounce time.Duration
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
}
// NewWatcher creates a file watcher that monitors the given directories
// for .go file changes. When a change is detected (after debouncing),
// onReload is called. The watcher must be started with Start() and
// stopped with Close().
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
fsw, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating file watcher: %w", err)
}
for _, dir := range dirs {
// Watch the directory itself.
if err := fsw.Add(dir); err != nil {
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
continue
}
// Also watch immediate subdirectories (for */main.go pattern).
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
if err := fsw.Add(subdir); err != nil {
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
}
}
}
}
return &Watcher{
watcher: fsw,
onReload: onReload,
debounce: 300 * time.Millisecond,
done: make(chan struct{}),
}, nil
}
// Start begins watching for file changes. It blocks until the context
// is cancelled or Close() is called. Typically called in a goroutine.
func (w *Watcher) Start(ctx context.Context) {
w.mu.Lock()
ctx, w.cancel = context.WithCancel(ctx)
w.mu.Unlock()
defer close(w.done)
var timer *time.Timer
var timerC <-chan time.Time
for {
select {
case <-ctx.Done():
if timer != nil {
timer.Stop()
}
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// Only care about .go files.
if !strings.HasSuffix(event.Name, ".go") {
continue
}
// React to write, create, remove, rename events.
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
continue
}
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
// Debounce: reset timer on each event.
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
case <-timerC:
timerC = nil
timer = nil
log.Printf("DEBUG watcher: reloading extensions")
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Printf("WARN watcher: error: %v", err)
}
}
}
// Close stops the watcher and releases resources.
func (w *Watcher) Close() error {
w.mu.Lock()
cancel := w.cancel
w.mu.Unlock()
if cancel != nil {
cancel()
}
// Wait for the event loop to finish.
<-w.done
return w.watcher.Close()
}
// WatchedDirs returns the directories to watch for extension changes.
// This includes the global extensions directory and the project-local
// .kit/extensions/ directory (if they exist). Explicit -e paths that
// point to directories are also included; explicit file paths cause
// their parent directory to be watched instead.
func WatchedDirs(extraPaths []string) []string {
var dirs []string
seen := make(map[string]bool)
add := func(dir string) {
abs, err := filepath.Abs(dir)
if err != nil {
return
}
if seen[abs] {
return
}
// Verify the directory exists.
info, err := os.Stat(abs)
if err != nil || !info.IsDir() {
return
}
seen[abs] = true
dirs = append(dirs, abs)
}
// Global extensions dir.
add(globalExtensionsDir())
// Project-local extensions dir.
add(filepath.Join(".kit", "extensions"))
// Explicit paths that are directories.
for _, p := range extraPaths {
info, err := os.Stat(p)
if err != nil {
continue
}
if info.IsDir() {
add(p)
} else {
// For explicit files, watch the parent directory.
add(filepath.Dir(p))
}
}
return dirs
}
+158
View File
@@ -0,0 +1,158 @@
package extensions
import (
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
func TestWatcher_ReloadsOnGoFileChange(t *testing.T) {
dir := t.TempDir()
// Write an initial extension file.
extFile := filepath.Join(dir, "test.go")
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := NewWatcher([]string{dir}, func() {
reloadCount.Add(1)
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Modify the file.
time.Sleep(50 * time.Millisecond) // let watcher settle
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
t.Fatal(err)
}
// Wait for debounce (300ms) + margin.
time.Sleep(600 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload, got %d", got)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
}
func TestWatcher_IgnoresNonGoFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := NewWatcher([]string{dir}, func() {
reloadCount.Add(1)
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Write a non-.go file.
time.Sleep(50 * time.Millisecond)
txtFile := filepath.Join(dir, "notes.txt")
if err := os.WriteFile(txtFile, []byte("hello"), 0o644); err != nil {
t.Fatal(err)
}
// Wait past the debounce window.
time.Sleep(600 * time.Millisecond)
if got := reloadCount.Load(); got != 0 {
t.Errorf("expected 0 reloads for .txt file, got %d", got)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
}
func TestWatcher_Debounces(t *testing.T) {
dir := t.TempDir()
extFile := filepath.Join(dir, "ext.go")
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := NewWatcher([]string{dir}, func() {
reloadCount.Add(1)
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(50 * time.Millisecond)
// Rapid-fire writes (simulating editor save: write temp, rename, etc.).
for range 5 {
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
t.Fatal(err)
}
time.Sleep(50 * time.Millisecond)
}
// Wait for debounce to fire.
time.Sleep(600 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 debounced reload, got %d", got)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
}
func TestWatchedDirs_Deduplicates(t *testing.T) {
dir := t.TempDir()
dirs := WatchedDirs([]string{dir, dir})
count := 0
for _, d := range dirs {
abs, _ := filepath.Abs(dir)
if d == abs {
count++
}
}
if count != 1 {
t.Errorf("expected directory to appear once, got %d", count)
}
}
func TestWatchedDirs_FileParent(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "ext.go")
if err := os.WriteFile(file, []byte("package main\n"), 0o644); err != nil {
t.Fatal(err)
}
dirs := WatchedDirs([]string{file})
abs, _ := filepath.Abs(dir)
found := false
for _, d := range dirs {
if d == abs {
found = true
}
}
if !found {
t.Errorf("expected parent dir %s in watched dirs %v", abs, dirs)
}
}
+86 -30
View File
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
// CoreTools overrides the default core tool set. If empty, core.AllTools()
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
CoreTools []fantasy.AgentTool
// DisableCoreTools, when true, prevents loading any core tools.
// If both DisableCoreTools is true and CoreTools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// ExtraTools are additional tools added alongside core, MCP, and extension
// tools. They do not replace the defaults — they extend them.
ExtraTools []fantasy.AgentTool
@@ -40,6 +44,30 @@ type AgentSetupOptions struct {
// wrapping. Used by the SDK hook system. Both wrappers compose:
// extension wrapper runs first (inner), then this wrapper (outer).
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
// ProviderConfig, when non-nil, is used directly instead of calling
// BuildProviderConfig(). Callers that already hold viperInitMu can
// pre-build this and release the lock before calling SetupAgent, so the
// slow agent/MCP initialisation runs concurrently with other New() calls.
ProviderConfig *models.ProviderConfig
// Debug enables debug logging. When zero-value, viper is consulted.
// Only meaningful when ProviderConfig is also set.
Debug bool
// NoExtensions skips extension loading. When false, viper is consulted.
// Only meaningful when ProviderConfig is also set.
NoExtensions bool
// MaxSteps overrides the agent step limit. 0 means use viper value.
// Only meaningful when ProviderConfig is also set.
MaxSteps int
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
// is also set.
StreamingEnabled bool
// AuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports are configured with OAuth support.
AuthHandler tools.MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
}
// AgentSetupResult bundles the created agent and any debug logger so the caller
@@ -63,23 +91,27 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
temperature := float32(viper.GetFloat64("temperature"))
topP := float32(viper.GetFloat64("top-p"))
topK := int32(viper.GetInt("top-k"))
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
numGPU := int32(viper.GetInt("num-gpu-layers"))
mainGPU := int32(viper.GetInt("main-gpu"))
cfg := &models.ProviderConfig{
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
ModelString: viper.GetString("model"),
SystemPrompt: systemPrompt,
ProviderAPIKey: viper.GetString("provider-api-key"),
ProviderURL: viper.GetString("provider-url"),
MaxTokens: viper.GetInt("max-tokens"),
Temperature: &temperature,
TopP: &topP,
TopK: &topK,
FrequencyPenalty: &frequencyPenalty,
PresencePenalty: &presencePenalty,
StopSequences: viper.GetStringSlice("stop-sequences"),
NumGPU: &numGPU,
MainGPU: &mainGPU,
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
}
return cfg, systemPrompt, nil
@@ -88,15 +120,36 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
// SetupAgent creates an agent from the current viper state + the provided
// options. It wraps BuildProviderConfig and agent.CreateAgent.
func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, error) {
modelConfig, systemPrompt, err := BuildProviderConfig()
if err != nil {
return nil, err
var modelConfig *models.ProviderConfig
var systemPrompt string
if opts.ProviderConfig != nil {
// Pre-built config supplied by caller (e.g. Kit.New after releasing
// viperInitMu). Use it directly — no viper reads needed here.
modelConfig = opts.ProviderConfig
systemPrompt = modelConfig.SystemPrompt
} else {
var err error
modelConfig, systemPrompt, err = BuildProviderConfig()
if err != nil {
return nil, err
}
}
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
// fields (set when ProviderConfig was pre-built) over viper fallback.
debugEnabled := opts.Debug || viper.GetBool("debug")
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
maxSteps := opts.MaxSteps
if maxSteps == 0 {
maxSteps = viper.GetInt("max-steps")
}
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
// Create the appropriate debug logger.
var debugLogger tools.DebugLogger
var bufferedLogger *tools.BufferedDebugLogger
if viper.GetBool("debug") {
if debugEnabled {
if opts.UseBufferedLogger {
bufferedLogger = tools.NewBufferedDebugLogger(true)
debugLogger = bufferedLogger
@@ -108,7 +161,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
// Load extensions unless --no-extensions is set.
var extRunner *extensions.Runner
var extCreationOpts extensionCreationOpts
if !viper.GetBool("no-extensions") {
if !noExtensions {
var extErr error
extRunner, extCreationOpts, extErr = loadExtensions()
if extErr != nil {
@@ -137,18 +190,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
}
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: viper.GetInt("max-steps"),
StreamingEnabled: viper.GetBool("stream"),
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
CoreTools: opts.CoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
ModelConfig: modelConfig,
MCPConfig: opts.MCPConfig,
SystemPrompt: systemPrompt,
MaxSteps: maxSteps,
StreamingEnabled: streamingEnabled,
ShowSpinner: opts.ShowSpinner,
Quiet: opts.Quiet,
SpinnerFunc: opts.SpinnerFunc,
DebugLogger: debugLogger,
AuthHandler: opts.AuthHandler,
CoreTools: opts.CoreTools,
DisableCoreTools: opts.DisableCoreTools,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)
+4
View File
@@ -37,6 +37,8 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
Attachment: cfg.Attachment,
Reasoning: cfg.Reasoning,
Temperature: cfg.Temperature,
BaseURL: cfg.BaseURL,
APIKey: cfg.APIKey,
Cost: Cost{
Input: cfg.Cost.Input,
Output: cfg.Cost.Output,
@@ -52,6 +54,8 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
// This is a duplicate here to avoid circular dependencies with internal/config.
type CustomModelConfig struct {
Name string `json:"name" yaml:"name"`
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
Family string `json:"family,omitempty" yaml:"family,omitempty"`
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
File diff suppressed because one or more lines are too long
+43 -148
View File
@@ -10,7 +10,6 @@ import (
"maps"
"net/http"
"os"
"regexp"
"strings"
"time"
@@ -144,20 +143,22 @@ func ParseThinkingLevel(s string) ThinkingLevel {
// ProviderConfig holds configuration for creating LLM providers.
type ProviderConfig struct {
ModelString string
SystemPrompt string
ProviderAPIKey string
ProviderURL string
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
StopSequences []string
NumGPU *int32
MainGPU *int32
TLSSkipVerify bool
ThinkingLevel ThinkingLevel
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
ModelString string
SystemPrompt string
ProviderAPIKey string
ProviderURL string
MaxTokens int
Temperature *float32
TopP *float32
TopK *int32
FrequencyPenalty *float32
PresencePenalty *float32
StopSequences []string
NumGPU *int32
MainGPU *int32
TLSSkipVerify bool
ThinkingLevel ThinkingLevel
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
}
// ProviderResult contains the result of provider creation.
@@ -525,13 +526,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
switch level {
case ThinkingMinimal:
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
return new(openai.ReasoningEffortMinimal)
case ThinkingLow:
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
return new(openai.ReasoningEffortLow)
case ThinkingMedium:
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
return new(openai.ReasoningEffortMedium)
case ThinkingHigh:
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
return new(openai.ReasoningEffortHigh)
default:
return nil
}
@@ -1000,139 +1001,29 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
return &ProviderResult{Model: model}, nil
}
// thinkTagRegex matches <think>...</think> tags for extracting reasoning content
// from models that wrap thinking in XML-like tags (e.g., Qwen, DeepSeek).
var thinkTagRegex = regexp.MustCompile(`(?s)<think>(.*?)</think>`)
// customExtraContentFunc extracts reasoning from <think> tags in the content field.
// This handles models like Qwen and DeepSeek that return reasoning wrapped in XML tags
// rather than using a separate reasoning_content field.
func customExtraContentFunc(choice openaisdk.ChatCompletionChoice) []fantasy.Content {
var content []fantasy.Content
if choice.Message.Content == "" {
return content
}
// Check for <think> tags in the content
matches := thinkTagRegex.FindStringSubmatch(choice.Message.Content)
if len(matches) > 1 {
// Found reasoning content in <think> tags
reasoning := strings.TrimSpace(matches[1])
if reasoning != "" {
content = append(content, fantasy.ReasoningContent{
Text: reasoning,
})
}
}
return content
}
// customStreamExtraFunc handles streaming responses with <think> tags.
// It extracts reasoning content and emits proper reasoning events.
func customStreamExtraFunc(
chunk openaisdk.ChatCompletionChunk,
yield func(fantasy.StreamPart) bool,
ctx map[string]any,
) (map[string]any, bool) {
if len(chunk.Choices) == 0 {
return ctx, true
}
const reasoningStartedKey = "reasoning_started"
const reasoningBufferKey = "reasoning_buffer"
const inThinkTagKey = "in_think_tag"
reasoningStarted, _ := ctx[reasoningStartedKey].(bool)
inThinkTag, _ := ctx[inThinkTagKey].(bool)
reasoningBuffer, _ := ctx[reasoningBufferKey].(string)
for i, choice := range chunk.Choices {
content := choice.Delta.Content
if content == "" {
continue
}
// Check for <think> tag start
if strings.Contains(content, "<think>") {
inThinkTag = true
ctx[inThinkTagKey] = true
// Emit reasoning start event
if !reasoningStarted {
reasoningStarted = true
ctx[reasoningStartedKey] = true
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeReasoningStart,
ID: fmt.Sprintf("%d", i),
}) {
return ctx, false
}
}
// Extract content after <think>
parts := strings.SplitN(content, "<think>", 2)
if len(parts) > 1 && parts[1] != "" {
reasoningBuffer += parts[1]
ctx[reasoningBufferKey] = reasoningBuffer
}
continue
}
// Check for </think> tag end
if strings.Contains(content, "</think>") {
inThinkTag = false
ctx[inThinkTagKey] = false
// Extract content before </think>
parts := strings.SplitN(content, "</think>", 2)
if len(parts) > 0 {
reasoningBuffer += parts[0]
}
// Emit the accumulated reasoning
if reasoningBuffer != "" {
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeReasoningDelta,
ID: fmt.Sprintf("%d", i),
Delta: reasoningBuffer,
}) {
return ctx, false
}
ctx[reasoningBufferKey] = ""
}
// Emit reasoning end
if !yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeReasoningEnd,
ID: fmt.Sprintf("%d", i),
}) {
return ctx, false
}
continue
}
// Accumulate reasoning content while in think tag
if inThinkTag {
reasoningBuffer += content
ctx[reasoningBufferKey] = reasoningBuffer
}
}
return ctx, true
}
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
return openai.DefaultToPrompt(prompt, systemPrompt, user)
}
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
if config.ProviderURL == "" {
return nil, fmt.Errorf("custom provider requires --provider-url")
// Resolve base URL: per-model override > global provider-url flag/config
registry := GetGlobalRegistry()
modelInfo := registry.LookupModel("custom", modelName)
baseURL := config.ProviderURL
if modelInfo != nil && modelInfo.BaseURL != "" {
baseURL = modelInfo.BaseURL
}
if baseURL == "" {
return nil, fmt.Errorf("custom provider requires --provider-url or a baseUrl in the model config")
}
apiKey := config.ProviderAPIKey
if modelInfo != nil && modelInfo.APIKey != "" {
apiKey = modelInfo.APIKey
}
if apiKey == "" {
apiKey = os.Getenv("CUSTOM_API_KEY")
}
@@ -1141,15 +1032,13 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
apiKey = "custom"
}
// Use the openai provider directly with custom hooks to handle <think> tags
// from models like Qwen and DeepSeek that wrap reasoning in XML tags.
// <think> tag extraction is handled transparently at the agent layer,
// so no provider-level hooks are needed here.
var opts []openai.Option
opts = append(opts, openai.WithBaseURL(config.ProviderURL))
opts = append(opts, openai.WithBaseURL(baseURL))
opts = append(opts, openai.WithAPIKey(apiKey))
opts = append(opts, openai.WithName("custom"))
opts = append(opts, openai.WithLanguageModelOptions(
openai.WithLanguageModelExtraContentFunc(customExtraContentFunc),
openai.WithLanguageModelStreamExtraFunc(customStreamExtraFunc),
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
))
@@ -1277,6 +1166,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
if config.TopK != nil {
options["top_k"] = int(*config.TopK)
}
if config.FrequencyPenalty != nil {
options["frequency_penalty"] = *config.FrequencyPenalty
}
if config.PresencePenalty != nil {
options["presence_penalty"] = *config.PresencePenalty
}
if len(config.StopSequences) > 0 {
options["stop"] = config.StopSequences
}
+50 -2
View File
@@ -24,6 +24,8 @@ type ModelInfo struct {
Cost Cost
Limit Limit
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
BaseURL string // Per-model base URL override (custom models only)
APIKey string // Per-model API key override (custom models only)
}
// SupportsCaching returns true if this model family supports prompt caching.
@@ -367,8 +369,8 @@ func (r *ModelsRegistry) GetFantasyProviders() []string {
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
// Ollama is always supported (via openaicompat pointed at localhost)
if providerID == "ollama" {
// Ollama and custom are always supported (model names are user-defined).
if providerID == "ollama" || providerID == "custom" {
return true
}
@@ -400,6 +402,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
return &info
}
// ValidateModelString checks whether a model string is well-formed and refers
// to a known provider. It returns a user-friendly error with suggestions when
// the model or provider is unrecognised. Passing validation does not guarantee
// that API authentication will succeed — it only catches obvious mistakes
// (typos, missing provider prefix, non-existent provider names) early so that
// callers such as subagent spawning can return fast feedback.
//
// Unknown models under a known provider are allowed (the provider API is the
// authority), but a completely unknown provider is rejected.
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
provider, modelName, err := ParseModelString(modelString)
if err != nil {
return err
}
// Ollama and custom are always valid — model names are user-defined.
if provider == "ollama" || provider == "custom" {
return nil
}
// Check if the provider exists in the registry.
providerInfo := r.GetProviderInfo(provider)
if providerInfo == nil {
known := r.GetSupportedProviders()
return fmt.Errorf(
"unknown provider %q in model string %q. Known providers: %s",
provider, modelString, strings.Join(known, ", "),
)
}
// Provider exists — check if the model is known. An unknown model is
// only a warning (the provider API decides), but we surface suggestions
// so the caller can self-correct.
if r.LookupModel(provider, modelName) == nil {
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
return fmt.Errorf(
"model %q not found for provider %s. Did you mean one of: %s",
modelName, provider, strings.Join(suggestions, ", "),
)
}
// No suggestions — let it through; the provider API is the authority.
}
return nil
}
// Global registry instance
var globalRegistry = NewModelsRegistry()
+92
View File
@@ -0,0 +1,92 @@
package models
import (
"strings"
"testing"
)
func TestValidateModelString(t *testing.T) {
registry := GetGlobalRegistry()
tests := []struct {
name string
model string
wantErr bool
errSubstr string // expected substring in error message (empty = don't check)
}{
{
name: "valid anthropic model",
model: "anthropic/claude-sonnet-4-6",
wantErr: false,
},
{
name: "missing provider prefix",
model: "claude-sonnet-4-6",
wantErr: true,
errSubstr: "invalid model format",
},
{
name: "empty string",
model: "",
wantErr: true,
errSubstr: "invalid model format",
},
{
name: "unknown provider",
model: "fakeprovider/some-model",
wantErr: true,
errSubstr: "unknown provider",
},
{
name: "ollama always valid",
model: "ollama/llama3",
wantErr: false,
},
{
name: "custom always valid",
model: "custom/my-fine-tune",
wantErr: false,
},
{
name: "empty provider",
model: "/claude-sonnet-4-6",
wantErr: true,
errSubstr: "invalid model format",
},
{
name: "empty model name",
model: "anthropic/",
wantErr: true,
errSubstr: "invalid model format",
},
{
name: "unknown model under known provider (no suggestions)",
model: "anthropic/totally-unknown-xyz-999",
wantErr: false, // no suggestions → passes through
},
{
name: "typo model under known provider with suggestions",
model: "anthropic/claude-sonet", // misspelled "sonnet"
wantErr: true,
errSubstr: "Did you mean",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := registry.ValidateModelString(tt.model)
if tt.wantErr && err == nil {
t.Errorf("ValidateModelString(%q) = nil, want error", tt.model)
}
if !tt.wantErr && err != nil {
t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err)
}
if tt.errSubstr != "" && err != nil {
if !strings.Contains(err.Error(), tt.errSubstr) {
t.Errorf("ValidateModelString(%q) error = %q, want substring %q",
tt.model, err.Error(), tt.errSubstr)
}
}
})
}
}
+2 -6
View File
@@ -2,11 +2,10 @@ package prompts
import (
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/charmbracelet/log"
)
// LoadOptions configures how templates are discovered and loaded.
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
DroppedPath: tpl.FilePath,
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
})
log.Debug("template collision",
"name", tpl.Name,
"dropped", tpl.FilePath,
"kept", existing.FilePath)
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
} else {
tpl.Source = source
seen[tpl.Name] = tpl
+33
View File
@@ -24,6 +24,7 @@ const (
EntryTypeSessionInfo EntryType = "session_info"
EntryTypeExtensionData EntryType = "extension_data"
EntryTypeCompaction EntryType = "compaction"
EntryTypeSystemPrompt EntryType = "system_prompt"
)
// CurrentVersion is the session format version for JSONL tree sessions.
@@ -117,6 +118,19 @@ type CompactionEntry struct {
ModifiedFiles []string `json:"modified_files,omitempty"`
}
// SystemPromptEntry records the system prompt and model used for the session.
// This is primarily for sharing/debugging to see what instructions were
// active during the conversation. It does NOT participate in the tree
// structure (no ParentID) and is not used when building LLM context.
type SystemPromptEntry struct {
Type EntryType `json:"type"` // always "system_prompt"
ID string `json:"id"` // unique entry ID
Timestamp time.Time `json:"timestamp"` // when captured
Content string `json:"content"` // the system prompt text
Model string `json:"model"` // the model used (e.g., "claude-sonnet-4-5")
Provider string `json:"provider"` // the provider used (e.g., "anthropic")
}
// GenerateEntryID creates a unique entry identifier (16 hex chars).
func GenerateEntryID() string {
bytes := make([]byte, 8)
@@ -217,6 +231,18 @@ func NewCompactionEntry(parentID, summary, firstKeptEntryID string, tokensBefore
}
}
// NewSystemPromptEntry creates a SystemPromptEntry.
func NewSystemPromptEntry(content, model, provider string) *SystemPromptEntry {
return &SystemPromptEntry{
Type: EntryTypeSystemPrompt,
ID: GenerateEntryID(),
Timestamp: time.Now(),
Content: content,
Model: model,
Provider: provider,
}
}
// --- JSONL marshaling helpers ---
// MarshalEntry serializes any entry to a JSON line (no trailing newline).
@@ -295,6 +321,13 @@ func UnmarshalEntry(data []byte) (any, error) {
}
return &e, nil
case EntryTypeSystemPrompt:
var e SystemPromptEntry
if err := json.Unmarshal(data, &e); err != nil {
return nil, fmt.Errorf("failed to unmarshal system_prompt entry: %w", err)
}
return &e, nil
default:
return nil, fmt.Errorf("unknown entry type: %q", env.Type)
}
+113
View File
@@ -0,0 +1,113 @@
package session
import (
"encoding/json"
"testing"
)
func TestSystemPromptEntry(t *testing.T) {
// Test creation
content := "You are a helpful coding assistant."
model := "claude-sonnet-4-5"
provider := "anthropic"
entry := NewSystemPromptEntry(content, model, provider)
if entry.Type != EntryTypeSystemPrompt {
t.Errorf("Expected type %q, got %q", EntryTypeSystemPrompt, entry.Type)
}
if entry.Content != content {
t.Errorf("Expected content %q, got %q", content, entry.Content)
}
if entry.Model != model {
t.Errorf("Expected model %q, got %q", model, entry.Model)
}
if entry.Provider != provider {
t.Errorf("Expected provider %q, got %q", provider, entry.Provider)
}
if entry.ID == "" {
t.Error("Expected non-empty ID")
}
// Test marshaling
data, err := MarshalEntry(entry)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Test unmarshaling
unmarshaled, err := UnmarshalEntry(data)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
sysPrompt, ok := unmarshaled.(*SystemPromptEntry)
if !ok {
t.Fatalf("Expected *SystemPromptEntry, got %T", unmarshaled)
}
if sysPrompt.Type != EntryTypeSystemPrompt {
t.Errorf("Unmarshaled: expected type %q, got %q", EntryTypeSystemPrompt, sysPrompt.Type)
}
if sysPrompt.Content != content {
t.Errorf("Unmarshaled: expected content %q, got %q", content, sysPrompt.Content)
}
if sysPrompt.Model != model {
t.Errorf("Unmarshaled: expected model %q, got %q", model, sysPrompt.Model)
}
if sysPrompt.Provider != provider {
t.Errorf("Unmarshaled: expected provider %q, got %q", provider, sysPrompt.Provider)
}
if sysPrompt.ID != entry.ID {
t.Errorf("Unmarshaled: expected ID %q, got %q", entry.ID, sysPrompt.ID)
}
}
func TestSystemPromptEntryJSONStructure(t *testing.T) {
content := "Test system prompt content"
model := "gpt-4o"
provider := "openai"
entry := NewSystemPromptEntry(content, model, provider)
data, err := MarshalEntry(entry)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("Failed to unmarshal to raw map: %v", err)
}
if raw["type"] != "system_prompt" {
t.Errorf("Expected type 'system_prompt', got %v", raw["type"])
}
if raw["content"] != content {
t.Errorf("Expected content %q, got %v", content, raw["content"])
}
if raw["model"] != model {
t.Errorf("Expected model %q, got %v", model, raw["model"])
}
if raw["provider"] != provider {
t.Errorf("Expected provider %q, got %v", provider, raw["provider"])
}
if raw["id"] == "" || raw["id"] == nil {
t.Error("Expected non-empty id field")
}
if raw["timestamp"] == "" || raw["timestamp"] == nil {
t.Error("Expected non-empty timestamp field")
}
}
+181
View File
@@ -114,6 +114,187 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
return tm, nil
}
// ForkToNewSession creates a new session file containing the history up to and
// including the target entry ID. This matches Pi's /fork behavior: it creates
// a completely new session file with a parent_session reference, copying all
// entries from the root to the target point.
func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManager, error) {
tm.mu.RLock()
defer tm.mu.RUnlock()
// Get the branch from root to target (root-to-leaf order).
branch := tm.getBranchLocked(targetID)
if len(branch) == 0 {
return nil, fmt.Errorf("target entry %q not found", targetID)
}
// Create a new session file.
newTm, err := CreateTreeSession(cwd)
if err != nil {
return nil, err
}
// Set the parent session reference in the header.
newTm.header.ParentSession = tm.filePath
newTm.header.ParentSessionID = tm.header.ID
// Rewrite the header with the parent reference.
// We need to close and recreate the file to rewrite the header.
if err := newTm.file.Close(); err != nil {
return nil, fmt.Errorf("failed to close new session file: %w", err)
}
// Recreate the file and write the updated header.
f, err := os.Create(newTm.filePath)
if err != nil {
return nil, fmt.Errorf("failed to recreate session file: %w", err)
}
newTm.file = f
if err := newTm.writeEntry(&newTm.header); err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to write session header: %w", err)
}
// Copy entries from the branch to the new session.
// We need to remap IDs since the new session is independent.
idMap := make(map[string]string) // old ID -> new ID
var prevNewID string
for _, entry := range branch {
oldID := tm.EntryID(entry)
newID := GenerateEntryID()
idMap[oldID] = newID
// Create a copy of the entry with the new ID and remapped parent.
var newEntry any
switch e := entry.(type) {
case *MessageEntry:
newEntry = &MessageEntry{
Entry: Entry{
Type: EntryTypeMessage,
ID: newID,
ParentID: prevNewID, // Chain sequentially in new session
Timestamp: e.Timestamp,
},
Role: e.Role,
Parts: e.Parts,
Model: e.Model,
Provider: e.Provider,
}
// Copy label if present.
if label, ok := tm.labels[oldID]; ok {
newTm.labels[newID] = label
}
case *ModelChangeEntry:
newEntry = &ModelChangeEntry{
Entry: Entry{
Type: EntryTypeModelChange,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
Provider: e.Provider,
ModelID: e.ModelID,
}
case *LabelEntry:
// Remap the target ID if it's in our copied branch.
newTargetID := e.TargetID
if mapped, ok := idMap[e.TargetID]; ok {
newTargetID = mapped
}
newEntry = &LabelEntry{
Entry: Entry{
Type: EntryTypeLabel,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
TargetID: newTargetID,
Label: e.Label,
}
case *SessionInfoEntry:
newEntry = &SessionInfoEntry{
Entry: Entry{
Type: EntryTypeSessionInfo,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
Name: e.Name,
}
newTm.sessionName = e.Name
case *ExtensionDataEntry:
newEntry = &ExtensionDataEntry{
Entry: Entry{
Type: EntryTypeExtensionData,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
ExtType: e.ExtType,
Data: e.Data,
}
case *BranchSummaryEntry:
// Remap the from ID if it's in our copied branch.
newFromID := e.FromID
if mapped, ok := idMap[e.FromID]; ok {
newFromID = mapped
}
newEntry = &BranchSummaryEntry{
Entry: Entry{
Type: EntryTypeBranchSummary,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
FromID: newFromID,
Summary: e.Summary,
}
case *CompactionEntry:
// Remap the first kept entry ID if it's in our copied branch.
newFirstKeptID := e.FirstKeptEntryID
if mapped, ok := idMap[e.FirstKeptEntryID]; ok {
newFirstKeptID = mapped
}
newEntry = &CompactionEntry{
Entry: Entry{
Type: EntryTypeCompaction,
ID: newID,
ParentID: prevNewID,
Timestamp: e.Timestamp,
},
Summary: e.Summary,
FirstKeptEntryID: newFirstKeptID,
TokensBefore: e.TokensBefore,
TokensAfter: e.TokensAfter,
MessagesRemoved: e.MessagesRemoved,
ReadFiles: e.ReadFiles,
ModifiedFiles: e.ModifiedFiles,
}
}
if newEntry != nil {
if err := newTm.appendAndPersist(newEntry); err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to copy entry: %w", err)
}
prevNewID = newID
}
}
// Set the leaf to the last entry in the new session.
newTm.leafID = prevNewID
return newTm, nil
}
// OpenTreeSession opens an existing JSONL session file.
func OpenTreeSession(path string) (*TreeManager, error) {
data, err := os.ReadFile(path)
+84 -10
View File
@@ -68,6 +68,7 @@ type MCPConnectionPool struct {
cancel context.CancelFunc
debug bool
debugLogger DebugLogger
oauthFlow *OAuthFlowRunner
}
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
@@ -75,7 +76,7 @@ type MCPConnectionPool struct {
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
}
@@ -90,6 +91,10 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
debug: debug,
}
if authHandler != nil {
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
}
go pool.startHealthCheck()
return pool
}
@@ -103,6 +108,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
p.debugLogger = logger
}
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
// When set, the pool can trigger OAuth re-authorization when a tool call fails
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
p.mu.Lock()
defer p.mu.Unlock()
p.oauthFlow = flow
}
// GetConnection retrieves or creates a connection for the specified MCP server.
// If a healthy, non-idle connection exists in the pool, it will be reused.
// Otherwise, a new connection is created and added to the pool.
@@ -230,18 +244,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
// createConnection creates a new connection
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
client, err := p.createMCPClient(ctx, serverName, serverConfig)
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return nil, err
// SSE transport can return OAuth error during Start()
if p.oauthFlow != nil && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
// Retry after successful auth
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
if err := p.initializeClient(ctx, client); err != nil {
_ = client.Close()
return nil, err
if err := p.initializeClient(ctx, mcpClient); err != nil {
// Streamable HTTP transport returns OAuth error during Initialize()
if p.oauthFlow != nil && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
_ = mcpClient.Close()
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
// Retry initialization after successful auth
if err := p.initializeClient(ctx, mcpClient); err != nil {
_ = mcpClient.Close()
return nil, err
}
} else {
_ = mcpClient.Close()
return nil, err
}
}
conn := &MCPConnection{
client: client,
client: mcpClient,
serverName: serverName,
serverConfig: serverConfig,
lastUsed: time.Now(),
@@ -323,13 +362,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithOAuth(transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := sseClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start SSE client: %v", err)
return nil, fmt.Errorf("failed to start SSE client: %w", err)
}
return sseClient, nil
@@ -354,13 +409,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
if err != nil {
return nil, err
}
if err := streamableClient.Start(ctx); err != nil {
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
}
return streamableClient, nil
@@ -381,7 +452,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
_, err := client.Initialize(initCtx, initRequest)
if err != nil {
return fmt.Errorf("initialization timeout or failed: %v", err)
return fmt.Errorf("initialization timeout or failed: %w", err)
}
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
@@ -539,6 +610,9 @@ func (p *MCPConnectionPool) Close() error {
// isConnectionError checks if the error is connection-related
func isConnectionError(err error) bool {
if IsOAuthError(err) {
return false // OAuth errors are recoverable, not connection failures
}
errStr := err.Error()
return strings.Contains(errStr, "Connection not found") ||
strings.Contains(errStr, "transport error") ||
+24 -3
View File
@@ -59,9 +59,30 @@ func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantas
},
})
if err != nil {
// Mark connection as unhealthy for automatic recovery
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
// Handle OAuth re-authorization: token may have expired mid-session.
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
}
// Retry the tool call after successful re-auth.
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{
Method: "tools/call",
},
Params: mcp.CallToolParams{
Name: t.mapping.originalName,
Arguments: arguments,
},
})
if err != nil {
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
}
} else {
// Mark connection as unhealthy for automatic recovery
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
}
}
// Marshal the MCP result to JSON string
+86 -21
View File
@@ -4,8 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"maps"
"slices"
"strings"
"sync"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
@@ -21,10 +23,16 @@ type MCPToolManager struct {
connectionPool *MCPConnectionPool
tools []fantasy.AgentTool
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
mu sync.Mutex // protects tools and toolMap during parallel loading
model fantasy.LanguageModel // LLM model for sampling
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
config *config.Config
debug bool
debugLogger DebugLogger
// onServerLoaded, if non-nil, is called when each server finishes loading.
// Called with server name, tool count, and error (nil on success).
onServerLoaded func(serverName string, toolCount int, err error)
}
// toolMapping stores the mapping between prefixed tool names and their original details
@@ -53,6 +61,14 @@ func (m *MCPToolManager) SetModel(model fantasy.LanguageModel) {
m.model = model
}
// SetAuthHandler sets the OAuth handler for remote MCP server authentication.
// When set, remote transports (streamable HTTP, SSE) are configured with OAuth
// support, enabling automatic authorization flows when servers require authentication.
// This method should be called before LoadTools.
func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
m.authHandler = handler
}
// SetDebugLogger sets the debug logger for the tool manager.
// The logger will be used to output detailed debugging information about MCP connections,
// tool loading, and execution. If a connection pool exists, it will also be configured
@@ -64,48 +80,87 @@ func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
}
}
// SetOnServerLoaded sets the callback that's invoked when each MCP server finishes
// loading. The callback receives the server name, tool count, and any error.
// Call this before LoadTools to receive per-server notifications.
func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount int, err error)) {
m.onServerLoaded = cb
}
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
// It initializes the connection pool, connects to each configured server, and loads their tools.
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
// Returns an error only if all configured servers fail to load; partial failures are logged as warnings.
// This method is thread-safe and idempotent.
func (m *MCPToolManager) LoadTools(ctx context.Context, config *config.Config) error {
func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) error {
// Initialize connection pool
m.config = config
m.debug = config.Debug
m.config = cfg
m.debug = cfg.Debug
if m.debugLogger == nil {
m.debugLogger = NewSimpleDebugLogger(config.Debug)
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
}
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, config.Debug)
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
m.connectionPool.SetDebugLogger(m.debugLogger)
var loadErrors []string
// Load all servers in parallel. Each server connection (subprocess
// spawn, MCP initialize handshake, ListTools) is independent and
// typically dominated by process startup latency. Running them
// concurrently reduces total wall-clock time from O(n * avg) to
// O(max).
type serverResult struct {
name string
err error
}
for serverName, serverConfig := range config.MCPServers {
if err := m.loadServerTools(ctx, serverName, serverConfig); err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", serverName, err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", serverName, err)
continue
results := make(chan serverResult, len(cfg.MCPServers))
var wg sync.WaitGroup
for serverName, serverConfig := range cfg.MCPServers {
wg.Add(1)
go func(name string, sc config.MCPServerConfig) {
defer wg.Done()
count, err := m.loadServerTools(ctx, name, sc)
results <- serverResult{name: name, err: err}
// Notify callback if set (for real-time UI updates).
if m.onServerLoaded != nil {
m.onServerLoaded(name, count, err)
}
}(serverName, serverConfig)
}
// Close results channel once all goroutines finish.
go func() {
wg.Wait()
close(results)
}()
var loadErrors []string
for r := range results {
if r.err != nil {
loadErrors = append(loadErrors, fmt.Sprintf("server %s: %v", r.name, r.err))
fmt.Printf("Warning: Failed to load MCP server '%s': %v\n", r.name, r.err)
}
}
// If all servers failed to load, return an error
if len(loadErrors) == len(config.MCPServers) && len(config.MCPServers) > 0 {
if len(loadErrors) == len(cfg.MCPServers) && len(cfg.MCPServers) > 0 {
return fmt.Errorf("all MCP servers failed to load: %s", strings.Join(loadErrors, "; "))
}
return nil
}
// loadServerTools loads tools from a single MCP server
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) error {
// loadServerTools loads tools from a single MCP server.
// Thread-safe: may be called concurrently for different servers.
// Returns the number of tools loaded from this server, or -1 on error.
func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (int, error) {
// Add debug logging
m.debugLogConnectionInfo(serverName, serverConfig)
// Get connection from pool
conn, err := m.connectionPool.GetConnection(ctx, serverName, serverConfig)
if err != nil {
return fmt.Errorf("failed to get connection from pool: %v", err)
return -1, fmt.Errorf("failed to get connection from pool: %v", err)
}
// Get tools from this server
@@ -113,7 +168,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
if err != nil {
// Handle connection error
m.connectionPool.HandleConnectionError(serverName, err)
return fmt.Errorf("failed to list tools: %v", err)
return -1, fmt.Errorf("failed to list tools: %v", err)
}
// Create name set for allowed tools
@@ -125,6 +180,10 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
}
}
// Build tools locally before acquiring the lock.
var localTools []fantasy.AgentTool
localMap := make(map[string]*toolMapping)
// Convert MCP tools to fantasy AgentTools with prefixed names
for _, mcpTool := range listResults.Tools {
// Filter tools based on allowedTools/excludedTools
@@ -142,7 +201,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Convert MCP InputSchema to map[string]any for fantasy ToolInfo
marshaledSchema, err := json.Marshal(mcpTool.InputSchema)
if err != nil {
return fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(marshal): %w, tool name: %s", err, mcpTool.Name)
}
// Fix for JSON Schema draft-07 vs draft-04 compatibility
@@ -151,7 +210,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
// Parse into map[string]any for fantasy's parameters format
var schemaMap map[string]any
if err := json.Unmarshal(marshaledSchema, &schemaMap); err != nil {
return fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
return -1, fmt.Errorf("conv mcp tool input schema fail(unmarshal): %w, tool name: %s", err, mcpTool.Name)
}
// Extract properties and required from the schema
@@ -184,7 +243,7 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
serverConfig: serverConfig,
manager: m,
}
m.toolMap[prefixedName] = mapping
localMap[prefixedName] = mapping
// Create fantasy AgentTool
fantasyTool := &mcpFantasyTool{
@@ -197,10 +256,16 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
mapping: mapping,
}
m.tools = append(m.tools, fantasyTool)
localTools = append(localTools, fantasyTool)
}
return nil
// Merge into the manager under the lock.
m.mu.Lock()
maps.Copy(m.toolMap, localMap)
m.tools = append(m.tools, localTools...)
m.mu.Unlock()
return len(localTools), nil
}
// GetTools returns all loaded tools as fantasy AgentTools from all configured MCP servers.
+109
View File
@@ -0,0 +1,109 @@
package tools
import (
"context"
"fmt"
"net/url"
"github.com/mark3labs/mcp-go/client"
)
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
// The SDK-level kit.MCPAuthHandler is adapted to this interface in cmd/root.go
// or pkg/kit/kit.go, keeping the tools package decoupled from the SDK.
type MCPAuthHandler interface {
// RedirectURI returns the OAuth redirect URI for transport setup.
RedirectURI() string
// HandleAuth is called when a server requires OAuth authorization.
// It receives the server name and the authorization URL the user must visit.
// It returns the full callback URL (containing code and state query params)
// after the user completes authorization.
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
// registration, PKCE generation, user authorization (via MCPAuthHandler),
// and token exchange.
type OAuthFlowRunner struct {
handler MCPAuthHandler
}
// NewOAuthFlowRunner creates a new OAuthFlowRunner with the given auth handler.
func NewOAuthFlowRunner(handler MCPAuthHandler) *OAuthFlowRunner {
return &OAuthFlowRunner{handler: handler}
}
// RunAuthFlow executes the OAuth authorization flow for the given server.
// It extracts the OAuthHandler from the error, performs dynamic client registration
// if needed, generates PKCE parameters, delegates to the MCPAuthHandler for user
// interaction, and exchanges the authorization code for a token.
func (r *OAuthFlowRunner) RunAuthFlow(ctx context.Context, serverName string, authErr error) error {
// Extract the OAuthHandler from the authorization-required error.
oauthHandler := client.GetOAuthHandler(authErr)
if oauthHandler == nil {
return fmt.Errorf("oauth flow: failed to extract OAuth handler from error: %w", authErr)
}
// Perform dynamic client registration if no client ID is configured yet.
if oauthHandler.GetClientID() == "" {
if err := oauthHandler.RegisterClient(ctx, "kit"); err != nil {
return fmt.Errorf("oauth flow: dynamic client registration failed: %w", err)
}
}
// Generate PKCE code verifier and challenge.
codeVerifier, err := client.GenerateCodeVerifier()
if err != nil {
return fmt.Errorf("oauth flow: failed to generate code verifier: %w", err)
}
codeChallenge := client.GenerateCodeChallenge(codeVerifier)
// Generate a random state parameter for CSRF protection.
state, err := client.GenerateState()
if err != nil {
return fmt.Errorf("oauth flow: failed to generate state: %w", err)
}
// Build the authorization URL the user needs to visit.
authURL, err := oauthHandler.GetAuthorizationURL(ctx, state, codeChallenge)
if err != nil {
return fmt.Errorf("oauth flow: failed to get authorization URL: %w", err)
}
// Delegate to the MCPAuthHandler for user-facing authorization (e.g. open
// browser, wait for redirect). It returns the full callback URL containing
// the authorization code and state.
callbackURL, err := r.handler.HandleAuth(ctx, serverName, authURL)
if err != nil {
return fmt.Errorf("oauth flow: user authorization failed: %w", err)
}
// Parse the callback URL to extract the authorization code and state.
parsed, err := url.Parse(callbackURL)
if err != nil {
return fmt.Errorf("oauth flow: failed to parse callback URL: %w", err)
}
code := parsed.Query().Get("code")
returnedState := parsed.Query().Get("state")
if code == "" {
return fmt.Errorf("oauth flow: callback URL missing 'code' parameter")
}
if returnedState == "" {
return fmt.Errorf("oauth flow: callback URL missing 'state' parameter")
}
// Exchange the authorization code for an access token.
if err := oauthHandler.ProcessAuthorizationResponse(ctx, code, returnedState, codeVerifier); err != nil {
return fmt.Errorf("oauth flow: token exchange failed: %w", err)
}
return nil
}
// IsOAuthError returns true if the error is an OAuthAuthorizationRequiredError.
func IsOAuthError(err error) bool {
return client.IsOAuthAuthorizationRequiredError(err)
}
+155
View File
@@ -0,0 +1,155 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"github.com/mark3labs/mcp-go/client/transport"
)
// Compile-time check that FileTokenStore implements transport.TokenStore.
var _ transport.TokenStore = (*FileTokenStore)(nil)
// FileTokenStore is a file-backed implementation of transport.TokenStore that
// persists OAuth tokens as JSON on disk. Tokens are stored in a shared JSON file
// keyed by server URL, allowing multiple MCP servers to maintain independent tokens.
//
// The token file is located at $XDG_CONFIG_HOME/.kit/mcp_tokens.json, falling back
// to ~/.config/.kit/mcp_tokens.json when XDG_CONFIG_HOME is not set.
//
// FileTokenStore is safe for concurrent use.
type FileTokenStore struct {
serverKey string
filePath string
mu sync.RWMutex
}
// NewFileTokenStore creates a new FileTokenStore for the given server URL.
// The serverKey is used as the map key in the shared token file, and should
// typically be the MCP server's base URL.
//
// Returns an error if the token file path cannot be resolved.
func NewFileTokenStore(serverKey string) (*FileTokenStore, error) {
filePath, err := resolveTokenFilePath()
if err != nil {
return nil, fmt.Errorf("resolving token file path: %w", err)
}
return &FileTokenStore{
serverKey: serverKey,
filePath: filePath,
}, nil
}
// GetToken returns the stored token for this store's server key.
// Returns transport.ErrNoToken if no token exists for the server key or if
// the token file does not yet exist.
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
func (s *FileTokenStore) GetToken(ctx context.Context) (*transport.Token, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
s.mu.RLock()
defer s.mu.RUnlock()
tokens, err := readTokenFile(s.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil, transport.ErrNoToken
}
return nil, fmt.Errorf("reading token file: %w", err)
}
token, ok := tokens[s.serverKey]
if !ok {
return nil, transport.ErrNoToken
}
return token, nil
}
// SaveToken persists the given token for this store's server key.
// If the token file or its parent directories do not exist, they are created.
// Existing tokens for other server keys are preserved.
// Returns context.Canceled or context.DeadlineExceeded if the context is done.
func (s *FileTokenStore) SaveToken(ctx context.Context, token *transport.Token) error {
if err := ctx.Err(); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
tokens, err := readTokenFile(s.filePath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("reading token file: %w", err)
}
if tokens == nil {
tokens = make(map[string]*transport.Token)
}
tokens[s.serverKey] = token
if err := writeTokenFile(s.filePath, tokens); err != nil {
return fmt.Errorf("writing token file: %w", err)
}
return nil
}
// resolveTokenFilePath determines the path to the token file using
// XDG_CONFIG_HOME if set, otherwise falling back to ~/.config/.kit/.
func resolveTokenFilePath() (string, error) {
configDir := os.Getenv("XDG_CONFIG_HOME")
if configDir == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("determining user home directory: %w", err)
}
configDir = filepath.Join(home, ".config")
}
return filepath.Join(configDir, ".kit", "mcp_tokens.json"), nil
}
// readTokenFile reads and unmarshals the token file into a server-keyed map.
// Returns os.ErrNotExist (via os.IsNotExist) if the file does not exist.
func readTokenFile(path string) (map[string]*transport.Token, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var tokens map[string]*transport.Token
if err := json.Unmarshal(data, &tokens); err != nil {
return nil, fmt.Errorf("unmarshaling token file: %w", err)
}
return tokens, nil
}
// writeTokenFile marshals the token map and writes it to disk, creating
// parent directories as needed. The file is written with 0600 permissions
// to protect sensitive token data.
func writeTokenFile(path string, tokens map[string]*transport.Token) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("creating token directory %s: %w", dir, err)
}
data, err := json.MarshalIndent(tokens, "", " ")
if err != nil {
return fmt.Errorf("marshaling tokens: %w", err)
}
if err := os.WriteFile(path, data, 0600); err != nil {
return fmt.Errorf("writing token file %s: %w", path, err)
}
return nil
}
+3 -1
View File
@@ -4,6 +4,8 @@ import (
"image/color"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/style"
)
// blockRenderer handles rendering of content blocks with configurable options
@@ -175,7 +177,7 @@ func renderContentBlock(content string, containerWidth int, options ...rendering
borderChars = 1
}
theme := GetTheme()
theme := style.GetTheme()
// Resolve foreground color: caller override or theme default.
fgColor := theme.Text
+19 -150
View File
@@ -6,6 +6,7 @@ import (
tea "charm.land/bubbletea/v2"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/ui/core"
)
// ==========================================================================
@@ -59,7 +60,7 @@ func TestInputComponent_SubmitEmitsSubmitMsg(t *testing.T) {
t.Fatal("expected a cmd from pressing enter on non-empty input")
}
sm, ok := msg.(submitMsg)
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("expected submitMsg, got %T", msg)
}
@@ -83,7 +84,7 @@ func TestInputComponent_CtrlD_SubmitEmitsSubmitMsg(t *testing.T) {
if msg == nil {
t.Fatal("expected a cmd from ctrl+d on non-empty input")
}
sm, ok := msg.(submitMsg)
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("expected submitMsg from ctrl+d, got %T", msg)
}
@@ -175,7 +176,7 @@ func TestInputComponent_ClearForwardsAsSubmitMsg(t *testing.T) {
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
}
msg := runCmd(cmd)
sm, ok := msg.(submitMsg)
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
}
@@ -230,7 +231,7 @@ func TestInputComponent_ClearQueue_ForwardsAsSubmitMsg(t *testing.T) {
t.Fatalf("%s: expected submitMsg cmd, got nil", alias)
}
msg := runCmd(cmd)
sm, ok := msg.(submitMsg)
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("%s: expected submitMsg, got %T", alias, msg)
}
@@ -258,7 +259,7 @@ func TestInputComponent_UnknownSlashCommand_ForwardsAsSubmit(t *testing.T) {
if msg == nil {
t.Fatal("expected submitMsg for unknown slash command")
}
sm, ok := msg.(submitMsg)
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("expected submitMsg for unknown slash command, got %T", msg)
}
@@ -701,167 +702,35 @@ func TestStreamComponent_StaleFlushTick_Discarded(t *testing.T) {
// TestStreamComponent_ConsumeOverflow_NoHeight verifies that when height is
// unconstrained (0), ConsumeOverflow always returns "".
func TestStreamComponent_ConsumeOverflow_NoHeight(t *testing.T) {
func TestStreamComponent_ConsumeOverflow_NoOp(t *testing.T) {
c := newTestStream()
// Commit some content directly.
c.streamContent.WriteString("line1\nline2\nline3")
c.phase = streamPhaseActive
c.renderDirty = true
// ConsumeOverflow is a no-op in alt screen mode — always returns "".
if got := c.ConsumeOverflow(); got != "" {
t.Fatalf("expected empty with height=0, got %q", got)
t.Fatalf("expected empty from no-op ConsumeOverflow, got %q", got)
}
}
// TestStreamComponent_ConsumeOverflow_NoOverflow verifies that when content fits
// within the allocated height, ConsumeOverflow returns "".
func TestStreamComponent_ConsumeOverflow_NoOverflow(t *testing.T) {
c := newTestStream()
c.streamContent.WriteString("line1\nline2")
c.phase = streamPhaseActive
c.renderDirty = true
c.height = 20 // plenty of room
// Also returns "" with a height set.
c.height = 2
if got := c.ConsumeOverflow(); got != "" {
t.Fatalf("expected empty when content fits, got %q", got)
t.Fatalf("expected empty from no-op ConsumeOverflow with height, got %q", got)
}
}
// TestStreamComponent_ConsumeOverflow_EmitsTopLines verifies that when the
// rendered content has more lines than the allocated height, ConsumeOverflow
// returns the top overflow lines and advances the internal pointer.
func TestStreamComponent_ConsumeOverflow_EmitsTopLines(t *testing.T) {
// TestStreamComponent_GetRenderedContent_ReturnsAll verifies that
// GetRenderedContent returns all accumulated content.
func TestStreamComponent_GetRenderedContent_ReturnsAll(t *testing.T) {
c := newTestStream()
c.height = 2
// Build raw content that when "rendered" (plain text for this test)
// is 5 lines — we bypass the markdown renderer by writing directly to
// streamContent and using a nil renderer.
c.renderer = nil
c.phase = streamPhaseActive
c.streamContent.WriteString("a\nb\nc\nd\ne")
c.phase = streamPhaseActive
c.renderDirty = true
// First call: should return lines a, b, c (5 lines - 2 visible = 3 overflow).
overflow1 := c.ConsumeOverflow()
if overflow1 == "" {
t.Fatal("expected overflow, got empty")
}
overflowLines := strings.Split(overflow1, "\n")
if len(overflowLines) != 3 {
t.Fatalf("expected 3 overflow lines, got %d: %q", len(overflowLines), overflow1)
}
if overflowLines[0] != "a" || overflowLines[1] != "b" || overflowLines[2] != "c" {
t.Fatalf("unexpected overflow lines: %v", overflowLines)
}
// Second call without new content should return "" (pointer already advanced).
overflow2 := c.ConsumeOverflow()
if overflow2 != "" {
t.Fatalf("expected empty on second call, got %q", overflow2)
}
}
// TestStreamComponent_ConsumeOverflow_IncrementalFlush verifies that as new
// content arrives, ConsumeOverflow incrementally returns only newly overflowed
// lines on each call.
func TestStreamComponent_ConsumeOverflow_IncrementalFlush(t *testing.T) {
c := newTestStream()
c.height = 2
c.renderer = nil
c.phase = streamPhaseActive
// Start with 3 lines — 1 overflows.
c.streamContent.WriteString("a\nb\nc")
c.renderDirty = true
overflow1 := c.ConsumeOverflow()
if overflow1 != "a" {
t.Fatalf("expected 'a', got %q", overflow1)
}
// Add 2 more lines — 2 additional overflows.
c.streamContent.WriteString("\nd\ne")
c.renderDirty = true
overflow2 := c.ConsumeOverflow()
want := "b\nc"
if overflow2 != want {
t.Fatalf("expected %q, got %q", want, overflow2)
}
}
// TestStreamComponent_ConsumeOverflow_ResetClearsPointer verifies that Reset()
// resets the scrollback pointer so the next response starts fresh.
func TestStreamComponent_ConsumeOverflow_ResetClearsPointer(t *testing.T) {
c := newTestStream()
c.height = 1
c.renderer = nil
c.phase = streamPhaseActive
c.streamContent.WriteString("a\nb")
c.renderDirty = true
overflow := c.ConsumeOverflow()
if overflow != "a" {
t.Fatalf("expected 'a', got %q", overflow)
}
c.Reset()
if c.scrollbackFlushedLines != 0 {
t.Fatalf("expected scrollbackFlushedLines=0 after Reset, got %d", c.scrollbackFlushedLines)
}
}
// TestStreamComponent_GetRenderedContent_SkipsFlushedLines verifies that
// GetRenderedContent skips lines already emitted via ConsumeOverflow so the
// caller doesn't re-print content already in the terminal scrollback.
func TestStreamComponent_GetRenderedContent_SkipsFlushedLines(t *testing.T) {
c := newTestStream()
c.height = 2
c.renderer = nil
c.phase = streamPhaseActive
// 5 lines → 3 overflow, 2 visible.
c.streamContent.WriteString("a\nb\nc\nd\ne")
c.renderDirty = true
// Consume the overflow: lines a, b, c.
overflow := c.ConsumeOverflow()
if overflow != "a\nb\nc" {
t.Fatalf("expected 'a\\nb\\nc', got %q", overflow)
}
if c.scrollbackFlushedLines != 3 {
t.Fatalf("expected flushedLines=3, got %d", c.scrollbackFlushedLines)
}
// GetRenderedContent should only return the non-flushed portion: d, e.
got := c.GetRenderedContent()
if got != "d\ne" {
t.Fatalf("expected 'd\\ne', got %q", got)
}
}
// TestStreamComponent_GetRenderedContent_AllFlushed verifies that when all
// lines have been pushed via ConsumeOverflow, GetRenderedContent returns "".
func TestStreamComponent_GetRenderedContent_AllFlushed(t *testing.T) {
c := newTestStream()
c.height = 1
c.renderer = nil
c.phase = streamPhaseActive
// 2 lines → height=1, so 1 overflow.
c.streamContent.WriteString("a\nb")
c.renderDirty = true
// Consume overflow (line a), leaving 1 visible line (b).
_ = c.ConsumeOverflow()
// Now bump height so everything overflows — simulate a resize that made
// the viewable area 0, forcing all content to be "flushed".
c.scrollbackFlushedLines = 2 // pretend both lines were flushed
got := c.GetRenderedContent()
if got != "" {
t.Fatalf("expected empty when all lines flushed, got %q", got)
if got != "a\nb\nc\nd\ne" {
t.Fatalf("expected full content, got %q", got)
}
}
+3 -29
View File
@@ -5,9 +5,10 @@ import (
"os"
"time"
"charm.land/fantasy"
"charm.land/lipgloss/v2"
"golang.org/x/term"
"github.com/mark3labs/kit/internal/ui/style"
)
// CLI manages the command-line interface for KIT, providing message rendering,
@@ -125,7 +126,7 @@ func (c *CLI) DisplayInfo(message string) {
// DisplayExtensionBlock renders a custom styled block with the given border
// color and optional subtitle. Used by extensions via ctx.PrintBlock.
func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
theme := GetTheme()
theme := style.GetTheme()
borderClr := theme.Info
if borderColor != "" {
@@ -171,33 +172,6 @@ func (c *CLI) DisplayDebugConfig(config map[string]any) {
fmt.Println(c.renderer.RenderDebugConfigMessage(config, time.Now()).Content)
}
// UpdateUsageFromResponse records token usage using metadata from the fantasy
// response. Only actual API-reported tokens are used for cost tracking.
// If the provider doesn't report token counts, no usage is recorded.
func (c *CLI) UpdateUsageFromResponse(response *fantasy.Response, inputText string) {
if c.usageTracker == nil {
return
}
usage := response.Usage
inputTokens := int(usage.InputTokens)
outputTokens := int(usage.OutputTokens)
// Only use actual API-reported tokens for cost tracking.
// We intentionally do NOT estimate tokens - estimation is inaccurate
// and should never be used for cost calculations.
if inputTokens > 0 {
cacheReadTokens := int(usage.CacheReadTokens)
cacheWriteTokens := int(usage.CacheCreationTokens)
c.usageTracker.UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens)
// Per-response usage is a single API call, so it represents the
// actual context window fill level.
c.usageTracker.SetContextTokens(inputTokens + outputTokens)
}
// If inputTokens is 0, the provider didn't report usage - we skip recording
// rather than estimating, to ensure cost accuracy.
}
// DisplayUsageAfterResponse renders and displays token usage information immediately
// following an AI response. This provides real-time feedback about the cost and
// token consumption of each interaction.
-96
View File
@@ -1,96 +0,0 @@
package ui
import (
"fmt"
"runtime"
tea "charm.land/bubbletea/v2"
"github.com/atotto/clipboard"
)
// CopyToClipboard writes text to both the system clipboard and via OSC 52.
// Returns a tea.Cmd that can be used in Bubble Tea's Update flow.
func CopyToClipboard(text string) tea.Cmd {
if text == "" {
return nil
}
return tea.Sequence(
// Method 1: OSC 52 escape sequence (works in modern terminals)
tea.SetClipboard(text),
// Method 2: Native system clipboard (atotto/clipboard)
func() tea.Msg {
// Best effort - ignore errors
_ = clipboard.WriteAll(text)
return nil
},
)
}
// CopyToClipboardWithMessage writes text to clipboard and returns a toast notification.
func CopyToClipboardWithMessage(text string, message string) tea.Cmd {
if text == "" {
return nil
}
return tea.Sequence(
CopyToClipboard(text),
func() tea.Msg {
return ToastMsg{Message: message, Type: ToastInfo}
},
)
}
// ToastType represents the type of toast notification.
type ToastType int
const (
ToastInfo ToastType = iota
ToastSuccess
ToastWarning
ToastError
)
// ToastMsg is a message to display a toast notification.
type ToastMsg struct {
Message string
Type ToastType
}
// IsClipboardSupported returns true if the clipboard is supported on this platform.
func IsClipboardSupported() bool {
// atotto/clipboard supports Linux (with xclip or xsel), macOS, Windows
switch runtime.GOOS {
case "darwin", "windows":
return true
case "linux":
// Check if xclip or xsel is available
// This is a best-effort check
return true
default:
return false
}
}
// CopySelection represents a text selection with start/end positions.
type CopySelection struct {
StartItemIdx int // Index of item where selection starts
StartLine int // Line within item where selection starts
StartCol int // Column where selection starts
EndItemIdx int // Index of item where selection ends
EndLine int // Line within item where selection ends
EndCol int // Column where selection ends
Active bool // Whether selection is currently active
}
// IsEmpty returns true if the selection has no content.
func (s CopySelection) IsEmpty() bool {
return !s.Active || (s.StartItemIdx == s.EndItemIdx && s.StartLine == s.EndLine && s.StartCol == s.EndCol)
}
// String returns a string representation for debugging.
func (s CopySelection) String() string {
return fmt.Sprintf("Selection{item:%d-%d, line:%d-%d, col:%d-%d, active:%v}",
s.StartItemIdx, s.EndItemIdx, s.StartLine, s.EndLine, s.StartCol, s.EndCol, s.Active)
}
+26
View File
@@ -0,0 +1,26 @@
package clipboard
import (
tea "charm.land/bubbletea/v2"
"github.com/atotto/clipboard"
)
// CopyToClipboard writes text to both the system clipboard and via OSC 52.
// Returns a tea.Cmd that can be used in Bubble Tea's Update flow.
func CopyToClipboard(text string) tea.Cmd {
if text == "" {
return nil
}
return tea.Sequence(
// Method 1: OSC 52 escape sequence (works in modern terminals)
tea.SetClipboard(text),
// Method 2: Native system clipboard (atotto/clipboard)
func() tea.Msg {
// Best effort - ignore errors
_ = clipboard.WriteAll(text)
return nil
},
)
}
@@ -1,4 +1,4 @@
package ui
package commands
import (
"slices"
@@ -7,6 +7,10 @@ import (
"github.com/mark3labs/kit/internal/models"
)
// ListThemesFunc is set by the ui package to provide theme name completion.
// This breaks the circular dependency between commands and ui packages.
var ListThemesFunc func() []string
// SlashCommand represents a user-invokable slash command with its metadata.
// Commands can have multiple aliases and are organized by category for better
// discoverability and help display.
@@ -99,7 +103,10 @@ var SlashCommands = []SlashCommand{
Description: "Switch color theme (e.g. /theme catppuccin)",
Category: "System",
Complete: func(prefix string) []string {
names := ListThemes()
if ListThemesFunc == nil {
return nil
}
names := ListThemesFunc()
if prefix == "" {
return names
}
@@ -112,6 +119,12 @@ var SlashCommands = []SlashCommand{
return matches
},
},
{
Name: "/reload-ext",
Description: "Hot-reload all extensions from disk",
Category: "System",
Aliases: []string{"/re"},
},
{
Name: "/quit",
Description: "Exit the application",
@@ -127,7 +140,7 @@ var SlashCommands = []SlashCommand{
},
{
Name: "/fork",
Description: "Branch from an earlier message",
Description: "Fork to new session from an earlier message",
Category: "Navigation",
},
{
@@ -1,4 +1,4 @@
package ui
package core
// ImageAttachment holds a clipboard image that will be sent alongside the
// user's text prompt to the LLM. The data is raw image bytes; MediaType is
@@ -10,9 +10,9 @@ type ImageAttachment struct {
MediaType string
}
// submitMsg is sent by the InputComponent when the user submits a text prompt.
// SubmitMsg is sent by the InputComponent when the user submits a text prompt.
// The parent model receives this and calls app.Run(Text) to start agent processing.
type submitMsg struct {
type SubmitMsg struct {
// Text is the user's input text to send to the agent.
Text string
// Images holds clipboard image attachments to send alongside the text.
@@ -20,10 +20,10 @@ type submitMsg struct {
Images []ImageAttachment
}
// cancelTimerExpiredMsg is sent by the tea.Tick command that starts when the user
// CancelTimerExpiredMsg is sent by the tea.Tick command that starts when the user
// presses ESC once during stateWorking. If this message arrives before the user
// presses ESC a second time, the canceling state is reset to false.
type cancelTimerExpiredMsg struct{}
type CancelTimerExpiredMsg struct{}
// --- Tree session events ---
@@ -42,14 +42,14 @@ type TreeNodeSelectedMsg struct {
// TreeCancelledMsg is sent when the user cancels the tree selector (ESC).
type TreeCancelledMsg struct{}
// shellCommandMsg is sent by the InputComponent when the user submits a
// ShellCommandMsg is sent by the InputComponent when the user submits a
// ! or !! prefixed command. The parent model intercepts this to execute
// the shell command directly instead of forwarding to the LLM.
//
// Matching pi's behavior:
// - !cmd → run shell command, output INCLUDED in LLM context
// - !!cmd → run shell command, output EXCLUDED from LLM context
type shellCommandMsg struct {
type ShellCommandMsg struct {
// Command is the shell command to execute (prefix stripped).
Command string
// ExcludeFromContext is true for !! (output excluded from LLM context),
@@ -57,9 +57,9 @@ type shellCommandMsg struct {
ExcludeFromContext bool
}
// shellCommandResultMsg carries the result of a shell command execution
// ShellCommandResultMsg carries the result of a shell command execution
// back to the parent model for display.
type shellCommandResultMsg struct {
type ShellCommandResultMsg struct {
// Command is the original shell command that was executed.
Command string
// Output is the combined stdout/stderr output.
@@ -68,6 +68,6 @@ type shellCommandResultMsg struct {
ExitCode int
// Err is non-nil if the command failed to start or timed out.
Err error
// ExcludeFromContext mirrors the flag from shellCommandMsg.
// ExcludeFromContext mirrors the flag from ShellCommandMsg.
ExcludeFromContext bool
}
+62
View File
@@ -0,0 +1,62 @@
package ui
// This file re-exports types from subpackages for backward compatibility.
// External importers can continue using ui.XXX without needing to import
// from subpackages directly.
import (
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/ui/core"
"github.com/mark3labs/kit/internal/ui/fileutil"
"github.com/mark3labs/kit/internal/ui/prefs"
"github.com/mark3labs/kit/internal/ui/style"
)
// Re-export from core package
type (
ImageAttachment = core.ImageAttachment
SubmitMsg = core.SubmitMsg
CancelTimerExpiredMsg = core.CancelTimerExpiredMsg
TreeNodeSelectedMsg = core.TreeNodeSelectedMsg
TreeCancelledMsg = core.TreeCancelledMsg
ShellCommandMsg = core.ShellCommandMsg
ShellCommandResultMsg = core.ShellCommandResultMsg
)
// Re-export from commands package
type (
SlashCommand = commands.SlashCommand
ExtensionCommand = commands.ExtensionCommand
)
// Re-export functions from fileutil package
var ProcessFileAttachments = fileutil.ProcessFileAttachments
// Re-export from prefs package
var (
LoadThemePreference = prefs.LoadThemePreference
SaveThemePreference = prefs.SaveThemePreference
LoadModelPreference = prefs.LoadModelPreference
SaveModelPreference = prefs.SaveModelPreference
LoadThinkingLevelPreference = prefs.LoadThinkingLevelPreference
SaveThinkingLevelPreference = prefs.SaveThinkingLevelPreference
)
// Re-export from style package
type (
Theme = style.Theme
MarkdownThemeColors = style.MarkdownThemeColors
)
var (
GetTheme = style.GetTheme
SetTheme = style.SetTheme
DefaultTheme = style.DefaultTheme
ApplyTheme = style.ApplyTheme
ApplyThemeWithoutSave = style.ApplyThemeWithoutSave
ListThemes = style.ListThemes
RegisterThemeFromConfig = style.RegisterThemeFromConfig
KitBanner = style.KitBanner
AdaptiveColor = style.AdaptiveColor
IsDarkBackground = style.IsDarkBackground
)
@@ -1,4 +1,4 @@
package ui
package fileutil
import (
"fmt"
+3 -1
View File
@@ -5,6 +5,8 @@ import (
"time"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/style"
)
// Renderer is the interface satisfied by MessageRenderer. It allows model.go
@@ -30,7 +32,7 @@ var _ Renderer = (*MessageRenderer)(nil)
// combined, styled output string with tags stripped.
//
// Shared by MessageRenderer.
func parseBashOutput(result string, theme Theme) string {
func parseBashOutput(result string, theme style.Theme) string {
var formattedResult strings.Builder
remaining := result
+5 -3
View File
@@ -2,20 +2,22 @@ package ui
import (
"strings"
"github.com/mark3labs/kit/internal/ui/commands"
)
// FuzzyMatch represents the result of a fuzzy string matching operation,
// containing the matched command and its relevance score. Higher scores
// indicate better matches.
type FuzzyMatch struct {
Command *SlashCommand
Command *commands.SlashCommand
Score int
}
// FuzzyMatchCommands performs fuzzy string matching on the provided slash commands
// based on the query string. Returns a slice of matches sorted by relevance score
// in descending order. An empty query returns all commands with zero scores.
func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
func FuzzyMatchCommands(query string, commands []commands.SlashCommand) []FuzzyMatch {
if query == "" || query == "/" {
// Return all commands when query is empty or just "/"
matches := make([]FuzzyMatch, len(commands))
@@ -57,7 +59,7 @@ func FuzzyMatchCommands(query string, commands []SlashCommand) []FuzzyMatch {
}
// fuzzyScore calculates the fuzzy match score for a command
func fuzzyScore(query string, cmd *SlashCommand) int {
func fuzzyScore(query string, cmd *commands.SlashCommand) int {
// Check exact match first
cmdName := strings.ToLower(strings.TrimPrefix(cmd.Name, "/"))
if cmdName == query {
+92 -63
View File
@@ -10,6 +10,9 @@ import (
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/clipboard"
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/ui/core"
"github.com/mark3labs/kit/internal/ui/style"
)
// InputComponent is the interactive text input field for the parent AppModel.
@@ -29,7 +32,7 @@ import (
// app.Run().
type InputComponent struct {
textarea textarea.Model
commands []SlashCommand
commands []commands.SlashCommand
showPopup bool
filtered []FuzzyMatch
selected int
@@ -42,17 +45,17 @@ type InputComponent struct {
// Argument completion state. When the user types "/cmd " followed by
// a partial argument and the command has a Complete function, the popup
// switches to argument-completion mode showing suggestions from Complete.
argMode bool // true when showing arg completions
argCommand string // command prefix for arg mode (e.g. "/bookmark")
argSynthCmds []SlashCommand // backing storage for synthetic arg entries
argMode bool // true when showing arg completions
argCommand string // command prefix for arg mode (e.g. "/bookmark")
argSynthCmds []commands.SlashCommand // backing storage for synthetic arg entries
// File completion state. When the user types @ followed by a partial
// file path, the popup shows file/directory suggestions from the cwd.
fileMode bool // true when showing @file completions
filePrefix string // current text after @ being matched
fileAtStartIdx int // byte offset of @ in the textarea value
fileSuggestions []FileSuggestion // backing storage for file entries
fileSynthCmds []SlashCommand // synthetic SlashCommands wrapping file entries
fileMode bool // true when showing @file completions
filePrefix string // current text after @ being matched
fileAtStartIdx int // byte offset of @ in the textarea value
fileSuggestions []FileSuggestion // backing storage for file entries
fileSynthCmds []commands.SlashCommand // synthetic commands.SlashCommands wrapping file entries
// cwd is the working directory used for @file path resolution and
// autocomplete suggestions. Set by the parent via SetCwd.
@@ -71,7 +74,7 @@ type InputComponent struct {
// pendingImages holds clipboard images attached to the next submission.
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
pendingImages []ImageAttachment
pendingImages []core.ImageAttachment
// history stores previously submitted prompts (most recent last).
// Limited to maxHistory entries; duplicates of the previous entry are
@@ -94,7 +97,7 @@ const maxHistory = 100
// clipboardImageMsg is the result of an async clipboard image read.
type clipboardImageMsg struct {
image *ImageAttachment
image *core.ImageAttachment
err error
}
@@ -119,7 +122,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
)
// Style the textarea using theme colors.
theme := GetTheme()
theme := style.GetTheme()
styles := ta.Styles()
styles.Focused.Base = lipgloss.NewStyle()
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
@@ -130,7 +133,7 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
return &InputComponent{
textarea: ta,
commands: SlashCommands,
commands: commands.SlashCommands,
width: width,
popupHeight: 7,
title: title,
@@ -329,7 +332,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s.filePrefix = prefix
s.fileAtStartIdx = atIdx
s.fileSuggestions = suggestions
s.fileSynthCmds = make([]SlashCommand, len(suggestions))
s.fileSynthCmds = make([]commands.SlashCommand, len(suggestions))
s.filtered = make([]FuzzyMatch, len(suggestions))
for i, fs := range suggestions {
name := fs.RelPath
@@ -337,7 +340,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if fs.IsDir {
desc = "directory"
}
s.fileSynthCmds[i] = SlashCommand{Name: name, Description: desc}
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
}
s.selected = 0
@@ -396,14 +399,14 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
cmd := strings.TrimSpace(trimmed[2:])
if cmd != "" {
return func() tea.Msg {
return shellCommandMsg{Command: cmd, ExcludeFromContext: true}
return core.ShellCommandMsg{Command: cmd, ExcludeFromContext: true}
}
}
} else if strings.HasPrefix(trimmed, "!") {
cmd := strings.TrimSpace(trimmed[1:])
if cmd != "" {
return func() tea.Msg {
return shellCommandMsg{Command: cmd, ExcludeFromContext: false}
return core.ShellCommandMsg{Command: cmd, ExcludeFromContext: false}
}
}
}
@@ -411,9 +414,9 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
// Resolve via canonical command lookup so aliases are handled uniformly.
// Only /quit is handled locally — all other slash commands (including
// /clear and /clear-queue) are forwarded to the parent model via
// submitMsg so the parent can update its own state (scrollback, queue
// submitMsg so the parent can update its own state (ScrollList, queue
// counts, etc.) in one place.
if sc := GetCommandByName(trimmed); sc != nil {
if sc := commands.GetCommandByName(trimmed); sc != nil {
switch sc.Name {
case "/quit":
return tea.Quit
@@ -426,7 +429,7 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
images := s.pendingImages
s.pendingImages = nil
return func() tea.Msg {
return submitMsg{Text: trimmed, Images: images}
return core.SubmitMsg{Text: trimmed, Images: images}
}
}
@@ -463,7 +466,7 @@ func (s *InputComponent) resetHistoryBrowsing() {
func (s *InputComponent) View() tea.View {
containerStyle := lipgloss.NewStyle()
theme := GetTheme()
theme := style.GetTheme()
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
titleStyle := lipgloss.NewStyle().
@@ -531,14 +534,7 @@ func (s *InputComponent) View() tea.View {
view.WriteString(helpStyle.Render(hint))
}
v := tea.NewView(containerStyle.Render(view.String()))
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
return v
return tea.NewView(containerStyle.Render(view.String()))
}
// renderPopup renders the autocomplete popup for slash command suggestions.
@@ -565,18 +561,39 @@ func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
// renderPopupWithOptions renders the popup content with optional center styling.
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
theme := GetTheme()
theme := style.GetTheme()
popupWidth := max(s.width-4, 20)
// Use the theme background for the popup - the full-width item backgrounds
// and primary-colored selection will provide sufficient contrast
popupBg := theme.Background
popupStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.MutedBorder).
BorderForeground(theme.Primary).
Background(popupBg).
Padding(1, 2).
Width(popupWidth).
MarginLeft(0)
MarginLeft(0).
MarginBottom(1) // Visual depth/shadow effect
// Inner content width: popup minus border (2) and horizontal padding (4).
innerWidth := max(popupWidth-6, 10)
// Item background styles for high contrast
normalItemBg := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.Text).
Width(innerWidth).
Padding(0, 1)
selectedItemBg := lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Width(innerWidth).
Padding(0, 1).
Bold(true)
var items []string
visibleItems := min(len(s.filtered), s.popupHeight)
@@ -590,44 +607,45 @@ func (s *InputComponent) renderPopupWithOptions(centered bool) string {
match := s.filtered[i]
sc := match.Command
// Choose the appropriate background style
itemStyle := normalItemBg
if i == s.selected {
itemStyle = selectedItemBg
}
// Build indicator with proper coloring
var indicator string
if i == s.selected {
indicator = lipgloss.NewStyle().Foreground(theme.Primary).Render("> ")
indicator = "> "
} else {
indicator = " "
}
nameStyle := lipgloss.NewStyle().Foreground(theme.Secondary).Bold(true)
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
if i == s.selected {
nameStyle = nameStyle.Foreground(theme.Primary)
descStyle = descStyle.Foreground(theme.Text)
}
// Build content with name and description
var content string
if s.fileMode {
// File mode: use full width for the path, show description
// (e.g. "directory") inline after a gap.
// File mode: use full width for the path, show description inline
maxNameLen := max(innerWidth-16, 8)
displayName := sc.Name
if len(displayName) > maxNameLen && maxNameLen > 3 {
displayName = displayName[:maxNameLen-3] + "..."
}
name := nameStyle.Render(displayName)
if sc.Description != "" && innerWidth > 30 {
items = append(items, indicator+name+" "+descStyle.Render(sc.Description))
content = indicator + displayName + " " + sc.Description
} else {
items = append(items, indicator+name)
content = indicator + displayName
}
} else {
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc.
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc
if innerWidth < 20 {
// Very narrow: show truncated name only, no fixed column.
// Very narrow: show truncated name only
displayName := sc.Name
maxName := max(innerWidth-2, 3)
if len(displayName) > maxName {
displayName = displayName[:maxName-1] + "…"
}
items = append(items, indicator+nameStyle.Render(displayName))
content = indicator + displayName
} else {
nameWidth := 15
if innerWidth < 25 {
@@ -638,33 +656,41 @@ func (s *InputComponent) renderPopupWithOptions(centered bool) string {
if len(displayName) > maxNameChars {
displayName = displayName[:maxNameChars-1] + "…"
}
name := nameStyle.Width(maxNameChars).Render(displayName)
// Description gets remaining space.
// Description gets remaining space
maxDescLen := max(innerWidth-nameWidth, 0)
desc := sc.Description
if maxDescLen < 4 {
items = append(items, indicator+name)
} else {
if maxDescLen >= 4 && desc != "" {
if len(desc) > maxDescLen {
desc = desc[:maxDescLen-3] + "..."
}
items = append(items, indicator+name+descStyle.Render(desc))
content = indicator + lipgloss.NewStyle().Width(maxNameChars).Render(displayName) + desc
} else {
content = indicator + displayName
}
}
}
items = append(items, itemStyle.Render(content))
}
// Add scroll indicators with background
scrollStyle := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Width(innerWidth).
Padding(0, 1)
if startIdx > 0 {
items = append([]string{lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↑ more above")}, items...)
items = append([]string{scrollStyle.Render(" ↑ more above")}, items...)
}
if endIdx < len(s.filtered) {
items = append(items, lipgloss.NewStyle().Foreground(theme.VeryMuted).Render(" ↓ more below"))
items = append(items, scrollStyle.Render(" ↓ more below"))
}
content := strings.Join(items, "\n")
// Adapt footer text to available width.
// Adapt footer text to available width with background
var footerText string
if innerWidth >= 50 {
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
@@ -673,7 +699,10 @@ func (s *InputComponent) renderPopupWithOptions(centered bool) string {
} else {
footerText = "↑↓ tab ↵ esc"
}
footer := lipgloss.NewStyle().Foreground(theme.VeryMuted).Italic(true).
footer := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Italic(true).
Render(footerText)
return popupStyle.Render(content + "\n\n" + footer)
@@ -703,10 +732,10 @@ func (s *InputComponent) completeArgs(line string) []FuzzyMatch {
s.argMode = true
s.argCommand = cmdName
s.argSynthCmds = make([]SlashCommand, len(suggestions))
s.argSynthCmds = make([]commands.SlashCommand, len(suggestions))
s.filtered = make([]FuzzyMatch, len(suggestions))
for i, sug := range suggestions {
s.argSynthCmds[i] = SlashCommand{Name: sug}
s.argSynthCmds[i] = commands.SlashCommand{Name: sug}
s.filtered[i] = FuzzyMatch{Command: &s.argSynthCmds[i]}
}
return s.filtered
@@ -714,7 +743,7 @@ func (s *InputComponent) completeArgs(line string) []FuzzyMatch {
// findCommandWithComplete looks up a command by name that has a non-nil
// Complete function.
func (s *InputComponent) findCommandWithComplete(name string) *SlashCommand {
func (s *InputComponent) findCommandWithComplete(name string) *commands.SlashCommand {
for i := range s.commands {
if s.commands[i].Name == name && s.commands[i].Complete != nil {
return &s.commands[i]
@@ -732,7 +761,7 @@ func readClipboardImageCmd() tea.Cmd {
return clipboardImageMsg{err: err}
}
return clipboardImageMsg{
image: &ImageAttachment{
image: &core.ImageAttachment{
Data: img.Data,
MediaType: img.MediaType,
},
@@ -742,7 +771,7 @@ func readClipboardImageCmd() tea.Cmd {
// ClearPendingImages removes all pending image attachments and returns them.
// Used by the parent model when consuming images for submission.
func (s *InputComponent) ClearPendingImages() []ImageAttachment {
func (s *InputComponent) ClearPendingImages() []core.ImageAttachment {
images := s.pendingImages
s.pendingImages = nil
return images
+20 -37
View File
@@ -6,6 +6,9 @@ import (
"time"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/render"
"github.com/mark3labs/kit/internal/ui/style"
)
// --------------------------------------------------------------------------
@@ -143,47 +146,20 @@ func (s *StreamingMessageItem) Render(width int) string {
return s.cachedRender
}
// Get renderer from context
renderer := newMessageRenderer(width, false)
var rendered string
if s.role == "reasoning" {
// Render as reasoning/thinking block with live duration counter
theme := GetTheme()
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
ty := createTypography(theme)
content := strings.TrimLeft(s.content, " \t\n")
var parts []string
parts = append(parts, mutedStyle.Render(ty.Italic(content)))
// Add live duration counter (updates on each render)
var duration time.Duration
// Calculate duration in milliseconds for render.ReasoningBlock
var durationMs int64
if s.finalDuration > 0 {
// Streaming complete, show frozen duration
duration = s.finalDuration
durationMs = s.finalDuration.Milliseconds()
} else if !s.startTime.IsZero() {
// Still streaming, show live duration
duration = time.Since(s.startTime)
durationMs = time.Since(s.startTime).Milliseconds()
}
if duration > 0 {
var durationStr string
if duration < time.Second {
durationStr = fmt.Sprintf("%dms", duration.Milliseconds())
} else {
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
}
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
parts = append(parts, label+durationStyled)
}
rendered = styleMarginBottom1.Render(strings.Join(parts, "\n"))
ty := createTypography(style.GetTheme())
rendered = render.ReasoningBlock(s.content, durationMs, ty, style.GetTheme())
} else {
// Render as assistant message
msg := renderer.RenderAssistantMessage(s.content, s.timestamp, s.modelName)
rendered = msg.Content
rendered = render.AssistantBlock(s.content, width, style.GetTheme())
}
// Cache and return (but reasoning is never cached due to live duration)
@@ -196,10 +172,17 @@ func (s *StreamingMessageItem) Render(width int) string {
// Height returns the number of lines.
func (s *StreamingMessageItem) Height() int {
if s.cachedRender == "" {
// For reasoning blocks, cachedRender is never populated (rendering is
// width-independent and includes a live timer). Fall back to Render(0)
// so callers always get the correct height.
rendered := s.cachedRender
if rendered == "" {
rendered = s.Render(0)
}
if rendered == "" {
return 0
}
return strings.Count(s.cachedRender, "\n") + 1
return strings.Count(rendered, "\n") + 1
}
// AppendChunk adds a content chunk and invalidates the render cache.
@@ -255,7 +238,7 @@ func (m *StreamingBashOutputItem) Render(width int) string {
return m.cachedRender
}
theme := GetTheme()
theme := style.GetTheme()
var parts []string
// Header with command
+22 -51
View File
@@ -9,6 +9,9 @@ import (
"charm.land/lipgloss/v2"
"github.com/indaco/herald"
"github.com/mark3labs/kit/internal/ui/render"
"github.com/mark3labs/kit/internal/ui/style"
)
// MessageType represents different categories of messages displayed in the UI,
@@ -138,7 +141,7 @@ func newMessageRenderer(width int, debug bool) *MessageRenderer {
return &MessageRenderer{
width: width,
debug: debug,
ty: createTypography(GetTheme()),
ty: createTypography(style.GetTheme()),
}
}
@@ -149,12 +152,7 @@ func (r *MessageRenderer) SetWidth(width int) {
// RenderUserMessage renders a user's input message using herald Tip alert
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
rendered := r.ty.Tip(content)
rendered = styleMarginBottom1.Render(rendered)
rendered := render.UserBlock(content, r.ty, style.GetTheme())
return UIMessage{
Type: UserMessage,
@@ -166,18 +164,7 @@ func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time)
// RenderAssistantMessage renders an AI assistant's response
func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.Time, modelName string) UIMessage {
if strings.TrimSpace(content) == "" {
return UIMessage{
Type: AssistantMessage,
Content: "",
Height: 0,
Timestamp: timestamp,
}
}
// Use markdown rendering with Chroma syntax highlighting
rendered := toMarkdown(content, r.width-4)
rendered = styleMarginBottom1.Render(rendered)
rendered := render.AssistantBlock(content, r.width, style.GetTheme())
return UIMessage{
Type: AssistantMessage,
@@ -191,23 +178,7 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
// as live streaming: muted italic text with margin. This is used when resuming
// sessions to display saved reasoning content.
func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Time) UIMessage {
if strings.TrimSpace(content) == "" {
return UIMessage{
Type: AssistantMessage,
Content: "",
Height: 0,
Timestamp: timestamp,
}
}
theme := GetTheme()
// Match live streaming styling: muted italic text
// Same as stream.go renderReasoningBlock()
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
contentStr := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
rendered := mutedStyle.Render(r.ty.Italic(contentStr))
rendered = styleMarginBottom1.Render(rendered)
rendered := render.ReasoningBlock(content, 0, r.ty, style.GetTheme())
return UIMessage{
Type: AssistantMessage,
@@ -219,12 +190,7 @@ func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Ti
// RenderSystemMessage renders KIT system messages using herald Note alert
func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Time) UIMessage {
if strings.TrimSpace(content) == "" {
content = "No content available"
}
rendered := r.ty.Note(content)
rendered = styleMarginBottom1.Render(rendered)
rendered := render.SystemBlock(content, r.ty, style.GetTheme())
return UIMessage{
Type: SystemMessage,
@@ -290,8 +256,7 @@ func (r *MessageRenderer) RenderDebugConfigMessage(config map[string]any, timest
// RenderErrorMessage renders error notifications
func (r *MessageRenderer) RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage {
rendered := r.ty.Caution(errorMsg)
rendered = styleMarginBottom1.Render(rendered)
rendered := render.ErrorBlock(errorMsg, r.ty, style.GetTheme())
return UIMessage{
Type: ErrorMessage,
@@ -323,16 +288,16 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
}
var icon string
iconColor := GetTheme().Success
iconColor := style.GetTheme().Success
if isError {
icon = "×"
iconColor = GetTheme().Error
iconColor = style.GetTheme().Error
} else {
icon = "✓"
}
// Style the tool name with color
theme := GetTheme()
theme := style.GetTheme()
nameColor := theme.Info
if isError {
nameColor = theme.Error
@@ -351,7 +316,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
if extRd != nil && extRd.RenderBody != nil {
body = extRd.RenderBody(toolResult, isError, r.width-8)
if body != "" && extRd.BodyMarkdown {
body = strings.TrimSuffix(toMarkdown(body, r.width-8), "\n")
body = strings.TrimSuffix(style.ToMarkdown(body, r.width-8), "\n")
}
}
if body == "" {
@@ -369,6 +334,12 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
body = r.ty.Italic("(no output)")
}
// Wrap all tool errors in a herald Caution alert so the error text
// renders inside a contained block instead of spilling into the layout.
if isError && strings.TrimSpace(body) != "" {
body = r.ty.Alert(herald.AlertCaution, body)
}
// Compose: icon + name + params, then body
fullContent := r.ty.Compose(
headerLine,
@@ -397,7 +368,7 @@ func (r *MessageRenderer) formatToolResult(toolName, result string) string {
if strings.Contains(toolName, "bash") || strings.Contains(toolName, "command") ||
strings.Contains(toolName, "shell") {
if strings.Contains(result, "<stdout>") || strings.Contains(result, "<stderr>") {
return parseBashOutput(result, GetTheme())
return parseBashOutput(result, style.GetTheme())
}
}
@@ -405,7 +376,7 @@ func (r *MessageRenderer) formatToolResult(toolName, result string) string {
}
// createTypography creates a typography instance from theme
func createTypography(theme Theme) *herald.Typography {
func createTypography(theme style.Theme) *herald.Typography {
return herald.New(
herald.WithPalette(herald.ColorPalette{
Primary: theme.Primary,
@@ -437,5 +408,5 @@ func createTypography(theme Theme) *herald.Typography {
// UpdateTheme refreshes the renderer's typography instance with colors from
// the current theme. This is called when the user changes themes via /theme.
func (r *MessageRenderer) UpdateTheme() {
r.ty = createTypography(GetTheme())
r.ty = createTypography(style.GetTheme())
}
+598 -339
View File
File diff suppressed because it is too large Load Diff
+98 -291
View File
@@ -5,9 +5,7 @@ import (
"sort"
"strings"
"charm.land/bubbles/v2/key"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/models"
)
@@ -29,16 +27,14 @@ type ModelSelectedMsg struct {
// ModelSelectorCancelledMsg is sent when the user cancels the selector.
type ModelSelectorCancelledMsg struct{}
// ModelSelectorComponent is a full-screen Bubble Tea component that displays
// a filterable list of available models. It follows the same pattern as
// TreeSelectorComponent: inline text search, scrolling list, and custom
// messages for result delivery.
// ModelSelectorComponent is a Bubble Tea component that displays a filterable
// list of available models as a centered overlay popup. It delegates rendering
// and keyboard navigation to PopupList and converts results into the
// ModelSelectedMsg / ModelSelectorCancelledMsg messages expected by AppModel.
type ModelSelectorComponent struct {
allModels []ModelEntry // all available models (pre-sorted)
filtered []ModelEntry // subset matching the current search
cursor int
search string
currentModel string // "provider/model" of the active model (for checkmark)
popup *PopupList
allModels []ModelEntry // kept for the custom filter callback
currentModel string // "provider/model" of the active model
width int
height int
active bool
@@ -61,7 +57,22 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
continue
}
// For the custom provider, skip the built-in "custom" stub when
// user-defined models are present — the stub is a fallback for
// --provider-url usage and would just clutter the list.
userDefinedCustomModels := 0
if providerID == "custom" {
for modelID := range modelsMap {
if modelID != "custom" {
userDefinedCustomModels++
}
}
}
for modelID, info := range modelsMap {
if providerID == "custom" && modelID == "custom" && userDefinedCustomModels > 0 {
continue
}
allModels = append(allModels, ModelEntry{
Provider: providerID,
ModelID: modelID,
@@ -80,24 +91,31 @@ func NewModelSelector(currentModel string, width, height int) *ModelSelectorComp
return allModels[i].ModelID < allModels[j].ModelID
})
ms := &ModelSelectorComponent{
// Build PopupItems from model entries.
items := make([]PopupItem, len(allModels))
for i, m := range allModels {
items[i] = PopupItem{
Label: m.ModelID,
Description: fmt.Sprintf("[%s]", m.Provider),
Active: m.Provider+"/"+m.ModelID == currentModel,
Meta: m,
}
}
popup := NewPopupList("Model Selector", items, width, height)
popup.Subtitle = "Only showing models with configured API keys"
popup.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
return filterModels(query, allItems)
}
return &ModelSelectorComponent{
popup: popup,
allModels: allModels,
filtered: allModels,
currentModel: currentModel,
width: width,
height: height,
active: true,
}
// Position cursor on the current model if found.
for i, m := range ms.filtered {
if m.Provider+"/"+m.ModelID == currentModel {
ms.cursor = i
break
}
}
return ms
}
// Init implements tea.Model.
@@ -111,241 +129,94 @@ func (ms *ModelSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.WindowSizeMsg:
ms.width = msg.Width
ms.height = msg.Height
ms.popup.SetSize(msg.Width, msg.Height)
return ms, nil
case tea.KeyPressMsg:
switch {
case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))):
if ms.cursor > 0 {
ms.cursor--
}
result := ms.popup.HandleKey(msg.String(), msg.Text)
case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))):
if ms.cursor < len(ms.filtered)-1 {
ms.cursor++
if result.Selected != nil {
ms.active = false
entry := result.Selected.Meta.(ModelEntry)
modelStr := entry.Provider + "/" + entry.ModelID
return ms, func() tea.Msg {
return ModelSelectedMsg{ModelString: modelStr}
}
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
ms.cursor -= ms.visibleHeight()
if ms.cursor < 0 {
ms.cursor = 0
}
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
ms.cursor += ms.visibleHeight()
if ms.cursor >= len(ms.filtered) {
ms.cursor = len(ms.filtered) - 1
}
if ms.cursor < 0 {
ms.cursor = 0
}
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
ms.cursor = 0
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
ms.cursor = max(len(ms.filtered)-1, 0)
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
if ms.cursor < len(ms.filtered) {
entry := ms.filtered[ms.cursor]
ms.active = false
return ms, func() tea.Msg {
return ModelSelectedMsg{
ModelString: entry.Provider + "/" + entry.ModelID,
}
}
}
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
if ms.search != "" {
ms.search = ""
ms.rebuildFiltered()
} else {
ms.active = false
return ms, func() tea.Msg {
return ModelSelectorCancelledMsg{}
}
}
default:
// Inline text search.
if msg.Text != "" && len(msg.Text) == 1 {
ch := msg.Text[0]
if ch >= 32 && ch < 127 {
ms.search += string(ch)
ms.rebuildFiltered()
}
}
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ms.search) > 0 {
ms.search = ms.search[:len(ms.search)-1]
ms.rebuildFiltered()
}
if result.Cancelled {
ms.active = false
return ms, func() tea.Msg {
return ModelSelectorCancelledMsg{}
}
}
}
return ms, nil
}
// View implements tea.Model.
// View implements tea.Model — not used for overlay rendering.
// Use RenderOverlay for the centered overlay approach.
func (ms *ModelSelectorComponent) View() tea.View {
theme := GetTheme()
headerStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
PaddingLeft(2)
helpStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
PaddingLeft(2)
infoStyle := lipgloss.NewStyle().
Foreground(theme.Warning).
PaddingLeft(2)
var b strings.Builder
// Header.
b.WriteString(headerStyle.Render("Model Selector"))
b.WriteString("\n")
// Adapt help text to terminal width.
if ms.width >= 56 {
b.WriteString(helpStyle.Render("↑/↓: move enter: select esc: cancel type to filter"))
} else if ms.width >= 35 {
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc type"))
} else {
b.WriteString(helpStyle.Render("↑↓ ↵ esc"))
}
b.WriteString("\n")
if ms.width >= 48 {
b.WriteString(infoStyle.Render("Only showing models with configured API keys"))
} else {
b.WriteString(infoStyle.Render("Models with API keys"))
}
b.WriteString("\n")
// Search input.
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(2)
if ms.search != "" {
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ms.search)))
} else {
b.WriteString(searchStyle.Render("> "))
}
b.WriteString("\n")
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
b.WriteString("\n")
if len(ms.filtered) == 0 {
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
if ms.search != "" {
b.WriteString(emptyStyle.Render("No models matching \"" + ms.search + "\""))
} else {
b.WriteString(emptyStyle.Render("No models available (check API keys)"))
}
b.WriteString("\n")
} else {
// Visible window.
visH := ms.visibleHeight()
startIdx := 0
if ms.cursor >= visH {
startIdx = ms.cursor - visH + 1
}
endIdx := min(startIdx+visH, len(ms.filtered))
for i := startIdx; i < endIdx; i++ {
entry := ms.filtered[i]
line := ms.renderEntry(entry, i == ms.cursor)
b.WriteString(line)
b.WriteString("\n")
}
}
// Footer.
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ms.width)))
b.WriteString("\n")
footerParts := []string{
fmt.Sprintf("(%d/%d)", ms.cursor+1, len(ms.filtered)),
}
if ms.cursor < len(ms.filtered) {
entry := ms.filtered[ms.cursor]
if entry.Name != "" {
footerParts = append(footerParts, fmt.Sprintf("Model Name: %s", entry.Name))
}
if entry.ContextLimit > 0 {
footerParts = append(footerParts, fmt.Sprintf("Context: %dK", entry.ContextLimit/1000))
}
}
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
b.WriteString(footerStyle.Render(strings.Join(footerParts, " ")))
v := tea.NewView(b.String())
// Fallback full-screen rendering (unused when rendered as overlay).
v := tea.NewView(ms.popup.RenderCentered(ms.width, ms.height))
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
return v
}
// RenderOverlay returns the popup as a centered overlay string, ready to be
// composited on top of the main content via overlayContent().
func (ms *ModelSelectorComponent) RenderOverlay(termWidth, termHeight int) string {
return ms.popup.RenderCentered(termWidth, termHeight)
}
// IsActive returns whether the selector is still accepting input.
func (ms *ModelSelectorComponent) IsActive() bool {
return ms.active
}
// --- Internal helpers ---
// --- Model-specific fuzzy filter ---
func (ms *ModelSelectorComponent) visibleHeight() int {
// Reserve: header(1) + help(1) + info(1) + search(1) + separator(1) + footer(2) = 7.
// Minimum 3 entries so the selector is still usable on short terminals.
return max(ms.height-7, 3)
}
// filterModels scores and filters PopupItems whose Meta is a ModelEntry.
func filterModels(query string, items []PopupItem) []PopupItem {
if query == "" {
return items
}
q := strings.ToLower(query)
func (ms *ModelSelectorComponent) rebuildFiltered() {
if ms.search == "" {
ms.filtered = ms.allModels
} else {
query := strings.ToLower(ms.search)
ms.filtered = ms.filtered[:0]
type scored struct {
item PopupItem
score int
}
var matches []scored
type scored struct {
entry ModelEntry
score int
for _, item := range items {
entry, ok := item.Meta.(ModelEntry)
if !ok {
continue
}
var matches []scored
for _, entry := range ms.allModels {
s := ms.fuzzyScoreModel(query, entry)
if s > 0 {
matches = append(matches, scored{entry: entry, score: s})
}
}
// Sort by score descending, then alphabetically.
sort.Slice(matches, func(i, j int) bool {
if matches[i].score != matches[j].score {
return matches[i].score > matches[j].score
}
return matches[i].entry.ModelID < matches[j].entry.ModelID
})
ms.filtered = make([]ModelEntry, len(matches))
for i, m := range matches {
ms.filtered[i] = m.entry
s := fuzzyScoreModelEntry(q, entry)
if s > 0 {
matches = append(matches, scored{item: item, score: s})
}
}
// Clamp cursor.
if ms.cursor >= len(ms.filtered) {
ms.cursor = max(len(ms.filtered)-1, 0)
sort.Slice(matches, func(i, j int) bool {
if matches[i].score != matches[j].score {
return matches[i].score > matches[j].score
}
a := matches[i].item.Meta.(ModelEntry)
b := matches[j].item.Meta.(ModelEntry)
return a.ModelID < b.ModelID
})
result := make([]PopupItem, len(matches))
for i, m := range matches {
result[i] = m.item
}
return result
}
// fuzzyScoreModel scores a model entry against the search query.
func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry) int {
// fuzzyScoreModelEntry scores a model entry against the search query.
func fuzzyScoreModelEntry(query string, entry ModelEntry) int {
modelID := strings.ToLower(entry.ModelID)
provider := strings.ToLower(entry.Provider)
name := strings.ToLower(entry.Name)
@@ -398,67 +269,3 @@ func (ms *ModelSelectorComponent) fuzzyScoreModel(query string, entry ModelEntry
return 0
}
func (ms *ModelSelectorComponent) renderEntry(entry ModelEntry, isCursor bool) string {
theme := GetTheme()
modelStr := entry.ModelID
providerStr := fmt.Sprintf("[%s]", entry.Provider)
// Cursor indicator.
var cursor string
if isCursor {
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("-> ")
} else {
cursor = " "
}
// Active model checkmark.
var active string
activeWidth := 0
if entry.Provider+"/"+entry.ModelID == ms.currentModel {
active = lipgloss.NewStyle().Foreground(theme.Success).Render(" \u2713")
activeWidth = 2 // " ✓"
}
// Truncate model ID and provider tag to fit terminal width.
// Layout: cursor(3) + model + " " + provider + active.
// Use rune length for display-width accuracy (the "…" suffix is 1 rune / 1 column).
const cursorWidth = 3
available := max(ms.width-cursorWidth-activeWidth-1, 10) // 1 for space between model and provider
provDisplayLen := len([]rune(providerStr))
modelDisplayLen := len([]rune(modelStr))
if modelDisplayLen+1+provDisplayLen > available {
// Prioritize model name — truncate it, but keep provider visible.
maxModel := max(available-provDisplayLen-1, 6)
if maxModel < modelDisplayLen {
if maxModel > 3 {
runes := []rune(modelStr)
modelStr = string(runes[:maxModel-1]) + "…"
} else {
runes := []rune(modelStr)
modelStr = string(runes[:maxModel])
}
}
// If provider itself is too long, drop it.
modelDisplayLen = len([]rune(modelStr))
if modelDisplayLen+1+provDisplayLen > available {
providerStr = ""
}
}
// Style the model ID.
modelStyle := lipgloss.NewStyle().Foreground(theme.Text)
if isCursor {
modelStyle = modelStyle.Bold(true).Foreground(theme.Accent)
}
// Style the provider tag.
providerStyle := lipgloss.NewStyle().Foreground(theme.Muted)
result := cursor + modelStyle.Render(modelStr)
if providerStr != "" {
result += " " + providerStyle.Render(providerStr)
}
return result + active
}
+104 -49
View File
@@ -7,6 +7,7 @@ import (
tea "charm.land/bubbletea/v2"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/ui/core"
kit "github.com/mark3labs/kit/pkg/kit"
)
@@ -80,6 +81,11 @@ func (s *stubAppController) Steer(prompt string) int {
return s.queueLen
}
func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) int {
s.runCalls = append(s.runCalls, prompt)
return s.queueLen
}
// --------------------------------------------------------------------------
// Stub child components
// --------------------------------------------------------------------------
@@ -87,7 +93,6 @@ func (s *stubAppController) Steer(prompt string) int {
// stubStreamComponent satisfies streamComponentIface without rendering anything.
type stubStreamComponent struct {
resetCalled int
height int
lastMsg tea.Msg
renderedContent string // returned by GetRenderedContent
}
@@ -99,9 +104,7 @@ func (s *stubStreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (s *stubStreamComponent) View() tea.View { return tea.NewView("") }
func (s *stubStreamComponent) Reset() { s.resetCalled++; s.renderedContent = "" }
func (s *stubStreamComponent) SetHeight(h int) { s.height = h }
func (s *stubStreamComponent) GetRenderedContent() string { return s.renderedContent }
func (s *stubStreamComponent) ConsumeOverflow() string { return "" }
func (s *stubStreamComponent) SpinnerView() string { return "" }
func (s *stubStreamComponent) SetThinkingVisible(bool) {}
func (s *stubStreamComponent) HasReasoning() bool { return false }
@@ -136,6 +139,8 @@ func newTestAppModel(ctrl AppController) (*AppModel, *stubStreamComponent, *stub
width: 80,
height: 24,
streamingBashMaxLines: 50, // Initialize buffer cap like NewAppModel does
scrollList: NewScrollList(80, 20),
messages: []MessageItem{},
}
return m, stream, input
}
@@ -168,7 +173,7 @@ func TestStateTransition_InputToWorking(t *testing.T) {
t.Fatalf("expected stateInput, got %v", m.state)
}
m = sendMsg(m, submitMsg{Text: "hello"})
m = sendMsg(m, core.SubmitMsg{Text: "hello"})
if m.state != stateWorking {
t.Fatalf("expected stateWorking after submitMsg, got %v", m.state)
@@ -356,7 +361,7 @@ func TestESCCancel_timerExpiry(t *testing.T) {
m.state = stateWorking
m.canceling = true
m = sendMsg(m, cancelTimerExpiredMsg{})
m = sendMsg(m, core.CancelTimerExpiredMsg{})
if m.canceling {
t.Fatal("expected canceling=false after timer expiry")
@@ -409,7 +414,7 @@ func TestQueuedMessages_storedOnQueuedSubmit(t *testing.T) {
m, _, _ := newTestAppModel(ctrl)
m.state = stateWorking
_, cmd := m.Update(submitMsg{Text: "queued prompt"})
_, cmd := m.Update(core.SubmitMsg{Text: "queued prompt"})
if len(m.queuedMessages) != 1 {
t.Fatalf("expected 1 queued message, got %d", len(m.queuedMessages))
@@ -417,7 +422,7 @@ func TestQueuedMessages_storedOnQueuedSubmit(t *testing.T) {
if m.queuedMessages[0] != "queued prompt" {
t.Fatalf("expected queued message text 'queued prompt', got %q", m.queuedMessages[0])
}
// Should NOT produce a tea.Println cmd (message is anchored, not in scrollback).
// Should NOT flush (message is anchored in ScrollList).
if cmd != nil {
t.Fatal("expected nil cmd for queued submit (message should not print to scrollback)")
}
@@ -507,19 +512,19 @@ func TestWindowResize_propagatesToStream(t *testing.T) {
// sets the stream height after a resize.
func TestWindowResize_distributeHeight(t *testing.T) {
ctrl := &stubAppController{}
m, stream, _ := newTestAppModel(ctrl)
m, _, _ := newTestAppModel(ctrl)
// With height=30, stream height = 30 - 1 (separator) - 9 (input) - 1 (statusBar) = 19
// With height=30, scroll height = 30 - 1 (separator) - 9 (input) - 1 (statusBar) = 19
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 30})
_ = m
if stream.height != 19 {
t.Fatalf("expected stream height=19, got %d", stream.height)
if m.scrollList.height != 19 {
t.Fatalf("expected scroll list height=19, got %d", m.scrollList.height)
}
}
// --------------------------------------------------------------------------
// tea.Println on step complete
// Step complete behavior
// --------------------------------------------------------------------------
// TestStepComplete_preservesStreamContent verifies that StepCompleteEvent
@@ -552,65 +557,87 @@ func TestStepComplete_noStreamContent_noCmd(t *testing.T) {
}
}
// TestSubmitMsg_printsUserMessage verifies that submitMsg produces a tea.Println
// cmd for the user message.
// TestSubmitMsg_printsUserMessage verifies that submitMsg adds the user message
// to the ScrollList messages and triggers a layout update.
func TestSubmitMsg_printsUserMessage(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
_, cmd := m.Update(submitMsg{Text: "user query"})
m = sendMsg(m, core.SubmitMsg{Text: "user query"})
if cmd == nil {
t.Fatal("expected non-nil cmd (tea.Println) for user message on submitMsg")
// In alt screen mode, user messages are added to the in-memory ScrollList
// rather than printed separately. Verify the message was added.
found := false
for _, msg := range m.messages {
if tm, ok := msg.(*TextMessageItem); ok && tm.role == "user" && tm.content == "user query" {
found = true
break
}
}
if !found {
t.Fatal("expected user message 'user query' in ScrollList messages")
}
}
// TestToolCallStarted_flushesOnly verifies that ToolCallStartedEvent flushes
// accumulated stream content but does NOT print a tool call block (the unified
// block is printed later on ToolResultEvent).
// TestToolCallStarted_flushesOnly verifies that ToolCallStartedEvent marks
// any active StreamingMessageItem as complete and resets the stream.
func TestToolCallStarted_flushesOnly(t *testing.T) {
ctrl := &stubAppController{}
m, stream, _ := newTestAppModel(ctrl)
m.state = stateWorking
// With no stream content, flush returns nil → cmd should be nil.
_, cmd := m.Update(app.ToolCallStartedEvent{
// With no stream content, nothing should change.
initialCount := len(m.messages)
m = sendMsg(m, app.ToolCallStartedEvent{
ToolName: "bash",
ToolArgs: `{"cmd":"ls"}`,
})
if cmd != nil {
t.Fatal("expected nil cmd on ToolCallStartedEvent with no stream content")
if len(m.messages) != initialCount {
t.Fatal("expected no new messages on ToolCallStartedEvent with no stream content")
}
// With stream content, flush returns tea.Println → cmd should be non-nil.
// Simulate a StreamingMessageItem already in messages (as if appendStreamingChunk was called)
// plus the stream component having rendered content.
streamItem := NewStreamingMessageItem("stream-1", "assistant", "test-model")
streamItem.AppendChunk("partial text")
m.messages = append(m.messages, streamItem)
stream.renderedContent = "partial text"
_, cmd = m.Update(app.ToolCallStartedEvent{
_ = sendMsg(m, app.ToolCallStartedEvent{
ToolName: "bash",
ToolArgs: `{"cmd":"ls"}`,
})
if cmd == nil {
t.Fatal("expected non-nil cmd on ToolCallStartedEvent with stream content to flush")
// The StreamingMessageItem should have been marked complete.
if streamItem.streaming {
t.Fatal("expected StreamingMessageItem to be marked complete after ToolCallStartedEvent")
}
// Stream should have been reset.
if stream.resetCalled == 0 {
t.Fatal("expected stream.Reset() to be called")
}
}
// TestToolResult_printsAndStartsSpinner verifies that ToolResultEvent produces
// a non-nil cmd and the stream receives a SpinnerEvent.
// TestToolResult_printsAndStartsSpinner verifies that ToolResultEvent adds
// the tool result to the ScrollList and the stream receives a SpinnerEvent.
func TestToolResult_printsAndStartsSpinner(t *testing.T) {
ctrl := &stubAppController{}
m, stream, _ := newTestAppModel(ctrl)
m.state = stateWorking
_, cmd := m.Update(app.ToolResultEvent{
initialCount := len(m.messages)
m = sendMsg(m, app.ToolResultEvent{
ToolName: "bash",
ToolArgs: "{}",
Result: "output",
IsError: false,
})
if cmd == nil {
t.Fatal("expected non-nil cmd on ToolResultEvent")
// Tool result should have been added to ScrollList messages.
if len(m.messages) <= initialCount {
t.Fatal("expected tool result message added to ScrollList")
}
// Stream should have received a SpinnerEvent to start spinner for next LLM call.
if stream.lastMsg == nil {
@@ -622,7 +649,7 @@ func TestToolResult_printsAndStartsSpinner(t *testing.T) {
}
// TestToolOutputEvent_accumulatesBashOutput verifies that ToolOutputEvent
// accumulates stdout and stderr lines into the streaming bash output buffers.
// accumulates stdout and stderr lines into a StreamingBashOutputItem in the ScrollList.
func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
@@ -636,11 +663,23 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
IsStderr: false,
})
if len(m.streamingBashOutput) != 1 || m.streamingBashOutput[0] != "line one\n" {
t.Fatalf("expected streamingBashOutput=['line one\\n'], got %v", m.streamingBashOutput)
// Should have created a StreamingBashOutputItem in messages.
var bashItem *StreamingBashOutputItem
for _, msg := range m.messages {
if item, ok := msg.(*StreamingBashOutputItem); ok {
bashItem = item
break
}
}
if len(m.streamingBashStderr) != 0 {
t.Fatalf("expected empty streamingBashStderr, got %v", m.streamingBashStderr)
if bashItem == nil {
t.Fatal("expected StreamingBashOutputItem in messages after ToolOutputEvent")
return
}
if len(bashItem.stdoutLines) != 1 || bashItem.stdoutLines[0] != "line one\n" {
t.Fatalf("expected stdout=['line one\\n'], got %v", bashItem.stdoutLines)
}
if len(bashItem.stderrLines) != 0 {
t.Fatalf("expected empty stderr, got %v", bashItem.stderrLines)
}
// Send another stdout chunk.
@@ -651,8 +690,15 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
IsStderr: false,
})
if len(m.streamingBashOutput) != 2 {
t.Fatalf("expected 2 stdout lines, got %d", len(m.streamingBashOutput))
// Re-find the bash item (same item, updated)
bashItem = nil
for _, msg := range m.messages {
if item, ok := msg.(*StreamingBashOutputItem); ok {
bashItem = item
}
}
if bashItem == nil || len(bashItem.stdoutLines) != 2 {
t.Fatalf("expected 2 stdout lines, got %d", len(bashItem.stdoutLines))
}
// Send stderr chunk.
@@ -663,11 +709,17 @@ func TestToolOutputEvent_accumulatesBashOutput(t *testing.T) {
IsStderr: true,
})
if len(m.streamingBashStderr) != 1 {
t.Fatalf("expected 1 stderr line, got %d", len(m.streamingBashStderr))
bashItem = nil
for _, msg := range m.messages {
if item, ok := msg.(*StreamingBashOutputItem); ok {
bashItem = item
}
}
if m.streamingBashStderr[0] != "error: something failed\n" {
t.Fatalf("expected stderr 'error: something failed\\n', got %q", m.streamingBashStderr[0])
if bashItem == nil || len(bashItem.stderrLines) != 1 {
t.Fatalf("expected 1 stderr line, got %d", len(bashItem.stderrLines))
}
if bashItem.stderrLines[0] != "error: something failed\n" {
t.Fatalf("expected stderr 'error: something failed\\n', got %q", bashItem.stderrLines[0])
}
}
@@ -749,16 +801,19 @@ func TestToolCallStarted_nonBashTool_doesNotSetCommand(t *testing.T) {
}
// TestStepError_printCmd verifies that StepErrorEvent with a non-nil error
// produces a non-nil cmd (the tea.Println call for the error message).
// adds an error message to the ScrollList.
func TestStepError_printCmd(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
m.state = stateWorking
_, cmd := m.Update(app.StepErrorEvent{Err: errors.New("agent failed")})
initialCount := len(m.messages)
if cmd == nil {
t.Fatal("expected non-nil cmd (tea.Println) on StepErrorEvent with error")
m = sendMsg(m, app.StepErrorEvent{Err: errors.New("agent failed")})
// Error should have been added to ScrollList messages.
if len(m.messages) <= initialCount {
t.Fatal("expected error message added to ScrollList on StepErrorEvent")
}
}
@@ -828,7 +883,7 @@ func TestSubmit_duringWorking_stays(t *testing.T) {
m, _, _ := newTestAppModel(ctrl)
m.state = stateWorking
m = sendMsg(m, submitMsg{Text: "queued prompt"})
m = sendMsg(m, core.SubmitMsg{Text: "queued prompt"})
if m.state != stateWorking {
t.Fatalf("expected stateWorking to persist after submitMsg during working, got %v", m.state)
+4 -2
View File
@@ -6,6 +6,8 @@ import (
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/style"
)
// ---------------------------------------------------------------------------
@@ -133,7 +135,7 @@ func (o *overlayDialog) handleKey(msg tea.KeyPressMsg) (*overlayResult, tea.Cmd)
// composition. The dialog is a bordered box centered (or anchored)
// horizontally within the terminal width.
func (o *overlayDialog) Render() string {
theme := GetTheme()
theme := style.GetTheme()
// Calculate dialog dimensions, clamped to terminal bounds.
termW := max(o.width, 10)
@@ -157,7 +159,7 @@ func (o *overlayDialog) Render() string {
// Render body text (potentially as markdown).
bodyText := o.content
if o.markdown {
bodyText = toMarkdown(bodyText, innerWidth)
bodyText = style.ToMarkdown(bodyText, innerWidth)
}
bodyText = strings.TrimRight(bodyText, "\n")
+501
View File
@@ -0,0 +1,501 @@
package ui
import (
"fmt"
"strings"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/style"
)
// PopupItem represents a single entry in a PopupList. The component renders
// Label as the primary text and Description as secondary text to its right.
// The Active flag renders a checkmark to indicate the currently-active item
// (e.g. the current model). Meta is opaque caller data returned on selection.
type PopupItem struct {
Label string // primary display text
Description string // secondary text (shown right of label)
Active bool // true → render checkmark indicator
Meta any // opaque data returned on selection
}
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
// rendered as a centered overlay on top of the normal TUI layout and can be
// reused by any feature that needs a selection popup (slash commands, model
// selector, session picker, extension-provided lists, etc.).
//
// The caller is responsible for:
// - Building the initial item list
// - Providing a fuzzy-filter callback (or nil for substring matching)
// - Handling the result when the user selects or cancels
//
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
type PopupList struct {
// Title shown at the top of the popup.
Title string
// Subtitle shown below the title (dimmed).
Subtitle string
// FooterHint overrides the default keyboard-hint footer.
FooterHint string
allItems []PopupItem // full unfiltered list
filtered []PopupItem // subset matching the current search
cursor int
search string
// FilterFunc is called with (query, allItems) and should return the
// filtered+scored subset. When nil, a default substring match is used.
FilterFunc func(query string, items []PopupItem) []PopupItem
width int
height int
maxVisible int // max items visible at once (0 = auto from height)
showSearch bool
}
// PopupResult is returned by HandleKey to tell the caller what happened.
type PopupResult struct {
// Selected is non-nil when the user pressed Enter on an item.
Selected *PopupItem
// Cancelled is true when the user pressed Esc with no search text.
Cancelled bool
// Changed is true when the search or cursor moved (caller should re-render).
Changed bool
}
// NewPopupList creates a new popup list with the given items and dimensions.
func NewPopupList(title string, items []PopupItem, width, height int) *PopupList {
p := &PopupList{
Title: title,
allItems: items,
filtered: items,
width: width,
height: height,
showSearch: true,
}
// Position cursor on the active item if one exists.
for i, item := range p.filtered {
if item.Active {
p.cursor = i
break
}
}
return p
}
// SetSize updates the popup dimensions (e.g. on window resize).
func (p *PopupList) SetSize(width, height int) {
p.width = width
p.height = height
}
// visibleCount returns the number of items visible at once.
func (p *PopupList) visibleCount() int {
if p.maxVisible > 0 {
return p.maxVisible
}
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
overhead := 8
if p.Subtitle != "" {
overhead++
}
if p.showSearch {
overhead += 2 // search line + separator
}
return max(p.height/2-overhead, 3)
}
// HandleKey processes a single key event and returns the result. The caller
// should inspect PopupResult to decide whether to re-render, close the popup,
// or act on a selection.
//
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
// keyText is the printable text for character keys (e.g. "a", "1").
func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
switch keyName {
case "up":
if p.cursor > 0 {
p.cursor--
return PopupResult{Changed: true}
}
return PopupResult{}
case "down":
if p.cursor < len(p.filtered)-1 {
p.cursor++
return PopupResult{Changed: true}
}
return PopupResult{}
case "pgup":
p.cursor -= p.visibleCount()
if p.cursor < 0 {
p.cursor = 0
}
return PopupResult{Changed: true}
case "pgdown":
p.cursor += p.visibleCount()
if p.cursor >= len(p.filtered) {
p.cursor = max(len(p.filtered)-1, 0)
}
return PopupResult{Changed: true}
case "home":
p.cursor = 0
return PopupResult{Changed: true}
case "end":
p.cursor = max(len(p.filtered)-1, 0)
return PopupResult{Changed: true}
case "enter":
if p.cursor < len(p.filtered) {
item := p.filtered[p.cursor]
return PopupResult{Selected: &item}
}
return PopupResult{}
case "esc":
if p.search != "" {
p.search = ""
p.rebuildFiltered()
return PopupResult{Changed: true}
}
return PopupResult{Cancelled: true}
case "backspace":
if len(p.search) > 0 {
p.search = p.search[:len(p.search)-1]
p.rebuildFiltered()
return PopupResult{Changed: true}
}
return PopupResult{}
default:
// Printable character → append to search.
if keyText != "" && len(keyText) == 1 {
ch := keyText[0]
if ch >= 32 && ch < 127 {
p.search += string(ch)
p.rebuildFiltered()
return PopupResult{Changed: true}
}
}
return PopupResult{}
}
}
// Render returns the styled popup content (bordered box) ready to be placed
// as a centered overlay via lipgloss.Place + overlayContent.
func (p *PopupList) Render() string {
theme := style.GetTheme()
popupWidth := max(min(p.width-4, 80), 20)
popupBg := theme.Background
popupStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(popupBg).
Padding(1, 2).
Width(popupWidth).
MarginBottom(1)
// Inner content width: popup minus border (2) and horizontal padding (4).
innerWidth := max(popupWidth-6, 10)
var b strings.Builder
// Title.
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
Background(popupBg).
Width(innerWidth)
b.WriteString(titleStyle.Render(p.Title))
b.WriteString("\n")
// Subtitle.
if p.Subtitle != "" {
subtitleStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg).
Width(innerWidth)
b.WriteString(subtitleStyle.Render(p.Subtitle))
b.WriteString("\n")
}
// Search input.
if p.showSearch {
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(popupBg).
Width(innerWidth)
if p.search != "" {
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
} else {
b.WriteString(searchStyle.Render("> "))
}
b.WriteString("\n")
// Separator.
sepStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg)
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
b.WriteString("\n")
}
// Item list.
normalItemBg := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.Text).
Width(innerWidth).
Padding(0, 1)
selectedItemBg := lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Width(innerWidth).
Padding(0, 1).
Bold(true)
scrollStyle := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Width(innerWidth).
Padding(0, 1)
vis := p.visibleCount()
var items []string
if len(p.filtered) == 0 {
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(popupBg).
Width(innerWidth).
Padding(0, 1)
if p.search != "" {
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
} else {
items = append(items, emptyStyle.Render("No items"))
}
} else {
startIdx := 0
if p.cursor >= vis {
startIdx = p.cursor - vis + 1
}
endIdx := min(startIdx+vis, len(p.filtered))
if startIdx > 0 {
items = append(items, scrollStyle.Render(" ↑ more above"))
}
for i := startIdx; i < endIdx; i++ {
entry := p.filtered[i]
isCursor := i == p.cursor
itemStyle := normalItemBg
if isCursor {
itemStyle = selectedItemBg
}
// Build indicator.
var indicator string
if isCursor {
indicator = "> "
} else {
indicator = " "
}
// Build content: indicator + label + description + active checkmark.
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
items = append(items, itemStyle.Render(content))
}
if endIdx < len(p.filtered) {
items = append(items, scrollStyle.Render(" ↓ more below"))
}
}
content := b.String() + strings.Join(items, "\n")
// Footer with count and keyboard hints.
var footerParts []string
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
footerHint := p.FooterHint
if footerHint == "" {
if innerWidth >= 50 {
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
} else if innerWidth >= 30 {
footerHint = "↑↓ nav • ↵ select • esc"
} else {
footerHint = "↑↓ ↵ esc"
}
}
footerParts = append(footerParts, footerHint)
footer := lipgloss.NewStyle().
Background(popupBg).
Foreground(theme.VeryMuted).
Italic(true).
Render(strings.Join(footerParts, " "))
return popupStyle.Render(content + "\n\n" + footer)
}
// RenderCentered returns the popup placed at the center of a termWidth×termHeight
// canvas, ready to be composed with overlayContent().
func (p *PopupList) RenderCentered(termWidth, termHeight int) string {
popupContent := p.Render()
return lipgloss.Place(
termWidth,
termHeight,
lipgloss.Center,
lipgloss.Center,
popupContent,
)
}
// IsSearching returns true when the search input is non-empty.
func (p *PopupList) IsSearching() bool {
return p.search != ""
}
// SelectedItem returns the item under the cursor, or nil if the list is empty.
func (p *PopupList) SelectedItem() *PopupItem {
if p.cursor < len(p.filtered) {
item := p.filtered[p.cursor]
return &item
}
return nil
}
// --- Internal helpers ---
func (p *PopupList) rebuildFiltered() {
if p.FilterFunc != nil {
p.filtered = p.FilterFunc(p.search, p.allItems)
} else {
p.filtered = defaultFilter(p.search, p.allItems)
}
// Clamp cursor.
if p.cursor >= len(p.filtered) {
p.cursor = max(len(p.filtered)-1, 0)
}
}
// defaultFilter is a simple case-insensitive substring + fuzzy character match.
func defaultFilter(query string, items []PopupItem) []PopupItem {
if query == "" {
return items
}
q := strings.ToLower(query)
type scored struct {
item PopupItem
score int
}
var matches []scored
for _, item := range items {
label := strings.ToLower(item.Label)
desc := strings.ToLower(item.Description)
var s int
switch {
case label == q:
s = 1000
case strings.HasPrefix(label, q):
s = 800 - len(label) + len(q)
case strings.Contains(label, q):
s = 600
case strings.Contains(desc, q):
s = 400
default:
s = fuzzyCharacterMatch(q, label)
}
if s > 0 {
matches = append(matches, scored{item: item, score: s})
}
}
// Sort by score descending, then alphabetically by label.
for i := 0; i < len(matches)-1; i++ {
for j := i + 1; j < len(matches); j++ {
if matches[j].score > matches[i].score ||
(matches[j].score == matches[i].score && matches[j].item.Label < matches[i].item.Label) {
matches[i], matches[j] = matches[j], matches[i]
}
}
}
result := make([]PopupItem, len(matches))
for i, m := range matches {
result[i] = m.item
}
return result
}
// renderItemContent builds the display string for a single item row.
func (p *PopupList) renderItemContent(indicator string, entry PopupItem, innerWidth int, isCursor bool) string {
theme := style.GetTheme()
// Reserve space: indicator(2) + potential checkmark(2)
activeWidth := 0
if entry.Active {
activeWidth = 2
}
available := max(innerWidth-2-activeWidth, 6) // 2 for indicator, already included
label := entry.Label
desc := entry.Description
if desc != "" {
// Two-column layout: label + description.
descWidth := len([]rune(desc)) + 1 // 1 space gap
labelMax := max(available-descWidth, available*2/3)
if len([]rune(label)) > labelMax && labelMax > 3 {
runes := []rune(label)
label = string(runes[:labelMax-1]) + "…"
}
labelDisplayLen := len([]rune(label))
// If label + desc don't fit, truncate or drop desc.
if labelDisplayLen+1+len([]rune(desc)) > available {
remaining := available - labelDisplayLen - 1
if remaining >= 4 {
runes := []rune(desc)
if len(runes) > remaining {
desc = string(runes[:remaining-1]) + "…"
}
} else {
desc = ""
}
}
} else {
// Single column: just the label.
if len([]rune(label)) > available && available > 3 {
runes := []rune(label)
label = string(runes[:available-1]) + "…"
}
}
result := indicator + label
if desc != "" {
descStyle := lipgloss.NewStyle().Foreground(theme.Muted)
if isCursor {
// When selected, use a dimmer foreground that still contrasts with Primary bg.
descStyle = lipgloss.NewStyle().Foreground(theme.Background)
}
result += " " + descStyle.Render(desc)
}
if entry.Active {
checkStyle := lipgloss.NewStyle().Foreground(theme.Success)
if isCursor {
checkStyle = lipgloss.NewStyle().Foreground(theme.Background)
}
result += checkStyle.Render(" ✓")
}
return result
}
+297
View File
@@ -0,0 +1,297 @@
package ui
import (
"strings"
"testing"
)
func TestPopupList_NewPositionsCursorOnActiveItem(t *testing.T) {
items := []PopupItem{
{Label: "alpha"},
{Label: "beta"},
{Label: "gamma", Active: true},
{Label: "delta"},
}
p := NewPopupList("Test", items, 80, 40)
if p.cursor != 2 {
t.Errorf("expected cursor on active item (index 2), got %d", p.cursor)
}
}
func TestPopupList_HandleKey_Navigation(t *testing.T) {
items := []PopupItem{
{Label: "alpha"},
{Label: "beta"},
{Label: "gamma"},
}
p := NewPopupList("Test", items, 80, 40)
// Initial cursor at 0.
if p.cursor != 0 {
t.Fatalf("expected cursor 0, got %d", p.cursor)
}
// Down → 1.
res := p.HandleKey("down", "")
if !res.Changed || p.cursor != 1 {
t.Errorf("down: changed=%v cursor=%d", res.Changed, p.cursor)
}
// Down → 2.
p.HandleKey("down", "")
if p.cursor != 2 {
t.Errorf("expected cursor 2, got %d", p.cursor)
}
// Down at end → stays at 2.
res = p.HandleKey("down", "")
if p.cursor != 2 {
t.Errorf("down at end: expected cursor 2, got %d", p.cursor)
}
// Up → 1.
res = p.HandleKey("up", "")
if !res.Changed || p.cursor != 1 {
t.Errorf("up: changed=%v cursor=%d", res.Changed, p.cursor)
}
// Home → 0.
p.HandleKey("home", "")
if p.cursor != 0 {
t.Errorf("home: expected cursor 0, got %d", p.cursor)
}
// End → 2.
p.HandleKey("end", "")
if p.cursor != 2 {
t.Errorf("end: expected cursor 2, got %d", p.cursor)
}
}
func TestPopupList_HandleKey_Search(t *testing.T) {
items := []PopupItem{
{Label: "apple"},
{Label: "banana"},
{Label: "cherry"},
}
p := NewPopupList("Test", items, 80, 40)
// Type "an" → should filter to banana.
p.HandleKey("a", "a")
p.HandleKey("n", "n")
if !p.IsSearching() {
t.Error("expected IsSearching() to be true")
}
if len(p.filtered) == 0 {
t.Fatal("expected at least one filtered result")
}
// banana should match (contains "an").
found := false
for _, item := range p.filtered {
if item.Label == "banana" {
found = true
break
}
}
if !found {
t.Error("expected 'banana' in filtered results")
}
// Backspace removes last char.
p.HandleKey("backspace", "")
if p.search != "a" {
t.Errorf("expected search 'a' after backspace, got %q", p.search)
}
// Esc clears search.
res := p.HandleKey("esc", "")
if res.Cancelled {
t.Error("esc with search should clear search, not cancel")
}
if p.search != "" {
t.Errorf("expected empty search after esc, got %q", p.search)
}
}
func TestPopupList_HandleKey_SelectAndCancel(t *testing.T) {
items := []PopupItem{
{Label: "alpha", Meta: "first"},
{Label: "beta", Meta: "second"},
}
p := NewPopupList("Test", items, 80, 40)
// Select first item.
res := p.HandleKey("enter", "")
if res.Selected == nil {
t.Fatal("expected a selection on enter")
}
if res.Selected.Label != "alpha" {
t.Errorf("expected 'alpha', got %q", res.Selected.Label)
}
if res.Selected.Meta != "first" {
t.Errorf("expected meta 'first', got %v", res.Selected.Meta)
}
// Cancel with esc (no search text).
p2 := NewPopupList("Test", items, 80, 40)
res = p2.HandleKey("esc", "")
if !res.Cancelled {
t.Error("expected Cancelled on esc with no search")
}
}
func TestPopupList_DefaultFilter(t *testing.T) {
items := []PopupItem{
{Label: "foo-bar"},
{Label: "baz-qux"},
{Label: "foobar"},
}
// Exact prefix.
result := defaultFilter("foo", items)
if len(result) < 2 {
t.Fatalf("expected at least 2 matches for 'foo', got %d", len(result))
}
// "foobar" should rank higher (shorter match) or equal to "foo-bar".
if result[0].Label != "foobar" && result[1].Label != "foobar" {
t.Error("expected 'foobar' in top results")
}
// No match.
result = defaultFilter("zzz", items)
if len(result) != 0 {
t.Errorf("expected 0 matches for 'zzz', got %d", len(result))
}
}
func TestPopupList_CustomFilterFunc(t *testing.T) {
items := []PopupItem{
{Label: "alpha"},
{Label: "beta"},
{Label: "gamma"},
}
p := NewPopupList("Test", items, 80, 40)
p.FilterFunc = func(query string, allItems []PopupItem) []PopupItem {
// Custom: only return items whose label starts with query.
var result []PopupItem
for _, item := range allItems {
if strings.HasPrefix(item.Label, query) {
result = append(result, item)
}
}
return result
}
p.HandleKey("b", "b")
if len(p.filtered) != 1 || p.filtered[0].Label != "beta" {
t.Errorf("expected ['beta'], got %v", p.filtered)
}
}
func TestPopupList_Render(t *testing.T) {
items := []PopupItem{
{Label: "alpha", Description: "[test]"},
{Label: "beta", Description: "[test]", Active: true},
}
p := NewPopupList("My List", items, 80, 40)
p.Subtitle = "Some subtitle"
rendered := p.Render()
if rendered == "" {
t.Fatal("expected non-empty rendered output")
}
// Strip ANSI escape sequences for content checking.
plain := stripAnsi(rendered)
if !strings.Contains(plain, "My List") {
t.Error("expected title 'My List' in rendered output")
}
if !strings.Contains(plain, "alpha") {
t.Error("expected 'alpha' in rendered output")
}
if !strings.Contains(plain, "beta") {
t.Error("expected 'beta' in rendered output")
}
if !strings.Contains(plain, "✓") {
t.Error("expected checkmark for active item")
}
}
func TestPopupList_RenderCentered(t *testing.T) {
items := []PopupItem{
{Label: "item1"},
}
p := NewPopupList("Test", items, 80, 40)
centered := p.RenderCentered(80, 40)
if centered == "" {
t.Fatal("expected non-empty centered output")
}
// Should contain newlines for vertical centering.
lines := strings.Split(centered, "\n")
if len(lines) < 10 {
t.Errorf("expected centered output to have many lines, got %d", len(lines))
}
}
func TestPopupList_EmptyItems(t *testing.T) {
p := NewPopupList("Empty", nil, 80, 40)
rendered := p.Render()
if !strings.Contains(rendered, "No items") {
t.Error("expected 'No items' for empty list")
}
// Navigate on empty list shouldn't panic.
p.HandleKey("down", "")
p.HandleKey("up", "")
res := p.HandleKey("enter", "")
if res.Selected != nil {
t.Error("enter on empty list should not select")
}
}
func TestPopupList_SearchNoResults(t *testing.T) {
items := []PopupItem{
{Label: "alpha"},
{Label: "beta"},
}
p := NewPopupList("Test", items, 80, 40)
// Type something that doesn't match.
p.HandleKey("z", "z")
p.HandleKey("z", "z")
p.HandleKey("z", "z")
rendered := p.Render()
if !strings.Contains(rendered, "No matches") {
t.Error("expected 'No matches' message for empty search results")
}
}
func TestPopupList_CursorClamping(t *testing.T) {
items := []PopupItem{
{Label: "alpha"},
{Label: "beta"},
{Label: "gamma"},
}
p := NewPopupList("Test", items, 80, 40)
// Move to last item.
p.HandleKey("end", "")
if p.cursor != 2 {
t.Fatalf("expected cursor 2, got %d", p.cursor)
}
// Search that reduces list to 1 item → cursor should clamp.
p.HandleKey("a", "a")
p.HandleKey("l", "l")
// Only "alpha" should match.
if p.cursor >= len(p.filtered) {
t.Errorf("cursor %d should be < filtered count %d", p.cursor, len(p.filtered))
}
}
// stripAnsi is defined in usage_tracker_render_test.go
@@ -1,4 +1,4 @@
package ui
package prefs
import (
"os"
+6 -4
View File
@@ -7,6 +7,8 @@ import (
"charm.land/bubbles/v2/textarea"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/ui/style"
)
// ---------------------------------------------------------------------------
@@ -204,7 +206,7 @@ func (p *promptOverlay) updateInput(msg tea.KeyPressMsg) (*promptResult, tea.Cmd
// AppModel layout. The prompt replaces the normal input area (below the
// separator and above the status bar) rather than taking over the full screen.
func (p *promptOverlay) Render() string {
theme := GetTheme()
theme := style.GetTheme()
var content string
switch p.mode {
@@ -224,7 +226,7 @@ func (p *promptOverlay) Render() string {
)
}
func (p *promptOverlay) viewSelect(theme Theme) string {
func (p *promptOverlay) viewSelect(theme style.Theme) string {
var lines []string
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
lines = append(lines, "")
@@ -247,7 +249,7 @@ func (p *promptOverlay) viewSelect(theme Theme) string {
return strings.Join(lines, "\n")
}
func (p *promptOverlay) viewConfirm(theme Theme) string {
func (p *promptOverlay) viewConfirm(theme style.Theme) string {
var lines []string
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
lines = append(lines, "")
@@ -272,7 +274,7 @@ func (p *promptOverlay) viewConfirm(theme Theme) string {
return strings.Join(lines, "\n")
}
func (p *promptOverlay) viewInput(theme Theme) string {
func (p *promptOverlay) viewInput(theme style.Theme) string {
var lines []string
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render(p.message))
lines = append(lines, "")
+127
View File
@@ -0,0 +1,127 @@
// Package render provides pure rendering functions for message blocks.
// These functions are stateless and can be used by both streaming and
// historical message rendering paths, eliminating code duplication.
package render
import (
"fmt"
"strings"
"charm.land/lipgloss/v2"
"github.com/indaco/herald"
"github.com/mark3labs/kit/internal/ui/style"
)
// UserBlock renders a user message with herald Tip styling.
func UserBlock(content string, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
rendered := ty.Tip(content)
return styleMarginBottom(theme, rendered)
}
// AssistantBlock renders an assistant message with markdown styling.
func AssistantBlock(content string, width int, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
return ""
}
rendered := style.ToMarkdown(content, width-4)
return styleMarginBottom(theme, rendered)
}
// ReasoningBlock renders a reasoning/thinking block with muted italic text.
// If duration > 0, shows "Thought for Xs" label. Otherwise shows just "Thought".
func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
return ""
}
// Match live streaming styling: muted italic text
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
contentStr := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
contentRendered := mutedStyle.Render(ty.Italic(contentStr))
// Build label based on duration
if duration > 0 {
var durationStr string
if duration < 1000 {
durationStr = fmt.Sprintf("%dms", duration)
} else {
durationStr = fmt.Sprintf("%.1fs", float64(duration)/1000)
}
labelPart := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
durationPart := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
label := labelPart + durationPart
rendered := contentRendered + "\n" + label
return styleMarginBottom(theme, rendered)
}
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought")
rendered := contentRendered + "\n" + label
return styleMarginBottom(theme, rendered)
}
// SystemBlock renders a system message with herald Note styling.
func SystemBlock(content string, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "No content available"
}
rendered := ty.Note(content)
return styleMarginBottom(theme, rendered)
}
// ErrorBlock renders an error message with herald Caution styling.
func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) string {
rendered := ty.Caution(errorMsg)
return styleMarginBottom(theme, rendered)
}
// ToolBlock renders a tool execution result with header and body.
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
var icon string
iconColor := theme.Success
if isError {
icon = "×"
iconColor = theme.Error
} else {
icon = "✓"
}
// Style the tool name with color
nameColor := theme.Info
if isError {
nameColor = theme.Error
}
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
// Build the content: icon + name + params on first line, then body
headerLine := styledIcon + " " + styledName
if params != "" {
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
}
if strings.TrimSpace(body) == "" {
body = ty.Italic("(no output)")
}
// Compose: icon + name + params, then body
fullContent := ty.Compose(
headerLine,
"",
body,
)
return styleMarginBottom(theme, fullContent)
}
// styleMarginBottom applies a 1-line margin bottom using the theme.
func styleMarginBottom(theme style.Theme, content string) string {
return lipgloss.NewStyle().MarginBottom(1).Render(content)
}
+308 -244
View File
@@ -2,25 +2,13 @@ package ui
import (
"strings"
"time"
"charm.land/lipgloss/v2"
xansi "github.com/charmbracelet/x/ansi"
"github.com/mark3labs/kit/internal/ui/selection"
)
// highlightStyle is lazily initialized to avoid creating it on every render
var highlightStyle lipgloss.Style
// initHighlightStyle creates the highlight style with proper colors
func initHighlightStyle() lipgloss.Style {
if highlightStyle.String() == "" {
theme := GetTheme()
highlightStyle = lipgloss.NewStyle().
Background(theme.Secondary).
Foreground(theme.Background).
Bold(true)
}
return highlightStyle
}
// MessageItem is the interface all scrollback messages must implement.
// This allows lazy rendering - messages are only rendered when visible.
type MessageItem interface {
@@ -36,8 +24,8 @@ type MessageItem interface {
}
// ScrollList manages a viewport over a list of MessageItems.
// It handles offset-based scrolling and lazy rendering. Only visible
// items are rendered on each View() call.
// It handles offset-based scrolling, lazy rendering, and character-level
// text selection (crush-style). Only visible items are rendered on each View() call.
type ScrollList struct {
items []MessageItem
offsetIdx int // Index of first visible item
@@ -46,15 +34,9 @@ type ScrollList struct {
height int // Viewport height in lines
autoScroll bool // Whether to auto-scroll to bottom on new content
itemGap int // Number of blank lines between items (0 = no gap)
focusedIdx int // Index of focused/selected item (-1 = none)
selectable bool // Whether items can be selected via mouse/keyboard
// Selection tracking for copy+paste (crush-style)
selection CopySelection // Current text selection
mouseDown bool // Whether mouse button is currently down
mouseDownX int // X coordinate where mouse was pressed
mouseDownY int // Y coordinate where mouse was pressed
mouseDownItem int // Item index where mouse was pressed
// Character-level text selection (crush-style).
sel selection.State
}
// NewScrollList creates a new ScrollList with the given dimensions.
@@ -65,7 +47,8 @@ func NewScrollList(width, height int) *ScrollList {
offsetLine: 0,
width: width,
height: height,
autoScroll: true, // Start with auto-scroll enabled
autoScroll: true,
sel: selection.NewState(),
}
}
@@ -101,118 +84,210 @@ func (s *ScrollList) ItemGap() int {
return s.itemGap
}
// SetSelectable enables or disables item selection.
func (s *ScrollList) SetSelectable(selectable bool) {
s.selectable = selectable
}
// --------------------------------------------------------------------------
// Mouse event handling — character-level text selection (crush-style)
// --------------------------------------------------------------------------
// FocusedIdx returns the currently focused item index (-1 if none).
func (s *ScrollList) FocusedIdx() int {
return s.focusedIdx
}
// SetFocused sets the focused item by index.
func (s *ScrollList) SetFocused(idx int) {
if idx < -1 {
s.focusedIdx = -1
} else if idx >= len(s.items) {
s.focusedIdx = len(s.items) - 1
} else {
s.focusedIdx = idx
}
}
// SelectItemAtY selects the item at the given Y coordinate (relative to viewport).
// Returns the selected item index or -1 if no item at that position.
func (s *ScrollList) SelectItemAtY(y int) int {
if !s.selectable || len(s.items) == 0 || y < 0 || y >= s.height {
return -1
}
// Calculate which item is at the given Y position
currentY := 0
for idx := s.offsetIdx; idx < len(s.items); idx++ {
item := s.items[idx]
itemHeight := item.Height()
// Check if y falls within this item
if y >= currentY && y < currentY+itemHeight {
s.focusedIdx = idx
return idx
}
currentY += itemHeight
// Add gap after item (except last)
if s.itemGap > 0 && idx < len(s.items)-1 {
currentY += s.itemGap
}
// Stop if we've passed the viewport
if currentY >= s.height {
break
}
}
return -1
}
// HandleMouseDown handles mouse button press for selection (crush-style).
// HandleMouseDown handles mouse button press. Detects single, double, and
// triple clicks for character, word, and line selection respectively.
// Returns true if the click was handled.
func (s *ScrollList) HandleMouseDown(x, y int) bool {
if !s.selectable || len(s.items) == 0 {
if len(s.items) == 0 {
return false
}
s.mouseDown = true
s.mouseDownX = x
s.mouseDownY = y
// Find which item and line was clicked
itemIdx, lineIdx := s.getItemAndLineAtY(y)
s.mouseDownItem = itemIdx
// Start a new selection at click position
if itemIdx >= 0 {
s.selection = CopySelection{
StartItemIdx: itemIdx,
StartLine: lineIdx,
StartCol: x,
EndItemIdx: itemIdx,
EndLine: lineIdx,
EndCol: x,
Active: true,
}
return true
}
return false
}
// HandleMouseDrag handles mouse drag for selection (crush-style).
// Updates the selection end point. Returns true if selection changed.
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
if !s.mouseDown || !s.selectable {
return false
}
// Find which item and line we're dragging over
itemIdx, lineIdx := s.getItemAndLineAtY(y)
if itemIdx < 0 {
return false
}
// Update selection end point
s.selection.EndItemIdx = itemIdx
s.selection.EndLine = lineIdx
s.selection.EndCol = x
s.selection.Active = true
// Multi-click detection (crush-style).
now := time.Now()
if now.Sub(s.sel.LastClickTime) <= selection.DoubleClickThreshold &&
abs(x-s.sel.LastClickX) <= selection.ClickTolerance &&
abs(y-s.sel.LastClickY) <= selection.ClickTolerance {
s.sel.ClickCount++
} else {
s.sel.ClickCount = 1
}
s.sel.LastClickTime = now
s.sel.LastClickX = x
s.sel.LastClickY = y
switch s.sel.ClickCount {
case 1:
// Single click: start character-level drag selection.
s.sel.MouseDown = true
s.sel.MouseDownItemIdx = itemIdx
s.sel.MouseDownLineIdx = lineIdx
s.sel.MouseDownCol = x
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = x
case 2:
// Double click: select word at position.
s.selectWord(itemIdx, lineIdx, x)
case 3:
// Triple click: select entire line.
s.selectLine(itemIdx, lineIdx)
s.sel.ClickCount = 0 // Reset after triple
}
return true
}
// getItemAndLineAtY converts a Y coordinate to item index and line index within that item.
// HandleMouseDrag handles mouse motion while button is held.
// Updates the selection endpoint for character-level precision.
// Returns true if selection was updated.
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
if !s.sel.MouseDown {
return false
}
if len(s.items) == 0 {
return false
}
itemIdx, lineIdx := s.getItemAndLineAtY(y)
if itemIdx < 0 {
return false
}
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = x
return true
}
// HandleMouseUp handles mouse button release.
// Returns true if there was an active selection.
func (s *ScrollList) HandleMouseUp() bool {
if !s.sel.MouseDown {
return false
}
s.sel.MouseDown = false
return s.sel.HasSelection()
}
// HasSelection returns true if there is a non-empty active selection.
func (s *ScrollList) HasSelection() bool {
return s.sel.HasSelection()
}
// ClearSelection clears the current text selection.
func (s *ScrollList) ClearSelection() {
s.sel.Clear()
}
// ExtractSelectedText returns the plain text content of the current selection
// by walking through selected items and extracting text at the character level
// using the ultraviolet cell buffer (ANSI-aware).
func (s *ScrollList) ExtractSelectedText() string {
r := s.sel.GetRange()
if r.IsEmpty() {
return ""
}
var sb strings.Builder
for itemIdx := r.StartItemIdx; itemIdx <= r.EndItemIdx && itemIdx < len(s.items); itemIdx++ {
item := s.items[itemIdx]
content := item.Render(s.width)
contentLines := strings.Split(content, "\n")
for lineIdx, line := range contentLines {
inRange, startCol, endCol := selection.IsLineInRange(r, itemIdx, lineIdx)
if !inRange {
continue
}
text := selection.ExtractText(line, startCol, endCol)
if text != "" {
if sb.Len() > 0 {
sb.WriteString("\n")
}
sb.WriteString(text)
}
}
}
return sb.String()
}
// selectWord selects the word at the given position using UAX#29 word
// segmentation and display-width-aware column calculations.
func (s *ScrollList) selectWord(itemIdx, lineIdx, x int) {
if itemIdx < 0 || itemIdx >= len(s.items) {
return
}
item := s.items[itemIdx]
content := item.Render(s.width)
lines := strings.Split(content, "\n")
if lineIdx < 0 || lineIdx >= len(lines) {
return
}
// Strip ANSI codes for word boundary detection.
plainLine := xansi.Strip(lines[lineIdx])
startCol, endCol := selection.FindWordBoundaries(plainLine, x)
if startCol == endCol {
// No word at this position — set up single-click drag state.
s.sel.MouseDown = true
s.sel.MouseDownItemIdx = itemIdx
s.sel.MouseDownLineIdx = lineIdx
s.sel.MouseDownCol = x
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = x
return
}
// Set selection to the word boundaries.
s.sel.MouseDown = true
s.sel.MouseDownItemIdx = itemIdx
s.sel.MouseDownLineIdx = lineIdx
s.sel.MouseDownCol = startCol
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = endCol
}
// selectLine selects the entire line at the given position.
func (s *ScrollList) selectLine(itemIdx, lineIdx int) {
if itemIdx < 0 || itemIdx >= len(s.items) {
return
}
item := s.items[itemIdx]
content := item.Render(s.width)
lines := strings.Split(content, "\n")
if lineIdx < 0 || lineIdx >= len(lines) {
return
}
lineWidth := xansi.StringWidth(lines[lineIdx])
s.sel.MouseDown = true
s.sel.MouseDownItemIdx = itemIdx
s.sel.MouseDownLineIdx = lineIdx
s.sel.MouseDownCol = 0
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = lineWidth
}
// getItemAndLineAtY converts a viewport-relative Y coordinate to item index
// and line index within that item. Accounts for scroll offset and item gaps.
// Returns (-1, -1) if Y is outside the viewport or beyond all items.
//
// IMPORTANT: Uses Render()+line counting (not Height()) to compute item height,
// because Height() on some MessageItem implementations (e.g. StreamingMessageItem
// for reasoning blocks) may return 0 when the render cache is empty.
func (s *ScrollList) getItemAndLineAtY(y int) (itemIdx, lineIdx int) {
if y < 0 || y >= s.height || len(s.items) == 0 {
return -1, -1
@@ -221,21 +296,27 @@ func (s *ScrollList) getItemAndLineAtY(y int) (itemIdx, lineIdx int) {
currentY := 0
for idx := s.offsetIdx; idx < len(s.items); idx++ {
item := s.items[idx]
itemHeight := item.Height()
// Compute height the same way View() does: render, then count lines.
itemHeight := s.renderedHeight(item)
// Account for partial visibility of the first item.
startLine := 0
if idx == s.offsetIdx {
startLine = s.offsetLine
itemHeight -= s.offsetLine
}
// Check if y falls within this item
if y >= currentY && y < currentY+itemHeight {
return idx, y - currentY
return idx, (y - currentY) + startLine
}
currentY += itemHeight
// Add gap after item (except last)
// Add gap after item (except last).
if s.itemGap > 0 && idx < len(s.items)-1 {
currentY += s.itemGap
}
// Stop if we've passed the viewport
if currentY >= s.height {
break
}
@@ -244,38 +325,9 @@ func (s *ScrollList) getItemAndLineAtY(y int) (itemIdx, lineIdx int) {
return -1, -1
}
// HandleMouseUp handles mouse button release (crush-style).
// Finalizes selection and returns true if there was an active selection.
func (s *ScrollList) HandleMouseUp(x, y int) bool {
if !s.mouseDown {
return false
}
s.mouseDown = false
// Check if we have a valid selection
if s.selection.Active && !s.selection.IsEmpty() {
return true
}
return false
}
// GetSelection returns the current text selection.
func (s *ScrollList) GetSelection() CopySelection {
return s.selection
}
// ClearSelection clears the current text selection.
func (s *ScrollList) ClearSelection() {
s.selection = CopySelection{}
s.mouseDown = false
}
// HasSelection returns true if there is an active non-empty selection.
func (s *ScrollList) HasSelection() bool {
return s.selection.Active && !s.selection.IsEmpty()
}
// --------------------------------------------------------------------------
// Scrolling
// --------------------------------------------------------------------------
// ScrollBy scrolls the viewport by the given number of lines.
// Positive = scroll down, negative = scroll up.
@@ -363,8 +415,9 @@ func (s *ScrollList) GotoBottom() {
// Calculate total height including gaps
totalHeight := 0
for i, item := range s.items {
totalHeight += item.Height()
// Add gap after each item except the last
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
totalHeight += itemHeight
if s.itemGap > 0 && i < len(s.items)-1 {
totalHeight += s.itemGap
}
@@ -380,14 +433,14 @@ func (s *ScrollList) GotoBottom() {
// Otherwise, position viewport at bottom
remaining := totalHeight - s.height
for idx := 0; idx < len(s.items); idx++ {
itemHeight := s.items[idx].Height()
rendered := s.items[idx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if remaining < itemHeight {
s.offsetIdx = idx
s.offsetLine = remaining
return
}
remaining -= itemHeight
// Subtract gap after item (except last)
if s.itemGap > 0 && idx < len(s.items)-1 {
remaining -= s.itemGap
}
@@ -410,11 +463,11 @@ func (s *ScrollList) AtBottom() bool {
return true
}
// Calculate visible height from current position including gaps
visibleHeight := 0
for idx := s.offsetIdx; idx < len(s.items); idx++ {
item := s.items[idx]
itemHeight := item.Height()
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if idx == s.offsetIdx {
visibleHeight += itemHeight - s.offsetLine
@@ -422,7 +475,6 @@ func (s *ScrollList) AtBottom() bool {
visibleHeight += itemHeight
}
// Add gap after item (except last)
if s.itemGap > 0 && idx < len(s.items)-1 {
visibleHeight += s.itemGap
}
@@ -440,19 +492,28 @@ func (s *ScrollList) AtTop() bool {
return s.offsetIdx == 0 && s.offsetLine == 0
}
// --------------------------------------------------------------------------
// Rendering
// --------------------------------------------------------------------------
// View renders the visible portion of the scrollback.
// Only items that fit within the viewport height are rendered.
// ALWAYS returns exactly s.height lines (padded with empty lines if needed)
// to ensure the input/footer stay fixed at the bottom.
//
// When an active selection exists, character-level highlighting is applied
// using ultraviolet ScreenBuffer for ANSI-aware cell manipulation.
func (s *ScrollList) View() string {
if s.height <= 0 {
return ""
}
selRange := s.sel.GetRange()
hasSelection := !selRange.IsEmpty()
var lines []string
remainingHeight := s.height
// Render visible items
if len(s.items) > 0 {
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
item := s.items[idx]
@@ -464,25 +525,22 @@ func (s *ScrollList) View() string {
startLine = s.offsetLine
}
// Check if this item is focused (for visual indicator)
isFocused := idx == s.focusedIdx
for i := startLine; i < len(contentLines) && remainingHeight > 0; i++ {
line := contentLines[i]
// Apply selection highlighting if this line is within selection
if s.selection.Active && s.isLineInSelection(idx, i) {
line = s.applyHighlight(line)
} else if isFocused && s.selectable {
// Apply subtle focus indicator when item is focused but not in selection
line = s.applyFocusIndicator(line)
// Apply character-level selection highlighting.
if hasSelection {
inRange, startCol, endCol := selection.IsLineInRange(selRange, idx, i)
if inRange {
line = selection.HighlightLine(line, startCol, endCol)
}
}
lines = append(lines, line)
remainingHeight--
}
// Add gap lines between items (but not after the last visible item)
// Add gap lines between items.
if remainingHeight > 0 && idx < len(s.items)-1 && s.itemGap > 0 {
for g := 0; g < s.itemGap && remainingHeight > 0; g++ {
lines = append(lines, "")
@@ -492,8 +550,7 @@ func (s *ScrollList) View() string {
}
}
// Pad with empty lines to ensure exactly s.height lines
// This keeps the input/footer fixed at the bottom of the screen
// Pad with empty lines to ensure exactly s.height lines.
for remainingHeight > 0 {
lines = append(lines, "")
remainingHeight--
@@ -502,65 +559,6 @@ func (s *ScrollList) View() string {
return strings.Join(lines, "\n")
}
// isLineInSelection checks if a specific line within an item is part of the current selection.
func (s *ScrollList) isLineInSelection(itemIdx, lineIdx int) bool {
if !s.selection.Active {
return false
}
// Normalize selection (start <= end)
startItem := s.selection.StartItemIdx
startLine := s.selection.StartLine
endItem := s.selection.EndItemIdx
endLine := s.selection.EndLine
if startItem > endItem || (startItem == endItem && startLine > endLine) {
startItem, endItem = endItem, startItem
startLine, endLine = endLine, startLine
}
// Check if item is within selection range
if itemIdx < startItem || itemIdx > endItem {
return false
}
// For single item selection
if startItem == endItem {
return itemIdx == startItem && lineIdx >= startLine && lineIdx <= endLine
}
// For multi-item selection
if itemIdx == startItem {
return lineIdx >= startLine
}
if itemIdx == endItem {
return lineIdx <= endLine
}
// Middle items are fully selected
return itemIdx > startItem && itemIdx < endItem
}
// applyHighlight applies the highlight style to a line.
// Uses the theme's Highlight color for the background.
func (s *ScrollList) applyHighlight(line string) string {
if line == "" {
return line
}
// Apply background/foreground color change for selection
style := initHighlightStyle()
return style.Render(line)
}
// applyFocusIndicator applies a subtle visual indicator for focused items.
func (s *ScrollList) applyFocusIndicator(line string) string {
if line == "" {
return line
}
// Just return the line as-is - no visual indicator for focus
// The selection highlighting is enough
return line
}
// ScrollPercent returns the current scroll position as a percentage (0.0-1.0).
// 0.0 = at top, 1.0 = at bottom. Useful for scroll indicators.
func (s *ScrollList) ScrollPercent() float64 {
@@ -574,10 +572,9 @@ func (s *ScrollList) ScrollPercent() float64 {
}
if totalHeight <= s.height {
return 1.0 // All content fits, consider it "at bottom"
return 1.0
}
// Calculate how many lines are above the viewport
linesAbove := 0
for i := 0; i < s.offsetIdx && i < len(s.items); i++ {
linesAbove += s.items[i].Height()
@@ -608,7 +605,6 @@ func (s *ScrollList) clampOffset() {
return
}
// Clamp offsetIdx
if s.offsetIdx >= len(s.items) {
s.offsetIdx = len(s.items) - 1
}
@@ -616,9 +612,9 @@ func (s *ScrollList) clampOffset() {
s.offsetIdx = 0
}
// Clamp offsetLine
if s.offsetIdx < len(s.items) {
itemHeight := s.items[s.offsetIdx].Height()
rendered := s.items[s.offsetIdx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if s.offsetLine >= itemHeight {
s.offsetLine = max(0, itemHeight-1)
}
@@ -626,4 +622,72 @@ func (s *ScrollList) clampOffset() {
if s.offsetLine < 0 {
s.offsetLine = 0
}
// Prevent scrolling past the bottom
totalHeight := 0
for i, item := range s.items {
rendered := item.Render(s.width)
totalHeight += strings.Count(rendered, "\n") + 1
if s.itemGap > 0 && i < len(s.items)-1 {
totalHeight += s.itemGap
}
}
if totalHeight <= s.height {
s.offsetIdx = 0
s.offsetLine = 0
return
}
linesAbove := 0
for i := 0; i < s.offsetIdx; i++ {
rendered := s.items[i].Render(s.width)
linesAbove += strings.Count(rendered, "\n") + 1
if s.itemGap > 0 && i < len(s.items)-1 {
linesAbove += s.itemGap
}
}
linesAbove += s.offsetLine
linesFromCurrentToEnd := totalHeight - linesAbove
if linesFromCurrentToEnd < s.height {
targetLine := totalHeight - s.height
currentLine := 0
for idx := 0; idx < len(s.items); idx++ {
rendered := s.items[idx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if currentLine+itemHeight > targetLine {
s.offsetIdx = idx
s.offsetLine = targetLine - currentLine
return
}
currentLine += itemHeight
if s.itemGap > 0 && idx < len(s.items)-1 {
currentLine += s.itemGap
}
}
}
}
// renderedHeight returns the height of a message item in lines by actually
// rendering it. This is the single source of truth for item height — it
// matches exactly what View() produces, unlike item.Height() which may
// return stale/zero values for uncached items (e.g. reasoning blocks).
func (s *ScrollList) renderedHeight(item MessageItem) int {
rendered := item.Render(s.width)
if rendered == "" {
return 0
}
return strings.Count(rendered, "\n") + 1
}
// abs returns the absolute value of x.
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
+324
View File
@@ -0,0 +1,324 @@
// Package selection provides character-level text selection for terminal UIs.
//
// It handles converting mouse coordinates (in terminal cells) to character
// positions within rendered ANSI-styled text, supporting multi-byte characters,
// wide characters (CJK, emoji), and word/line selection via double/triple click.
//
// The approach is modeled after Charm's crush: all coordinate calculations use
// display columns (terminal cells), not byte offsets or rune counts. The
// ultraviolet ScreenBuffer provides the bridge between rendered ANSI strings
// and individual character cells.
package selection
import (
"image"
"strings"
"time"
uv "github.com/charmbracelet/ultraviolet"
xansi "github.com/charmbracelet/x/ansi"
"github.com/clipperhouse/displaywidth"
"github.com/clipperhouse/uax29/v2/words"
)
// DoubleClickThreshold is the maximum time between clicks for multi-click.
const DoubleClickThreshold = 400 * time.Millisecond
// ClickTolerance is the pixel/cell tolerance for multi-click detection.
const ClickTolerance = 2
// State tracks the full state of a mouse text selection.
type State struct {
// Whether a mouse button is currently held down.
MouseDown bool
// Position where mouse was first pressed (viewport-relative).
MouseDownItemIdx int
MouseDownLineIdx int
MouseDownCol int
// Current drag position (viewport-relative).
DragItemIdx int
DragLineIdx int
DragCol int
// Multi-click detection.
LastClickTime time.Time
LastClickX int
LastClickY int
ClickCount int
}
// Range represents a normalized (start <= end) selection range.
type Range struct {
StartItemIdx int
StartLine int
StartCol int
EndItemIdx int
EndLine int
EndCol int
}
// IsEmpty returns true if the range selects nothing.
func (r Range) IsEmpty() bool {
return r.StartItemIdx < 0 || r.EndItemIdx < 0 ||
(r.StartItemIdx == r.EndItemIdx && r.StartLine == r.EndLine && r.StartCol == r.EndCol)
}
// NewState creates a new empty selection state.
func NewState() State {
return State{
MouseDownItemIdx: -1,
DragItemIdx: -1,
}
}
// Clear resets all selection state.
func (s *State) Clear() {
s.MouseDown = false
s.MouseDownItemIdx = -1
s.MouseDownLineIdx = 0
s.MouseDownCol = 0
s.DragItemIdx = -1
s.DragLineIdx = 0
s.DragCol = 0
s.LastClickTime = time.Time{}
s.LastClickX = 0
s.LastClickY = 0
s.ClickCount = 0
}
// HasSelection returns true if there is a non-empty active selection.
func (s *State) HasSelection() bool {
return s.MouseDownItemIdx >= 0 && s.DragItemIdx >= 0 && !s.GetRange().IsEmpty()
}
// GetRange returns the normalized selection range (start <= end).
func (s *State) GetRange() Range {
if s.MouseDownItemIdx < 0 || s.DragItemIdx < 0 {
return Range{StartItemIdx: -1, EndItemIdx: -1}
}
downItem := s.MouseDownItemIdx
downLine := s.MouseDownLineIdx
downCol := s.MouseDownCol
dragItem := s.DragItemIdx
dragLine := s.DragLineIdx
dragCol := s.DragCol
// Determine if dragging forward or backward.
forward := dragItem > downItem ||
(dragItem == downItem && dragLine > downLine) ||
(dragItem == downItem && dragLine == downLine && dragCol >= downCol)
if forward {
return Range{
StartItemIdx: downItem,
StartLine: downLine,
StartCol: downCol,
EndItemIdx: dragItem,
EndLine: dragLine,
EndCol: dragCol,
}
}
return Range{
StartItemIdx: dragItem,
StartLine: dragLine,
StartCol: dragCol,
EndItemIdx: downItem,
EndLine: downLine,
EndCol: downCol,
}
}
// IsLineInRange checks if a specific line within an item falls inside the
// selection range. Returns (inRange, startCol, endCol) where startCol == -1
// means the entire line is selected. startCol == endCol means no selection
// on this line.
func IsLineInRange(r Range, itemIdx, lineIdx int) (bool, int, int) {
if r.IsEmpty() {
return false, 0, 0
}
// Outside item range entirely.
if itemIdx < r.StartItemIdx || itemIdx > r.EndItemIdx {
return false, 0, 0
}
// Single-item selection.
if r.StartItemIdx == r.EndItemIdx {
if itemIdx != r.StartItemIdx {
return false, 0, 0
}
if lineIdx < r.StartLine || lineIdx > r.EndLine {
return false, 0, 0
}
if r.StartLine == r.EndLine {
// Single line: specific column range.
return true, r.StartCol, r.EndCol
}
if lineIdx == r.StartLine {
return true, r.StartCol, -1 // from startCol to end of line
}
if lineIdx == r.EndLine {
return true, 0, r.EndCol // from start of line to endCol
}
return true, -1, -1 // full line (middle of multi-line selection)
}
// Multi-item selection.
if itemIdx == r.StartItemIdx {
if lineIdx < r.StartLine {
return false, 0, 0
}
if lineIdx == r.StartLine {
return true, r.StartCol, -1
}
return true, -1, -1 // full line
}
if itemIdx == r.EndItemIdx {
if lineIdx > r.EndLine {
return false, 0, 0
}
if lineIdx == r.EndLine {
return true, 0, r.EndCol
}
return true, -1, -1 // full line
}
// Middle item: fully selected.
return true, -1, -1
}
// FindWordBoundaries finds the start and end column of the word at the given
// column position in a plain-text line (ANSI codes already stripped).
// Returns (startCol, endCol) where endCol is exclusive.
// Uses UAX#29 word segmentation and display-width-aware column tracking.
func FindWordBoundaries(line string, col int) (startCol, endCol int) {
if line == "" || col < 0 {
return 0, 0
}
// Segment the line into words using UAX#29.
lineCol := 0
iter := words.FromString(line)
for iter.Next() {
token := iter.Value()
tokenWidth := displaywidth.String(token)
graphemeStart := lineCol
graphemeEnd := lineCol + tokenWidth
lineCol += tokenWidth
// If clicked before this token, no word here.
if col < graphemeStart {
return col, col
}
// If clicked within this token, return its boundaries.
if col >= graphemeStart && col < graphemeEnd {
// Whitespace tokens produce empty selection.
if strings.TrimSpace(token) == "" {
return col, col
}
return graphemeStart, graphemeEnd
}
}
return col, col
}
// HighlightLine applies reverse-video highlighting to a portion of a rendered
// line (which may contain ANSI escape codes). startCol/endCol are in display
// columns. If startCol == -1, the entire line is highlighted. If startCol ==
// endCol, returns the line unchanged.
//
// Uses ultraviolet ScreenBuffer for cell-level ANSI manipulation.
func HighlightLine(line string, startCol, endCol int) string {
if line == "" {
return line
}
lineWidth := xansi.StringWidth(line)
if lineWidth == 0 {
return line
}
// Full-line highlight.
if startCol == -1 {
startCol = 0
endCol = lineWidth
}
if startCol >= endCol || startCol >= lineWidth {
return line
}
if endCol > lineWidth {
endCol = lineWidth
}
// Parse the styled line into a cell buffer.
area := image.Rect(0, 0, lineWidth, 1)
buf := uv.NewScreenBuffer(lineWidth, 1)
styled := uv.NewStyledString(line)
styled.Draw(&buf, area)
// Apply reverse attribute to cells in the selection range.
if buf.Height() > 0 {
bufLine := buf.Line(0)
for x := startCol; x < endCol && x < len(bufLine); x++ {
cell := bufLine.At(x)
if cell != nil {
cell.Style.Attrs |= uv.AttrReverse
}
}
}
return buf.Render()
}
// ExtractText extracts plain text from a rendered ANSI string within the given
// column range on a single line. Uses ultraviolet to parse ANSI and extract
// character content.
func ExtractText(line string, startCol, endCol int) string {
if line == "" {
return ""
}
lineWidth := xansi.StringWidth(line)
if lineWidth == 0 {
return ""
}
// Full-line extraction.
if startCol == -1 {
startCol = 0
endCol = lineWidth
}
if startCol >= endCol || startCol >= lineWidth {
return ""
}
if endCol > lineWidth {
endCol = lineWidth
}
// Parse to cell buffer.
area := image.Rect(0, 0, lineWidth, 1)
buf := uv.NewScreenBuffer(lineWidth, 1)
styled := uv.NewStyledString(line)
styled.Draw(&buf, area)
var sb strings.Builder
if buf.Height() > 0 {
bufLine := buf.Line(0)
for x := startCol; x < endCol && x < len(bufLine); x++ {
cell := bufLine.At(x)
if cell != nil && cell.Content != "" {
sb.WriteString(cell.Content)
}
}
}
return sb.String()
}
+400
View File
@@ -0,0 +1,400 @@
package selection
import (
"testing"
"time"
)
func TestNewState(t *testing.T) {
s := NewState()
if s.MouseDownItemIdx != -1 {
t.Errorf("expected MouseDownItemIdx -1, got %d", s.MouseDownItemIdx)
}
if s.DragItemIdx != -1 {
t.Errorf("expected DragItemIdx -1, got %d", s.DragItemIdx)
}
if s.MouseDown {
t.Error("expected MouseDown false")
}
if s.HasSelection() {
t.Error("expected no selection on new state")
}
}
func TestClear(t *testing.T) {
s := NewState()
s.MouseDown = true
s.MouseDownItemIdx = 2
s.DragItemIdx = 3
s.ClickCount = 2
s.Clear()
if s.MouseDown {
t.Error("expected MouseDown false after clear")
}
if s.MouseDownItemIdx != -1 {
t.Errorf("expected MouseDownItemIdx -1 after clear, got %d", s.MouseDownItemIdx)
}
if s.DragItemIdx != -1 {
t.Errorf("expected DragItemIdx -1 after clear, got %d", s.DragItemIdx)
}
if s.ClickCount != 0 {
t.Errorf("expected ClickCount 0 after clear, got %d", s.ClickCount)
}
}
func TestGetRange_Forward(t *testing.T) {
s := NewState()
s.MouseDownItemIdx = 0
s.MouseDownLineIdx = 1
s.MouseDownCol = 5
s.DragItemIdx = 0
s.DragLineIdx = 3
s.DragCol = 10
r := s.GetRange()
if r.StartItemIdx != 0 || r.StartLine != 1 || r.StartCol != 5 {
t.Errorf("unexpected start: item=%d line=%d col=%d", r.StartItemIdx, r.StartLine, r.StartCol)
}
if r.EndItemIdx != 0 || r.EndLine != 3 || r.EndCol != 10 {
t.Errorf("unexpected end: item=%d line=%d col=%d", r.EndItemIdx, r.EndLine, r.EndCol)
}
}
func TestGetRange_Backward(t *testing.T) {
s := NewState()
s.MouseDownItemIdx = 2
s.MouseDownLineIdx = 5
s.MouseDownCol = 20
s.DragItemIdx = 0
s.DragLineIdx = 1
s.DragCol = 3
r := s.GetRange()
// Should be normalized: drag position becomes start
if r.StartItemIdx != 0 || r.StartLine != 1 || r.StartCol != 3 {
t.Errorf("unexpected start: item=%d line=%d col=%d", r.StartItemIdx, r.StartLine, r.StartCol)
}
if r.EndItemIdx != 2 || r.EndLine != 5 || r.EndCol != 20 {
t.Errorf("unexpected end: item=%d line=%d col=%d", r.EndItemIdx, r.EndLine, r.EndCol)
}
}
func TestGetRange_SameLine(t *testing.T) {
s := NewState()
s.MouseDownItemIdx = 1
s.MouseDownLineIdx = 2
s.MouseDownCol = 10
s.DragItemIdx = 1
s.DragLineIdx = 2
s.DragCol = 20
r := s.GetRange()
if r.IsEmpty() {
t.Error("expected non-empty range")
}
if r.StartCol != 10 || r.EndCol != 20 {
t.Errorf("expected cols 10-20, got %d-%d", r.StartCol, r.EndCol)
}
}
func TestRangeIsEmpty(t *testing.T) {
// Same point
r := Range{StartItemIdx: 0, StartLine: 0, StartCol: 5, EndItemIdx: 0, EndLine: 0, EndCol: 5}
if !r.IsEmpty() {
t.Error("expected same-point range to be empty")
}
// Negative item idx
r = Range{StartItemIdx: -1, EndItemIdx: -1}
if !r.IsEmpty() {
t.Error("expected negative item idx range to be empty")
}
// Valid range
r = Range{StartItemIdx: 0, StartLine: 0, StartCol: 0, EndItemIdx: 0, EndLine: 0, EndCol: 5}
if r.IsEmpty() {
t.Error("expected valid range to not be empty")
}
}
func TestHasSelection(t *testing.T) {
s := NewState()
if s.HasSelection() {
t.Error("new state should have no selection")
}
// Set up a valid selection
s.MouseDownItemIdx = 0
s.MouseDownLineIdx = 0
s.MouseDownCol = 0
s.DragItemIdx = 0
s.DragLineIdx = 0
s.DragCol = 10
if !s.HasSelection() {
t.Error("expected selection to exist")
}
// Same point = no selection
s.DragCol = 0
if s.HasSelection() {
t.Error("same point should not be a selection")
}
}
func TestIsLineInRange_SingleItem_SingleLine(t *testing.T) {
r := Range{
StartItemIdx: 1, StartLine: 2, StartCol: 5,
EndItemIdx: 1, EndLine: 2, EndCol: 15,
}
// Exact line
ok, sc, ec := IsLineInRange(r, 1, 2)
if !ok || sc != 5 || ec != 15 {
t.Errorf("expected (true, 5, 15), got (%v, %d, %d)", ok, sc, ec)
}
// Wrong line
ok, _, _ = IsLineInRange(r, 1, 0)
if ok {
t.Error("line 0 should not be in range")
}
// Wrong item
ok, _, _ = IsLineInRange(r, 0, 2)
if ok {
t.Error("item 0 should not be in range")
}
}
func TestIsLineInRange_SingleItem_MultiLine(t *testing.T) {
r := Range{
StartItemIdx: 0, StartLine: 1, StartCol: 5,
EndItemIdx: 0, EndLine: 4, EndCol: 10,
}
// Start line
ok, sc, ec := IsLineInRange(r, 0, 1)
if !ok || sc != 5 || ec != -1 {
t.Errorf("start line: expected (true, 5, -1), got (%v, %d, %d)", ok, sc, ec)
}
// Middle line
ok, sc, ec = IsLineInRange(r, 0, 2)
if !ok || sc != -1 || ec != -1 {
t.Errorf("middle line: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
}
// End line
ok, sc, ec = IsLineInRange(r, 0, 4)
if !ok || sc != 0 || ec != 10 {
t.Errorf("end line: expected (true, 0, 10), got (%v, %d, %d)", ok, sc, ec)
}
}
func TestIsLineInRange_MultiItem(t *testing.T) {
r := Range{
StartItemIdx: 0, StartLine: 3, StartCol: 5,
EndItemIdx: 2, EndLine: 1, EndCol: 10,
}
// First item, start line
ok, sc, ec := IsLineInRange(r, 0, 3)
if !ok || sc != 5 || ec != -1 {
t.Errorf("first item start: expected (true, 5, -1), got (%v, %d, %d)", ok, sc, ec)
}
// First item, line after start
ok, sc, ec = IsLineInRange(r, 0, 5)
if !ok || sc != -1 || ec != -1 {
t.Errorf("first item after: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
}
// Middle item, any line
ok, sc, ec = IsLineInRange(r, 1, 0)
if !ok || sc != -1 || ec != -1 {
t.Errorf("middle item: expected (true, -1, -1), got (%v, %d, %d)", ok, sc, ec)
}
// Last item, end line
ok, sc, ec = IsLineInRange(r, 2, 1)
if !ok || sc != 0 || ec != 10 {
t.Errorf("last item end: expected (true, 0, 10), got (%v, %d, %d)", ok, sc, ec)
}
// Last item, line after end
ok, _, _ = IsLineInRange(r, 2, 5)
if ok {
t.Error("line after end in last item should not be in range")
}
}
func TestFindWordBoundaries(t *testing.T) {
tests := []struct {
name string
line string
col int
wantStart int
wantEnd int
}{
{
name: "simple word",
line: "hello world",
col: 2,
wantStart: 0,
wantEnd: 5,
},
{
name: "second word",
line: "hello world",
col: 7,
wantStart: 6,
wantEnd: 11,
},
{
name: "on space",
line: "hello world",
col: 5,
wantStart: 5,
wantEnd: 5,
},
{
name: "empty line",
line: "",
col: 0,
wantStart: 0,
wantEnd: 0,
},
{
name: "negative col",
line: "hello",
col: -1,
wantStart: 0,
wantEnd: 0,
},
{
name: "past end",
line: "hello",
col: 10,
wantStart: 10,
wantEnd: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start, end := FindWordBoundaries(tt.line, tt.col)
if start != tt.wantStart || end != tt.wantEnd {
t.Errorf("FindWordBoundaries(%q, %d) = (%d, %d), want (%d, %d)",
tt.line, tt.col, start, end, tt.wantStart, tt.wantEnd)
}
})
}
}
func TestExtractText_PlainText(t *testing.T) {
line := "Hello, World!"
text := ExtractText(line, 0, 5)
if text != "Hello" {
t.Errorf("expected 'Hello', got %q", text)
}
text = ExtractText(line, 7, 12)
if text != "World" {
t.Errorf("expected 'World', got %q", text)
}
}
func TestExtractText_FullLine(t *testing.T) {
line := "Hello"
text := ExtractText(line, -1, -1)
if text != "Hello" {
t.Errorf("expected 'Hello', got %q", text)
}
}
func TestExtractText_Empty(t *testing.T) {
text := ExtractText("", 0, 5)
if text != "" {
t.Errorf("expected empty string, got %q", text)
}
}
func TestExtractText_OutOfBounds(t *testing.T) {
line := "Hi"
text := ExtractText(line, 5, 10)
if text != "" {
t.Errorf("expected empty string for out of bounds, got %q", text)
}
}
func TestHighlightLine_PlainText(t *testing.T) {
line := "Hello, World!"
result := HighlightLine(line, 0, 5)
// Should produce a non-empty result different from input (has ANSI codes)
if result == "" {
t.Error("expected non-empty result")
}
// Should still contain the text content
if len(result) < len(line) {
t.Error("result should be at least as long as input (ANSI codes add length)")
}
}
func TestHighlightLine_Empty(t *testing.T) {
result := HighlightLine("", 0, 5)
if result != "" {
t.Errorf("expected empty for empty input, got %q", result)
}
}
func TestHighlightLine_NoSelection(t *testing.T) {
line := "Hello"
result := HighlightLine(line, 3, 3)
// Same startCol and endCol = no change
if result != line {
t.Errorf("expected no change for zero-width selection, got %q", result)
}
}
// TestMultiClickDetection verifies the click counting logic.
func TestMultiClickDetection(t *testing.T) {
s := NewState()
now := time.Now()
// First click
s.LastClickTime = now
s.LastClickX = 10
s.LastClickY = 5
s.ClickCount = 1
// Second click within threshold
later := now.Add(200 * time.Millisecond)
if later.Sub(s.LastClickTime) <= DoubleClickThreshold {
if abs(10-s.LastClickX) <= ClickTolerance && abs(5-s.LastClickY) <= ClickTolerance {
s.ClickCount++
}
}
if s.ClickCount != 2 {
t.Errorf("expected click count 2, got %d", s.ClickCount)
}
// Third click
s.LastClickTime = later
later2 := later.Add(200 * time.Millisecond)
if later2.Sub(s.LastClickTime) <= DoubleClickThreshold {
if abs(10-s.LastClickX) <= ClickTolerance && abs(5-s.LastClickY) <= ClickTolerance {
s.ClickCount++
}
}
if s.ClickCount != 3 {
t.Errorf("expected click count 3, got %d", s.ClickCount)
}
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
+134 -63
View File
@@ -12,6 +12,7 @@ import (
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/ui/style"
)
// SessionSelectedMsg is sent when the user selects a session from the picker.
@@ -158,12 +159,12 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
switch {
case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))):
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
if ss.cursor > 0 {
ss.cursor--
}
case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))):
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
if ss.cursor < len(ss.filtered)-1 {
ss.cursor++
}
@@ -250,58 +251,108 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// View implements tea.Model.
func (ss *SessionSelectorComponent) View() tea.View {
theme := GetTheme()
w := ss.width
var b strings.Builder
theme := style.GetTheme()
// Full-screen bordered container - uses entire terminal width and height
maxWidth := ss.width - 2 // Small margin on each side
if maxWidth < 20 {
maxWidth = ss.width
}
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
if maxHeight < 10 {
maxHeight = ss.height
}
horizontalPadding := 1
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
// Container style with border - full width/height like a framed panel
containerStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(theme.Background).
Padding(1, horizontalPadding).
Width(maxWidth).
Height(maxHeight)
var contentBuilder strings.Builder
// ── Header: title + scope badges ─────────────────────────────
titleStyle := lipgloss.NewStyle().Bold(true).Foreground(theme.Accent).PaddingLeft(1)
b.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
b.WriteString("\n")
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
Background(theme.Background)
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
contentBuilder.WriteString("\n")
// ── Help / keybindings ───────────────────────────────────────
helpStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(1)
if w >= 75 {
b.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
} else if w >= 50 {
b.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
helpStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
if innerWidth >= 75 {
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
} else if innerWidth >= 50 {
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
} else {
b.WriteString(helpStyle.Render("tab N D esc"))
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
}
b.WriteString("\n")
contentBuilder.WriteString("\n")
// ── Search (only shown when active) ──────────────────────────
if ss.search != "" {
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(1)
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
b.WriteString("\n")
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(theme.Background)
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
contentBuilder.WriteString("\n")
}
b.WriteString("\n")
// Separator line
sepWidth := innerWidth
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// ── Delete confirmation ──────────────────────────────────────
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
warnStyle := lipgloss.NewStyle().Foreground(theme.Error).Bold(true).PaddingLeft(1)
warnStyle := lipgloss.NewStyle().
Foreground(theme.Error).
Bold(true).
Background(theme.Background)
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
b.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
b.WriteString("\n")
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
contentBuilder.WriteString("\n")
}
// ── Session list ─────────────────────────────────────────────
if len(ss.filtered) == 0 {
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
if ss.search != "" {
b.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
} else if ss.filter == SessionFilterNamed {
b.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
} else if ss.scope == SessionScopeCwd {
b.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
} else {
b.WriteString(emptyStyle.Render("No sessions found"))
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
}
b.WriteString("\n")
contentBuilder.WriteString("\n")
} else {
visH := ss.visibleHeight()
// Compute visible window based on inner container height
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
chromeLines := 5
if ss.search != "" {
chromeLines++
}
if ss.confirmDelete >= 0 {
chromeLines++
}
visH := max(innerHeight-chromeLines, 3)
// Center the cursor in the visible window.
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
@@ -312,26 +363,41 @@ func (ss *SessionSelectorComponent) View() tea.View {
isCursor := i == ss.cursor
isCurrent := info.Path == ss.currentPath
isDeleting := i == ss.confirmDelete
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, w)
b.WriteString(line)
b.WriteString("\n")
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
contentBuilder.WriteString(line)
contentBuilder.WriteString("\n")
}
// Scroll position indicator.
if len(ss.filtered) > visH {
posStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
b.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
b.WriteString("\n")
posStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
contentBuilder.WriteString("\n")
}
}
v := tea.NewView(b.String())
// Footer separator
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// Footer with filter info
footerStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
// Apply the bordered container
content := contentBuilder.String()
borderedContent := containerStyle.Render(content)
v := tea.NewView(borderedContent)
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
return v
}
@@ -410,12 +476,12 @@ func removeByPath(sessions []session.SessionInfo, path string) []session.Session
// renderEntry renders a single session line with right-aligned metadata.
// Layout: [cursor 2] [message ...variable...] [padding] [count age] [cwd?]
func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCursor, isCurrent, isDeleting bool, width int) string {
theme := GetTheme()
theme := style.GetTheme()
// ── Cursor indicator (2 chars) ───────────────────────────────
cursorStr := " "
if isCursor {
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render(" ")
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
}
const cursorW = 2
@@ -443,45 +509,50 @@ func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCurs
msgW := utf8.RuneCountInString(displayText)
// ── Style the message ────────────────────────────────────────
msgStyle := lipgloss.NewStyle()
var msgStyle lipgloss.Style
switch {
case isDeleting:
msgStyle = msgStyle.Foreground(theme.Error)
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
case isCurrent:
msgStyle = msgStyle.Foreground(theme.Accent)
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
case info.Name != "":
msgStyle = msgStyle.Foreground(theme.Warning)
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
default:
msgStyle = msgStyle.Foreground(theme.Text)
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
}
if isCursor {
msgStyle = msgStyle.Bold(true)
}
styledMsg := msgStyle.Render(displayText)
// ── Style the right part ─────────────────────────────────────
rightColor := theme.Muted
if isDeleting {
rightColor = theme.Error
}
styledRight := lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
var styledRight string
// ── Assemble with spacing ────────────────────────────────────
spacing := max(width-cursorW-msgW-rightW, 1)
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
// ── Background highlight for selected row ────────────────────
// If selected, use inverted colors like PopupList
if isCursor {
// Use a subtle background highlight. We apply it by wrapping the
// full line in a style with a background color.
bgStyle := lipgloss.NewStyle().
Background(theme.Highlight).
Width(width)
line = bgStyle.Render(line)
// Inverted colors for selected item
msgStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
styledRight = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(rightColor).
Render(rightPart)
cursorStr = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Accent).
Render("> ")
} else {
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
}
styledMsg := msgStyle.Render(displayText)
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
return line
}
+31 -215
View File
@@ -2,25 +2,15 @@ package ui
import (
"fmt"
"regexp"
"strings"
"time"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/indaco/herald"
"github.com/mark3labs/kit/internal/app"
)
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
// reasoning content in. Used to strip these tags from streaming text content.
// The (?s) flag makes . match newlines.
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
// thinkTagOpen and thinkTagClose are the opening and closing think tag strings.
const (
thinkTagOpen = "<think>"
thinkTagClose = "</think>"
"github.com/mark3labs/kit/internal/ui/style"
)
// knightRiderFrames generates a KITT-style scanning animation where a bright
@@ -31,7 +21,7 @@ func knightRiderFrames() []string {
const numDots = 8
const dot = "▪"
theme := GetTheme()
theme := style.GetTheme()
bright := lipgloss.NewStyle().Foreground(theme.Primary)
med := lipgloss.NewStyle().Foreground(theme.Muted)
@@ -131,13 +121,13 @@ const (
// alongside streaming text until the step completes and Reset() is called.
//
// Tool calls, tool results, user messages, and other non-streaming content
// are printed immediately by the parent AppModel via tea.Println(). The
// StreamComponent only handles the live streaming text and spinner display.
// are added to the ScrollList by the parent AppModel. The StreamComponent
// only handles the live streaming text and spinner display.
//
// Lifecycle is managed entirely by the parent AppModel:
// - Parent calls Reset() between agent steps to clear state.
// - Parent emits completed responses above the BT region via tea.Println()
// then calls Reset(); StreamComponent never calls tea.Quit.
// - Content is displayed via StreamingMessageItem in the ScrollList.
// - StreamComponent never calls tea.Quit.
//
// Events handled:
// - app.SpinnerEvent{Show:true} → start spinner tick loop
@@ -196,23 +186,6 @@ type StreamComponent struct {
// ticks from a previous step can be discarded.
flushGeneration uint64
// renderCache holds the last rendered output string. Reused by View()
// between flush ticks to avoid redundant markdown re-parsing.
renderCache string
// renderDirty is true when committed content has changed since the
// last render. Set on flush tick; cleared after render() rebuilds
// the cache.
renderDirty bool
// scrollbackFlushedLines is the number of lines from the top of the
// rendered content that have already been emitted to the terminal
// scrollback buffer. On each flush, lines that overflow the allocated
// height and haven't been pushed yet are emitted via tea.Println so
// they appear in the terminal's real scrollback (scrollable with the
// terminal's own scroll mechanism).
scrollbackFlushedLines int
// thinkingVisible controls whether reasoning blocks are expanded or collapsed.
thinkingVisible bool
@@ -222,10 +195,6 @@ type StreamComponent struct {
// reasoningDuration holds the total reasoning time, frozen when streaming text begins.
reasoningDuration time.Duration
// inThinkTag tracks whether we're currently inside a section
// from models that wrap reasoning in XML-like tags (Qwen, DeepSeek).
inThinkTag bool
// renderer renders streaming assistant text.
renderer Renderer
@@ -272,9 +241,6 @@ func (s *StreamComponent) SetHeight(h int) {
}
if s.height != h {
s.height = h
// Invalidate cache — height clamp affects output.
s.renderCache = ""
s.renderDirty = true
}
}
@@ -293,59 +259,23 @@ func (s *StreamComponent) Reset() {
s.pendingReasoning.Reset()
s.flushPending = false
s.flushGeneration++
s.renderCache = ""
s.renderDirty = false
s.timestamp = time.Time{}
s.reasoningStartTime = time.Time{}
s.reasoningDuration = 0
s.scrollbackFlushedLines = 0
}
// ConsumeOverflow returns any lines from the rendered stream content that have
// overflowed the allocated height and have not yet been pushed to the terminal
// scrollback buffer. It advances the internal flushed-line pointer so
// subsequent calls only return newly overflowed lines.
//
// Returns "" when there is no overflow or height is unconstrained (0).
// The caller should emit the returned string via tea.Println so the content
// appears in the terminal's real scrollback (not just discarded).
// ConsumeOverflow is a no-op in alt screen mode. Overflow is handled by the
// ScrollList viewport. Retained to satisfy streamComponentIface.
func (s *StreamComponent) ConsumeOverflow() string {
if s.height <= 0 {
return ""
}
content := s.render()
if content == "" {
return ""
}
lines := strings.Split(content, "\n")
totalLines := len(lines)
// Number of lines that overflow the viewable height.
overflowLines := totalLines - s.height
if overflowLines <= 0 {
return ""
}
// How many overflow lines are new (not yet flushed to scrollback).
newOverflow := overflowLines - s.scrollbackFlushedLines
if newOverflow <= 0 {
return ""
}
// The new overflow is lines [s.scrollbackFlushedLines .. overflowLines).
start := s.scrollbackFlushedLines
end := overflowLines
s.scrollbackFlushedLines = overflowLines
return strings.Join(lines[start:end], "\n")
return ""
}
// GetRenderedContent returns the rendered assistant message from the accumulated
// streaming text. Returns empty string if no text has been accumulated. Used by
// the parent AppModel to flush content via tea.Println() before resetting.
// the parent AppModel to flush stream content before resetting.
//
// This commits any pending chunks first so the output includes all received
// content, not just what has been flushed by the tick.
//
// Lines already pushed to the terminal scrollback buffer via ConsumeOverflow
// are skipped so that callers do not re-emit content that is already visible
// in the terminal's real scrollback.
func (s *StreamComponent) GetRenderedContent() string {
// Commit any pending chunks so the final output is complete.
s.commitPending()
@@ -366,35 +296,19 @@ func (s *StreamComponent) GetRenderedContent() string {
if len(sections) == 0 {
return ""
}
fullContent := strings.Join(sections, "\n")
// Skip lines already emitted to the terminal scrollback via ConsumeOverflow
// so the caller doesn't re-print content that is already there.
if s.scrollbackFlushedLines > 0 {
lines := strings.Split(fullContent, "\n")
if s.scrollbackFlushedLines >= len(lines) {
return "" // everything already in scrollback
}
return strings.Join(lines[s.scrollbackFlushedLines:], "\n")
}
return fullContent
return strings.Join(sections, "\n")
}
// commitPending moves any pending chunks to the committed content builders.
// Called before reading content for scrollback output or on flush tick.
// Called before reading content for output or on flush tick.
func (s *StreamComponent) commitPending() {
if s.pendingStream.Len() > 0 {
// Strip ... tags that some models wrap reasoning in
cleanedText := thinkTagRegex.ReplaceAllString(s.pendingStream.String(), "")
s.streamContent.WriteString(cleanedText)
s.streamContent.WriteString(s.pendingStream.String())
s.pendingStream.Reset()
s.renderDirty = true
}
if s.pendingReasoning.Len() > 0 {
s.reasoningContent.WriteString(s.pendingReasoning.String())
s.pendingReasoning.Reset()
s.renderDirty = true
}
}
@@ -417,9 +331,6 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if s.renderer != nil {
s.renderer.SetWidth(s.width)
}
// Invalidate render cache — width change affects wrapping/styling.
s.renderCache = ""
s.renderDirty = true
case streamSpinnerTickMsg:
// Only continue the tick loop if this tick belongs to the current
@@ -472,6 +383,17 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return s, streamFlushTickCmd(s.flushGeneration)
}
case app.ReasoningCompleteEvent:
// Freeze reasoning duration when reasoning finishes (before text streaming starts).
if s.reasoningDuration == 0 && !s.reasoningStartTime.IsZero() {
s.reasoningDuration = time.Since(s.reasoningStartTime)
}
// Flush any remaining pending reasoning content.
if s.pendingReasoning.Len() > 0 {
s.reasoningContent.WriteString(s.pendingReasoning.String())
s.pendingReasoning.Reset()
}
case app.StreamChunkEvent:
s.phase = streamPhaseActive
if s.timestamp.IsZero() {
@@ -482,43 +404,9 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s.reasoningDuration = time.Since(s.reasoningStartTime)
}
// Handle models that wrap reasoning in tags (Qwen, DeepSeek)
// Filter out all content between and tags
content := msg.Content
// Check for opening tag
if strings.Contains(content, thinkTagOpen) {
parts := strings.SplitN(content, thinkTagOpen, 2)
// Content before the tag can be written
if !s.inThinkTag && parts[0] != "" {
s.pendingStream.WriteString(parts[0])
}
s.inThinkTag = true
// Content after the opening tag is reasoning - don't write it
if len(parts) > 1 && parts[1] != "" {
// Check if the same chunk contains the closing tag
if strings.Contains(parts[1], thinkTagClose) {
innerParts := strings.SplitN(parts[1], thinkTagClose, 2)
s.inThinkTag = false
// Content after closing tag can be written
if len(innerParts) > 1 && innerParts[1] != "" {
s.pendingStream.WriteString(innerParts[1])
}
}
}
} else if strings.Contains(content, thinkTagClose) {
// Closing tag found
parts := strings.SplitN(content, thinkTagClose, 2)
s.inThinkTag = false
// Content after closing tag can be written
if len(parts) > 1 && parts[1] != "" {
s.pendingStream.WriteString(parts[1])
}
} else if !s.inThinkTag {
// Normal content, not inside think tags
s.pendingStream.WriteString(content)
}
// else: inside think tag, don't write this content
// <think> tag filtering is handled at the agent layer — chunks here
// are already clean text.
s.pendingStream.WriteString(msg.Content)
if !s.flushPending && s.pendingStream.Len() > 0 {
s.flushPending = true
@@ -559,79 +447,10 @@ func (s *StreamComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return s, nil
}
// View implements tea.Model. Renders the current stream region content.
// View implements tea.Model. Returns an empty view since rendering is handled
// by StreamingMessageItem in the ScrollList. Retained to satisfy tea.Model.
func (s *StreamComponent) View() tea.View {
fullContent := s.render()
visibleContent := s.viewContent(fullContent)
v := tea.NewView(visibleContent)
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
return v
}
// --------------------------------------------------------------------------
// Internal rendering
// --------------------------------------------------------------------------
// render builds the full content string for the stream region. Uses a render
// cache to avoid redundant markdown re-parsing between flush ticks. The cache
// is invalidated when committed content changes (flush tick), terminal width
// changes, or height/thinking visibility changes.
func (s *StreamComponent) render() string {
if s.phase == streamPhaseIdle {
return ""
}
// Return cached render if committed content hasn't changed.
if !s.renderDirty {
return s.renderCache
}
var sections []string
// Render reasoning/thinking block above the main text if present.
if reasoning := s.reasoningContent.String(); reasoning != "" {
sections = append(sections, s.renderReasoningBlock(reasoning))
}
// Render streaming text only. The spinner is rendered in the status bar
// by the parent so it never changes the stream region height.
text := s.streamContent.String()
if text != "" {
sections = append(sections, s.renderStreamingText(text))
}
if len(sections) == 0 {
s.renderCache = ""
s.renderDirty = false
return ""
}
content := strings.Join(sections, "\n")
// Cache FULL content without height clamping.
// Height clamping is applied in View() for display only.
s.renderCache = content
s.renderDirty = false
return content
}
// viewContent returns the visible portion of content based on height constraint.
// This is called by View() to get the slice that fits in the terminal.
func (s *StreamComponent) viewContent(fullContent string) string {
if s.height > 0 && fullContent != "" {
lines := strings.Split(fullContent, "\n")
if len(lines) > s.height {
// Keep only the last h lines so the most recent output is visible.
lines = lines[len(lines)-s.height:]
return strings.Join(lines, "\n")
}
}
return fullContent
return tea.NewView("")
}
// renderReasoningBlock renders the reasoning/thinking content using blockquote.
@@ -692,9 +511,6 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
func (s *StreamComponent) SetThinkingVisible(visible bool) {
if s.thinkingVisible != visible {
s.thinkingVisible = visible
// Invalidate cache — thinking visibility affects rendered output.
s.renderCache = ""
s.renderDirty = true
}
}
@@ -1,4 +1,4 @@
package ui
package style
import (
"fmt"
@@ -294,3 +294,28 @@ func ApplyGradient(text string, colorA, colorB color.Color) string {
return result.String()
}
// KitBanner returns the KIT ASCII art title with KITT scanner lights,
// rendered with a KITT red gradient.
func KitBanner() string {
kittDark := lipgloss.Color("#8B0000")
kittBright := lipgloss.Color("#FF2200")
lines := []string{
" ██╗ ██╗ ██╗ ████████╗",
" ██║ ██╔╝ ██║ ╚══██╔══╝",
" █████╔╝ ██║ ██║",
" ██╔═██╗ ██║ ██║",
" ██║ ██╗ ██║ ██║",
" ╚═╝ ╚═╝ ╚═╝ ╚═╝",
"░░ ░░ ░░ ▒▒ ▒▒ ▓▓ ▓▓ ████ ▓▓ ▓▓ ▒▒ ▒▒ ░░ ░░ ░░",
}
var result strings.Builder
for i, line := range lines {
if i > 0 {
result.WriteString("\n")
}
result.WriteString(ApplyGradient(line, kittDark, kittBright))
}
return result.String()
}
@@ -1,4 +1,4 @@
package ui
package style
import (
"charm.land/lipgloss/v2"
@@ -85,10 +85,10 @@ func GetMarkdownTypography() *herald.Typography {
return ty
}
// toMarkdown renders markdown content using herald-md.
// ToMarkdown renders markdown content using herald-md.
// The width parameter is currently unused as herald handles wrapping
// based on terminal width internally.
func toMarkdown(content string, width int) string {
func ToMarkdown(content string, width int) string {
ty := GetMarkdownTypography()
rendered := heraldmd.Render(ty, []byte(content))
return rendered
@@ -1,4 +1,4 @@
package ui
package style
import (
"encoding/json"
@@ -11,6 +11,8 @@ import (
"strings"
"gopkg.in/yaml.v3"
"github.com/mark3labs/kit/internal/ui/prefs"
)
// ---------------------------------------------------------------------------
@@ -410,10 +412,10 @@ func initThemeRegistry() {
}
// 2. User themes from ~/.config/kit/themes/
scanThemesDir(userThemesDir())
scanThemesDir(UserThemesDir())
// 3. Project-local themes from .kit/themes/
scanThemesDir(projectThemesDir())
scanThemesDir(ProjectThemesDir())
sortRegistry()
}
@@ -461,7 +463,7 @@ func removeFromRegistry(name string) {
}
// userThemesDir returns ~/.config/kit/themes, creating it if needed.
func userThemesDir() string {
func UserThemesDir() string {
cfgDir, err := os.UserConfigDir()
if err != nil {
return ""
@@ -473,7 +475,7 @@ func userThemesDir() string {
// projectThemesDir returns .kit/themes/ relative to the working directory.
// Returns "" if the directory doesn't exist (does NOT create it).
func projectThemesDir() string {
func ProjectThemesDir() string {
dir := filepath.Join(".kit", "themes")
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
@@ -525,7 +527,7 @@ func ApplyTheme(name string) error {
return err
}
SetTheme(t)
_ = SaveThemePreference(name)
_ = prefs.SaveThemePreference(name)
return nil
}
@@ -1,4 +1,4 @@
package ui
package style
import (
"testing"
+2 -12
View File
@@ -83,17 +83,8 @@ func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (t *ToolApprovalInput) View() tea.View {
v := tea.NewView("")
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
if t.done {
v.Content = "we are done"
return v
return tea.NewView("we are done")
}
containerStyle := lipgloss.NewStyle()
@@ -145,6 +136,5 @@ func (t *ToolApprovalInput) View() tea.View {
}
view.WriteString(yesText + "/" + noText + "\n")
v.Content = containerStyle.Render(inputBoxStyle.Render(view.String()))
return v
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
}
+277 -32
View File
@@ -28,10 +28,10 @@ const (
maxLsLines = 20 // lines for Ls directory listings
)
// isShellTool reports if the tool name matches a shell-like tool (bash, grep, find, or
// isShellTool reports if the tool name matches a shell-like tool (bash or
// tools with "shell"/"command" in the name). Used by renderToolBody.
func isShellTool(toolName string) bool {
return toolName == "bash" || toolName == "grep" || toolName == "find" ||
return toolName == "bash" ||
strings.Contains(toolName, "shell") || strings.Contains(toolName, "command")
}
@@ -55,8 +55,16 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
if body := renderWriteBody(toolArgs, toolResult, width); body != "" {
return body
}
case toolName == "find":
if body := renderFindBody(toolResult, width); body != "" {
return body
}
case toolName == "grep":
if body := renderGrepBody(toolResult, width); body != "" {
return body
}
case isShellTool(toolName):
if body := renderBashBody(toolResult, width); body != "" {
if body := renderBashBody(toolArgs, toolResult, width); body != "" {
return body
}
case toolName == "subagent":
@@ -337,6 +345,148 @@ func renderDiffBlock(before, after string, startLine int, width int) string {
// Ls tool — simple list without gutter
// ---------------------------------------------------------------------------
// renderFindBody renders find output as a plain list with code background.
// Similar to ls but with results-specific caption.
func renderFindBody(toolResult string, width int) string {
content := strings.TrimSpace(toolResult)
if content == "" {
return ""
}
lines := strings.Split(content, "\n")
totalResults := len(lines)
// Truncate to maxLsLines for display
var hiddenCount int
if len(lines) > maxLsLines {
hiddenCount = len(lines) - maxLsLines
lines = lines[:maxLsLines]
}
const lineIndent = " "
codeWidth := max(width-len(lineIndent), 20)
theme := GetTheme()
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
var rendered []string
for _, line := range lines {
// Truncate before styling to prevent wrapping.
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
styled := codeStyle.Width(codeWidth).Render(line)
rendered = append(rendered, styled)
}
content = strings.Join(rendered, "\n")
// Build caption with results info
var captionParts []string
if totalResults == 1 {
captionParts = append(captionParts, "1 result")
} else {
captionParts = append(captionParts, fmt.Sprintf("%d results", totalResults))
}
if hiddenCount > 0 {
captionParts = append(captionParts, fmt.Sprintf("%d more", hiddenCount))
}
if len(captionParts) > 1 || hiddenCount > 0 {
ty := herald.New(herald.WithTheme(herald.Theme{
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
FigureCaptionPosition: herald.CaptionBottom,
}))
caption := strings.Join(captionParts, " • ")
result := ty.Figure(content, caption)
// Indent entire block (content + caption) to match other tools
const blockIndent = " "
resultLines := strings.Split(result, "\n")
for i, line := range resultLines {
resultLines[i] = blockIndent + line
}
return strings.Join(resultLines, "\n")
}
// Single result with no truncation - just return indented content
const blockIndent = " "
contentLines := strings.Split(content, "\n")
for i, line := range contentLines {
contentLines[i] = blockIndent + line
}
return strings.Join(contentLines, "\n")
}
// renderGrepBody renders grep output as a plain list with code background.
// Similar to find but with match-specific caption terminology.
func renderGrepBody(toolResult string, width int) string {
content := strings.TrimSpace(toolResult)
if content == "" {
return ""
}
lines := strings.Split(content, "\n")
totalMatches := len(lines)
// Truncate to maxLsLines for display
var hiddenCount int
if len(lines) > maxLsLines {
hiddenCount = len(lines) - maxLsLines
lines = lines[:maxLsLines]
}
const lineIndent = " "
codeWidth := max(width-len(lineIndent), 20)
theme := GetTheme()
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
var rendered []string
for _, line := range lines {
// Truncate before styling to prevent wrapping.
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
styled := codeStyle.Width(codeWidth).Render(line)
rendered = append(rendered, styled)
}
content = strings.Join(rendered, "\n")
// Build caption with match info
var captionParts []string
if totalMatches == 1 {
captionParts = append(captionParts, "1 match")
} else {
captionParts = append(captionParts, fmt.Sprintf("%d matches", totalMatches))
}
if hiddenCount > 0 {
captionParts = append(captionParts, fmt.Sprintf("%d more", hiddenCount))
}
if len(captionParts) > 1 || hiddenCount > 0 {
ty := herald.New(herald.WithTheme(herald.Theme{
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
FigureCaptionPosition: herald.CaptionBottom,
}))
caption := strings.Join(captionParts, " • ")
result := ty.Figure(content, caption)
// Indent entire block (content + caption) to match other tools
const blockIndent = " "
resultLines := strings.Split(result, "\n")
for i, line := range resultLines {
resultLines[i] = blockIndent + line
}
return strings.Join(resultLines, "\n")
}
// Single match with no truncation - just return indented content
const blockIndent = " "
contentLines := strings.Split(content, "\n")
for i, line := range contentLines {
contentLines[i] = blockIndent + line
}
return strings.Join(contentLines, "\n")
}
// renderLsBody renders ls output as a plain list with code background and no
// line-number gutter.
func renderLsBody(toolResult string, width int) string {
@@ -354,28 +504,47 @@ func renderLsBody(toolResult string, width int) string {
lines = lines[:maxLsLines]
}
const indent = " "
codeWidth := max(width-len(indent), 20)
const lineIndent = " "
codeWidth := max(width-len(lineIndent), 20)
theme := GetTheme()
codeStyle := lipgloss.NewStyle().Background(theme.CodeBg).PaddingLeft(1)
var result []string
var rendered []string
for _, line := range lines {
// Truncate before styling to prevent wrapping.
line = truncateLine(line, codeWidth-1) // account for PaddingLeft(1)
styled := codeStyle.Width(codeWidth).Render(line)
result = append(result, indent+styled)
rendered = append(rendered, styled)
}
content = strings.Join(rendered, "\n")
// Build caption with hidden entries info
if hiddenCount > 0 {
hint := fmt.Sprintf("...(%d more entries)", hiddenCount)
hintContent := codeStyle.Width(codeWidth).
Foreground(theme.Muted).Italic(true).Render(hint)
result = append(result, indent+hintContent)
ty := herald.New(herald.WithTheme(herald.Theme{
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
FigureCaptionPosition: herald.CaptionBottom,
}))
caption := fmt.Sprintf("%d more entries", hiddenCount)
result := ty.Figure(content, caption)
// Indent entire block (content + caption) to match other tools
const blockIndent = " "
resultLines := strings.Split(result, "\n")
for i, line := range resultLines {
resultLines[i] = blockIndent + line
}
return strings.Join(resultLines, "\n")
}
return strings.Join(result, "\n")
// No caption - just return indented content
const blockIndent = " "
contentLines := strings.Split(content, "\n")
for i, line := range contentLines {
contentLines[i] = blockIndent + line
}
return strings.Join(contentLines, "\n")
}
// ---------------------------------------------------------------------------
@@ -461,19 +630,50 @@ func renderReadBody(toolArgs, toolResult string, width int) string {
)
// Render the code block
result := ty.CodeBlock(codeContent, lang)
codeBlock := ty.CodeBlock(codeContent, lang)
// Add truncation hint if needed
// Herald's codeBlockWithLineNumbers() hardcodes PaddingTop(1) and
// PaddingBottom(1), adding invisible blank lines with background color
// above and below the code. These interfere with mouse selection
// (off-by-one) because the padding line looks blank but occupies a
// line index in the rendered item. Strip them since the Compose
// separator above and Figure caption below already provide spacing.
codeBlock = stripCodeBlockPadding(codeBlock)
// Parse total lines from footer if available (e.g., "[showing lines 1-100 of 407 total...]")
totalLines := totalCodeLines
for _, footer := range footerLines {
if matches := regexp.MustCompile(`of (\d+) total`).FindStringSubmatch(footer); len(matches) > 1 {
if t, _ := strconv.Atoi(matches[1]); t > totalLines {
totalLines = t
}
}
}
// Build caption with file metadata
var captionParts []string
if fileName != "" {
captionParts = append(captionParts, filepath.Base(fileName))
}
if len(codeLines) > 0 {
endLine := offset + len(codeLines) - 1
captionParts = append(captionParts, fmt.Sprintf("lines %d-%d of %d", offset, endLine, totalLines))
}
if codeHiddenCount > 0 {
hint := fmt.Sprintf("...(%d more lines)", codeHiddenCount)
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Italic(true).Render(hint)
nextOffset := offset + len(codeLines)
captionParts = append(captionParts, fmt.Sprintf("offset=%d to continue", nextOffset))
}
// Add any footer lines
if len(footerLines) > 0 {
footer := strings.Join(footerLines, "\n")
result += "\n" + lipgloss.NewStyle().Foreground(GetTheme().Muted).Render(footer)
caption := strings.Join(captionParts, " • ")
// Use Figure with caption below content (default behavior)
// Apply theme to ensure caption is positioned below
figTheme := herald.Theme{
FigureCaption: lipgloss.NewStyle().Foreground(GetTheme().Muted),
FigureCaptionPosition: herald.CaptionBottom,
}
tyFig := herald.New(herald.WithTheme(figTheme))
result := tyFig.Figure(codeBlock, caption)
// Indent entire block to match Write/Edit tools (2 spaces)
const blockIndent = " "
@@ -582,7 +782,7 @@ func renderWriteBlock(content, fileName string, width int) string {
// renderBashBody renders bash output with per-line background and stderr
// in error color.
func renderBashBody(toolResult string, width int) string {
func renderBashBody(toolArgs, toolResult string, width int) string {
if strings.TrimSpace(toolResult) == "" {
return ""
}
@@ -609,6 +809,7 @@ func renderBashBody(toolResult string, width int) string {
maxLineChars := lineWidth - 1
var rendered []string
exitCode := -1 // -1 means not found
inStderr := false
for _, line := range lines {
line = truncateLine(line, maxLineChars)
@@ -617,30 +818,55 @@ func renderBashBody(toolResult string, width int) string {
inStderr = true
continue
}
// Exit code line
// Exit code line - extract it for caption
if strings.HasPrefix(line, "Exit code:") {
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
rendered = append(rendered, lineIndent+styled)
continue
_, _ = fmt.Sscanf(line, "Exit code: %d", &exitCode)
continue // Don't render exit code inline, it goes in caption
}
if inStderr {
styled := stderrStyle.Width(width - len(lineIndent)).Render(line)
rendered = append(rendered, lineIndent+styled)
rendered = append(rendered, styled)
} else {
styled := outputStyle.Width(width - len(lineIndent)).Render(line)
rendered = append(rendered, lineIndent+styled)
rendered = append(rendered, styled)
}
}
// Build caption with status info
var captionParts []string
if hiddenCount > 0 {
truncMsg := fmt.Sprintf("...(%d more lines)", hiddenCount)
hint := outputStyle.Width(width - len(lineIndent)).
Foreground(theme.Muted).Italic(true).Render(truncMsg)
rendered = append(rendered, lineIndent+hint)
captionParts = append(captionParts, fmt.Sprintf("%d more lines", hiddenCount))
}
if exitCode >= 0 {
captionParts = append(captionParts, fmt.Sprintf("exit code %d", exitCode))
}
return strings.Join(rendered, "\n")
content := strings.Join(rendered, "\n")
if len(captionParts) > 0 {
ty := herald.New(herald.WithTheme(herald.Theme{
FigureCaption: lipgloss.NewStyle().Foreground(theme.Muted),
FigureCaptionPosition: herald.CaptionBottom,
}))
caption := strings.Join(captionParts, " • ")
result := ty.Figure(content, caption)
// Indent entire block (content + caption) to match other tools
const blockIndent = " "
lines := strings.Split(result, "\n")
for i, line := range lines {
lines[i] = blockIndent + line
}
return strings.Join(lines, "\n")
}
// No caption - just return indented content
const blockIndent = " "
contentLines := strings.Split(content, "\n")
for i, line := range contentLines {
contentLines[i] = blockIndent + line
}
return strings.Join(contentLines, "\n")
}
// ---------------------------------------------------------------------------
@@ -724,6 +950,25 @@ func padRight(s string, width int) string {
return s + strings.Repeat(" ", width-w)
}
// stripCodeBlockPadding removes the top and bottom padding lines that herald's
// codeBlockWithLineNumbers() hardcodes via PaddingTop(1)/PaddingBottom(1).
// These padding lines are blank lines with background color that look invisible
// but occupy line indices, causing mouse selection to be off by one row.
func stripCodeBlockPadding(block string) string {
lines := strings.Split(block, "\n")
if len(lines) < 3 {
return block
}
// The first and last lines are padding (blank with bg color).
// Strip them only if they contain no visible text.
first := xansi.Strip(lines[0])
last := xansi.Strip(lines[len(lines)-1])
if strings.TrimSpace(first) == "" && strings.TrimSpace(last) == "" {
return strings.Join(lines[1:len(lines)-1], "\n")
}
return block
}
// truncateLine truncates a line to maxWidth visual characters, adding "…"
// if truncated. This is ANSI-aware: escape codes are preserved and wide
// characters are measured correctly.
+177 -60
View File
@@ -10,6 +10,7 @@ import (
"charm.land/lipgloss/v2"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/ui/core"
)
// TreeFilterMode controls which entries are visible in the tree selector.
@@ -88,6 +89,28 @@ func NewTreeSelector(tm *session.TreeManager, width, height int) *TreeSelectorCo
return ts
}
// NewTreeSelectorForFork creates a tree selector for the /fork command.
// It shows only user messages (flat list) matching Pi's fork behavior.
func NewTreeSelectorForFork(tm *session.TreeManager, width, height int) *TreeSelectorComponent {
ts := &TreeSelectorComponent{
tm: tm,
filter: TreeFilterUserOnly,
leafID: tm.GetLeafID(),
width: width,
height: height,
active: true,
}
ts.rebuildFlatList()
// Position cursor at the last user message before the leaf.
for i := len(ts.flatNodes) - 1; i >= 0; i-- {
if ts.isUserMessage(ts.flatNodes[i].Entry) {
ts.cursor = i
break
}
}
return ts
}
// Init implements tea.Model.
func (ts *TreeSelectorComponent) Init() tea.Cmd {
return nil
@@ -103,12 +126,12 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyPressMsg:
switch {
case key.Matches(msg, key.NewBinding(key.WithKeys("up", "k"))):
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
if ts.cursor > 0 {
ts.cursor--
}
case key.Matches(msg, key.NewBinding(key.WithKeys("down", "j"))):
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
if ts.cursor < len(ts.flatNodes)-1 {
ts.cursor++
}
@@ -138,7 +161,7 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
ts.selectedID = ts.flatNodes[ts.cursor].ID
ts.active = false
return ts, func() tea.Msg {
return TreeNodeSelectedMsg{
return core.TreeNodeSelectedMsg{
ID: ts.selectedID,
Entry: ts.flatNodes[ts.cursor].Entry,
IsUser: ts.isUserMessage(ts.flatNodes[ts.cursor].Entry),
@@ -155,7 +178,7 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
ts.cancelled = true
ts.active = false
return ts, func() tea.Msg {
return TreeCancelledMsg{}
return core.TreeCancelledMsg{}
}
}
@@ -203,46 +226,92 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (ts *TreeSelectorComponent) View() tea.View {
theme := GetTheme()
// Full-screen bordered container - uses entire terminal width and height
maxWidth := ts.width - 2 // Small margin on each side
if maxWidth < 20 {
maxWidth = ts.width
}
maxHeight := ts.height - 2 // Small margin top/bottom to prevent overflow
if maxHeight < 10 {
maxHeight = ts.height
}
horizontalPadding := 1
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
// Container style with border - full width/height like a framed panel
containerStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Primary).
Background(theme.Background).
Padding(1, horizontalPadding).
Width(maxWidth).
Height(maxHeight)
// Header style with background highlight (like PopupList title)
headerStyle := lipgloss.NewStyle().
Bold(true).
Foreground(theme.Accent).
PaddingLeft(2)
Background(theme.Background)
// Help text style
helpStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
PaddingLeft(2)
Background(theme.Background)
var b strings.Builder
var contentBuilder strings.Builder
// Header.
b.WriteString(headerStyle.Render("Session Tree"))
b.WriteString("\n")
// Adapt help text to terminal width.
// Header row with title and help
headerRow := headerStyle.Render("Session Tree")
contentBuilder.WriteString(headerRow)
contentBuilder.WriteString("\n")
// Help text - adapt to terminal width
var helpText string
if ts.width >= 70 {
b.WriteString(helpStyle.Render("↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"))
helpText = "↑/↓: move ←/→: page enter: select esc: cancel ^O: cycle filter"
} else if ts.width >= 45 {
b.WriteString(helpStyle.Render("↑↓ move ↵ select esc cancel ^O filter"))
helpText = "↑↓ move ↵ select esc cancel ^O filter"
} else {
b.WriteString(helpStyle.Render("↑↓ ↵ esc ^O"))
helpText = "↑↓ ↵ esc ^O"
}
b.WriteString("\n")
contentBuilder.WriteString(helpStyle.Render(helpText))
contentBuilder.WriteString("\n")
// Search display (if active)
if ts.search != "" {
searchStyle := lipgloss.NewStyle().Foreground(theme.Info).PaddingLeft(2)
b.WriteString(searchStyle.Render(fmt.Sprintf("Search: %s", ts.search)))
b.WriteString("\n")
searchStyle := lipgloss.NewStyle().
Foreground(theme.Info).
Background(theme.Background)
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ts.search)))
contentBuilder.WriteString("\n")
}
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ts.width)))
b.WriteString("\n")
// Separator line - full width
sepWidth := innerWidth
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
// Tree content
if len(ts.flatNodes) == 0 {
emptyStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
b.WriteString(emptyStyle.Render("No entries in session"))
b.WriteString("\n")
emptyStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
contentBuilder.WriteString(emptyStyle.Render("No entries in session"))
contentBuilder.WriteString("\n")
} else {
// Compute visible window.
visH := ts.visibleHeight()
// Compute visible window based on inner container height
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
chromeLines := 5
if ts.search != "" {
chromeLines++
}
visH := max(innerHeight-chromeLines, 3)
startIdx := 0
if ts.cursor >= visH {
startIdx = ts.cursor - visH + 1
@@ -251,27 +320,33 @@ func (ts *TreeSelectorComponent) View() tea.View {
for i := startIdx; i < endIdx; i++ {
node := ts.flatNodes[i]
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID)
b.WriteString(line)
b.WriteString("\n")
line := ts.renderNode(node, i == ts.cursor, node.ID == ts.leafID, innerWidth)
contentBuilder.WriteString(line)
contentBuilder.WriteString("\n")
}
}
// Footer.
b.WriteString(lipgloss.NewStyle().Foreground(theme.Muted).Render(strings.Repeat("─", ts.width)))
b.WriteString("\n")
// Footer separator
contentBuilder.WriteString(
lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background).
Render(strings.Repeat("─", sepWidth)))
contentBuilder.WriteString("\n")
footerStyle := lipgloss.NewStyle().Foreground(theme.Muted).PaddingLeft(2)
// Footer with count and filter
footerStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Background(theme.Background)
footer := fmt.Sprintf("(%d/%d) [%s]", ts.cursor+1, len(ts.flatNodes), ts.filter)
b.WriteString(footerStyle.Render(footer))
contentBuilder.WriteString(footerStyle.Render(footer))
v := tea.NewView(b.String())
// Apply the bordered container - full width, no centering
content := contentBuilder.String()
borderedContent := containerStyle.Render(content)
v := tea.NewView(borderedContent)
v.AltScreen = true
v.MouseMode = tea.MouseModeCellMotion
v.ReportFocus = true
v.KeyboardEnhancements = tea.KeyboardEnhancements{
ReportEventTypes: true,
}
return v
}
@@ -402,21 +477,23 @@ func (ts *TreeSelectorComponent) passesFilter(node *session.TreeNode) bool {
}
}
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool) string {
func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool, innerWidth int) string {
theme := GetTheme()
maxWidth := max(ts.width-4, 10)
// Cursor indicator.
// Cursor indicator - use ">" for selected (like PopupList)
var cursor string
if isCursor {
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render(" ")
cursor = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
} else {
cursor = " "
}
// Role-colored content.
// Role-colored content with background support for selection
text := ts.entryDisplayText(node.Entry)
available := maxWidth - len(node.Prefix) - 10
// Calculate available width accounting for cursor, prefix, and markers
prefixLen := len(node.Prefix)
available := innerWidth - prefixLen - 4 // 4 for cursor and some padding
if available > 3 && len(text) > available {
trimLen := max(available-3, 1)
if trimLen < len(text) {
@@ -424,48 +501,88 @@ func (ts *TreeSelectorComponent) renderNode(node FlatNode, isCursor, isLeaf bool
}
}
var style lipgloss.Style
// Build the full line style
var lineStyle lipgloss.Style
var textStyle lipgloss.Style
// Base text color based on role
switch e := node.Entry.(type) {
case *session.MessageEntry:
switch e.Role {
case "user":
style = lipgloss.NewStyle().Foreground(theme.Accent)
textStyle = lipgloss.NewStyle().Foreground(theme.Accent)
case "assistant":
style = lipgloss.NewStyle().Foreground(theme.Success)
textStyle = lipgloss.NewStyle().Foreground(theme.Success)
default:
style = lipgloss.NewStyle().Foreground(theme.Muted)
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
}
case *session.BranchSummaryEntry:
style = lipgloss.NewStyle().Foreground(theme.Warning).Italic(true)
textStyle = lipgloss.NewStyle().Foreground(theme.Warning).Italic(true)
case *session.CompactionEntry:
style = lipgloss.NewStyle().Foreground(theme.Info).Italic(true)
textStyle = lipgloss.NewStyle().Foreground(theme.Info).Italic(true)
default:
style = lipgloss.NewStyle().Foreground(theme.Muted)
textStyle = lipgloss.NewStyle().Foreground(theme.Muted)
}
// Apply selection highlighting (like PopupList)
if isCursor {
style = style.Bold(true)
// Inverted colors for selected item - matches PopupList style
lineStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
textStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Background).
Bold(true)
}
content := style.Render(text)
// Render components
content := textStyle.Render(text)
// Label badge.
var labelBadge string
if node.Label != "" {
labelBadge = " " + lipgloss.NewStyle().Foreground(theme.Warning).Render("["+node.Label+"]")
labelStyle := lipgloss.NewStyle().Foreground(theme.Warning)
if isCursor {
labelStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Warning)
}
labelBadge = " " + labelStyle.Render("["+node.Label+"]")
}
// Active marker.
// Active marker - use Success color for better visibility
var activeMarker string
if isLeaf {
activeMarker = lipgloss.NewStyle().Foreground(theme.Accent).Bold(true).Render(" ← active")
markerStyle := lipgloss.NewStyle().Foreground(theme.Success).Bold(true)
if isCursor {
markerStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.Success).
Bold(true)
}
activeMarker = markerStyle.Render(" ← active")
}
// Prefix (tree lines).
prefixStyle := lipgloss.NewStyle().Foreground(theme.Muted)
// Prefix (tree lines) - use MutedBorder for subtler appearance
prefixStyle := lipgloss.NewStyle().Foreground(theme.MutedBorder)
if isCursor {
prefixStyle = lipgloss.NewStyle().
Background(theme.Primary).
Foreground(theme.MutedBorder)
}
renderedPrefix := prefixStyle.Render(node.Prefix)
return cursor + renderedPrefix + content + labelBadge + activeMarker
// Combine all parts
line := cursor + renderedPrefix + content + labelBadge + activeMarker
// If selected, apply the background to the entire line
if isCursor {
return lineStyle.Render(line)
}
return line
}
func (ts *TreeSelectorComponent) entryDisplayText(entry any) string {
+259
View File
@@ -0,0 +1,259 @@
// Package watcher provides a general-purpose file watcher that monitors
// directories for changes to files matching specified extensions. It uses
// fsnotify for kernel-level notifications with debouncing to coalesce
// rapid editor writes.
package watcher
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
// ContentWatcher monitors directories for file changes matching a set of
// extensions and triggers a reload callback when changes are detected.
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
type ContentWatcher struct {
watcher *fsnotify.Watcher
onReload func()
extensions []string // e.g. [".md", ".txt"]
label string // for logging (e.g. "prompts", "skills")
debounce time.Duration
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
}
// Options configures a ContentWatcher.
type Options struct {
// Dirs are the directories to watch.
Dirs []string
// Extensions are the file extensions to watch for (e.g. ".md", ".txt").
// Include the leading dot.
Extensions []string
// OnReload is called when a matching file changes (after debouncing).
OnReload func()
// Label is a human-readable name for logging (e.g. "prompts", "skills").
Label string
// Debounce is the debounce duration. Defaults to 300ms if zero.
Debounce time.Duration
}
// New creates a ContentWatcher that monitors the given directories for
// file changes matching the specified extensions. When a change is detected
// (after debouncing), onReload is called. The watcher must be started with
// Start() and stopped with Close().
func New(opts Options) (*ContentWatcher, error) {
if len(opts.Dirs) == 0 {
return nil, fmt.Errorf("no directories to watch")
}
fsw, err := fsnotify.NewWatcher()
if err != nil {
return nil, fmt.Errorf("creating file watcher: %w", err)
}
for _, dir := range opts.Dirs {
if err := fsw.Add(dir); err != nil {
continue
}
// Also watch immediate subdirectories (for skill/SKILL.md pattern).
entries, err := os.ReadDir(dir)
if err != nil {
continue
}
for _, entry := range entries {
if entry.IsDir() {
subdir := filepath.Join(dir, entry.Name())
_ = fsw.Add(subdir)
}
}
}
debounce := opts.Debounce
if debounce == 0 {
debounce = 300 * time.Millisecond
}
return &ContentWatcher{
watcher: fsw,
onReload: opts.OnReload,
extensions: opts.Extensions,
label: opts.Label,
debounce: debounce,
done: make(chan struct{}),
}, nil
}
// Start begins watching for file changes. It blocks until the context
// is cancelled or Close() is called. Typically called in a goroutine.
func (w *ContentWatcher) Start(ctx context.Context) {
w.mu.Lock()
ctx, w.cancel = context.WithCancel(ctx)
w.mu.Unlock()
defer close(w.done)
var timer *time.Timer
var timerC <-chan time.Time
for {
select {
case <-ctx.Done():
if timer != nil {
timer.Stop()
}
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
// When a new subdirectory is created, start watching it so
// that files added inside (e.g. new-skill/SKILL.md) trigger
// reload events. Also schedule a reload in case the directory
// was created with matching files already inside.
if event.Op&fsnotify.Create != 0 {
if info, err := os.Stat(event.Name); err == nil && info.IsDir() {
if addErr := w.watcher.Add(event.Name); addErr == nil {
// Check if the new directory already contains matching files.
if w.dirContainsMatchingFiles(event.Name) {
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
}
}
continue
}
}
// Only care about files matching our extensions.
if !w.matchesExtension(event.Name) {
continue
}
// React to write, create, remove, rename events.
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
continue
}
// Debounce: reset timer on each event.
if timer != nil {
timer.Stop()
}
timer = time.NewTimer(w.debounce)
timerC = timer.C
case <-timerC:
timerC = nil
timer = nil
w.onReload()
case err, ok := <-w.watcher.Errors:
if !ok {
return
}
_ = err
}
}
}
// Close stops the watcher and releases resources.
func (w *ContentWatcher) Close() error {
w.mu.Lock()
cancel := w.cancel
w.mu.Unlock()
if cancel != nil {
cancel()
}
// Wait for the event loop to finish.
<-w.done
return w.watcher.Close()
}
// matchesExtension returns true if the file name ends with one of the
// watched extensions.
func (w *ContentWatcher) matchesExtension(name string) bool {
for _, ext := range w.extensions {
if strings.HasSuffix(name, ext) {
return true
}
}
return false
}
// dirContainsMatchingFiles returns true if the directory contains at least
// one file matching the watched extensions. Used to detect cases where a
// directory is created with files already inside (e.g. cp -r).
func (w *ContentWatcher) dirContainsMatchingFiles(dir string) bool {
entries, err := os.ReadDir(dir)
if err != nil {
return false
}
for _, entry := range entries {
if !entry.IsDir() && w.matchesExtension(entry.Name()) {
return true
}
}
return false
}
// CollectDirs returns the directories to watch for a given set of standard
// directories and extra paths. Directories are deduplicated by absolute path
// and verified to exist. For explicit file paths, the parent directory is
// watched instead.
func CollectDirs(standardDirs []string, extraPaths []string) []string {
var dirs []string
seen := make(map[string]bool)
add := func(dir string) {
abs, err := filepath.Abs(dir)
if err != nil {
return
}
if seen[abs] {
return
}
// Verify the directory exists.
info, err := os.Stat(abs)
if err != nil || !info.IsDir() {
return
}
seen[abs] = true
dirs = append(dirs, abs)
}
for _, d := range standardDirs {
add(d)
}
for _, p := range extraPaths {
info, err := os.Stat(p)
if err != nil {
continue
}
if info.IsDir() {
add(p)
} else {
// For explicit files, watch the parent directory.
add(filepath.Dir(p))
}
}
return dirs
}
+307
View File
@@ -0,0 +1,307 @@
package watcher
import (
"os"
"path/filepath"
"sync/atomic"
"testing"
"time"
)
func TestContentWatcher_ReloadsOnMatchingFile(t *testing.T) {
dir := t.TempDir()
// Write an initial file so the directory isn't empty.
initial := filepath.Join(dir, "existing.md")
if err := os.WriteFile(initial, []byte("# Hello"), 0644); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Wait for watcher to be ready.
time.Sleep(100 * time.Millisecond)
// Modify the file.
if err := os.WriteFile(initial, []byte("# Updated"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_IgnoresNonMatchingFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write a non-matching file.
if err := os.WriteFile(filepath.Join(dir, "readme.txt"), []byte("hello"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 0 {
t.Errorf("expected 0 reloads for non-matching file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_MultipleExtensions(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md", ".txt"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write a .txt file — should trigger.
if err := os.WriteFile(filepath.Join(dir, "notes.txt"), []byte("notes"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload for .txt file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_Debounces(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 100 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Rapid-fire writes — should debounce into 1 reload.
for i := range 5 {
if err := os.WriteFile(filepath.Join(dir, "test.md"), []byte("v"+string(rune('0'+i))), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(30 * time.Millisecond)
}
time.Sleep(300 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 debounced reload, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesSubdirectories(t *testing.T) {
dir := t.TempDir()
// Create a subdirectory (simulates skill-name/SKILL.md pattern).
subdir := filepath.Join(dir, "my-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Write to subdirectory.
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Skill"), 0644); err != nil {
t.Fatal(err)
}
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got != 1 {
t.Errorf("expected 1 reload for subdirectory file, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectory(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
// Wait for watcher to be ready.
time.Sleep(100 * time.Millisecond)
// Create a NEW subdirectory after the watcher started (the bug scenario).
subdir := filepath.Join(dir, "new-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
// Give fsnotify time to pick up the new directory.
time.Sleep(100 * time.Millisecond)
// Write a matching file inside the new subdirectory.
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# New Skill"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(200 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for file in new subdirectory, got %d", got)
}
_ = w.Close()
}
func TestContentWatcher_WatchesNewSubdirectoryWithExistingFiles(t *testing.T) {
dir := t.TempDir()
var reloadCount atomic.Int32
w, err := New(Options{
Dirs: []string{dir},
Extensions: []string{".md"},
OnReload: func() { reloadCount.Add(1) },
Label: "test",
Debounce: 50 * time.Millisecond,
})
if err != nil {
t.Fatal(err)
}
go w.Start(t.Context())
time.Sleep(100 * time.Millisecond)
// Create a subdirectory with a matching file already inside (simulates cp -r).
subdir := filepath.Join(dir, "copied-skill")
if err := os.MkdirAll(subdir, 0755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(subdir, "SKILL.md"), []byte("# Copied"), 0644); err != nil {
t.Fatal(err)
}
// Wait for debounce + processing.
time.Sleep(300 * time.Millisecond)
if got := reloadCount.Load(); got < 1 {
t.Errorf("expected at least 1 reload for copied subdirectory with files, got %d", got)
}
_ = w.Close()
}
func TestCollectDirs_Deduplicates(t *testing.T) {
dir := t.TempDir()
dirs := CollectDirs([]string{dir, dir}, nil)
if len(dirs) != 1 {
t.Errorf("expected 1 deduplicated dir, got %d", len(dirs))
}
}
func TestCollectDirs_FileParent(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "test.md")
if err := os.WriteFile(file, []byte("test"), 0644); err != nil {
t.Fatal(err)
}
dirs := CollectDirs(nil, []string{file})
if len(dirs) != 1 {
t.Fatalf("expected 1 dir, got %d", len(dirs))
}
abs, _ := filepath.Abs(dir)
if dirs[0] != abs {
t.Errorf("expected %s, got %s", abs, dirs[0])
}
}
func TestCollectDirs_SkipsNonexistent(t *testing.T) {
dirs := CollectDirs([]string{"/nonexistent/dir"}, nil)
if len(dirs) != 0 {
t.Errorf("expected 0 dirs for nonexistent path, got %d", len(dirs))
}
}
+29
View File
@@ -0,0 +1,29 @@
{
"name": "@mark3labs/kit",
"version": "0.0.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@mark3labs/kit",
"version": "0.0.0",
"cpu": [
"x64",
"arm64"
],
"hasInstallScript": true,
"license": "MIT",
"os": [
"darwin",
"linux",
"win32"
],
"bin": {
"kit": "bin/kit"
},
"engines": {
"node": ">=16"
}
}
}
}
+2
View File
@@ -202,6 +202,7 @@ func Init(api ext.API) {
footer := harness.Context().GetFooter()
if footer == nil {
t.Fatal("expected footer to be set")
return
}
if footer.Content.Text != "Status: OK" {
t.Errorf("expected footer text 'Status: OK', got %q", footer.Content.Text)
@@ -258,6 +259,7 @@ func Init(api ext.API) {
if result == nil {
t.Fatal("expected non-nil result")
return
}
if !result.Block {
+24 -2
View File
@@ -68,8 +68,12 @@ host, err := kit.New(ctx, &kit.Options{
NoSession: true, // Ephemeral mode
// Tool options
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
Tools: []kit.Tool{kit.NewBashTool()}, // Replace default tool set
ExtraTools: []kit.Tool{myTool}, // Add alongside defaults
DisableCoreTools: true, // Use no core tools (0 tools)
// Configuration
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
// Compaction
AutoCompact: true, // Auto-compact near context limit
@@ -172,6 +176,24 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
- `GetSessionID()` - Get session UUID
- `Close()` - Clean up resources
### Options
Key `Options` fields for SDK usage:
| Field | Description |
|-------|-------------|
| `Model` | Override model (e.g., "anthropic/claude-sonnet-4-5-20250929") |
| `SystemPrompt` | Override system prompt |
| `ConfigFile` | Load specific config file (empty = search defaults) |
| `SkipConfig` | Skip `.kit.yml` loading (defaults + env vars still apply) |
| `Tools` | Replace core tools with custom set |
| `ExtraTools` | Add tools alongside defaults |
| `DisableCoreTools` | Use no core tools (0 tools, for chat-only) |
| `NoSession` | Ephemeral mode (no session persistence) |
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `Debug` | Enable debug logging |
## Environment Variables
All CLI environment variables work with the SDK:
+231
View File
@@ -0,0 +1,231 @@
package kit
import (
"strings"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/session"
)
// treeManagerAdapter adapts TreeManager to SessionManager interface.
// This is unexported - users don't interact with it directly.
type treeManagerAdapter struct {
inner *session.TreeManager
}
// NewTreeManagerAdapter creates an adapter (exported for use in New function).
// This is used by the SDK when no custom SessionManager is provided.
func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
return &treeManagerAdapter{inner: tm}
}
// AppendMessage implements SessionManager.
func (a *treeManagerAdapter) AppendMessage(msg LLMMessage) (string, error) {
// LLMMessage is just an alias for fantasy.Message, so no conversion needed
return a.inner.AppendLLMMessage(msg)
}
// GetMessages implements SessionManager.
func (a *treeManagerAdapter) GetMessages() []LLMMessage {
// LLMMessage is just an alias for fantasy.Message
return a.inner.GetLLMMessages()
}
// BuildContext implements SessionManager.
func (a *treeManagerAdapter) BuildContext() ([]LLMMessage, string, string) {
msgs, provider, modelID := a.inner.BuildContext()
return msgs, provider, modelID
}
// Branch implements SessionManager.
func (a *treeManagerAdapter) Branch(entryID string) error {
return a.inner.Branch(entryID)
}
// GetCurrentBranch implements SessionManager.
func (a *treeManagerAdapter) GetCurrentBranch() []BranchEntry {
branch := a.inner.GetBranch("")
var result []BranchEntry
for _, entry := range branch {
be := a.convertEntry(entry)
if be != nil {
result = append(result, *be)
}
}
return result
}
// GetChildren implements SessionManager.
func (a *treeManagerAdapter) GetChildren(parentID string) []string {
return a.inner.GetChildren(parentID)
}
// GetEntry implements SessionManager.
func (a *treeManagerAdapter) GetEntry(entryID string) *BranchEntry {
entry := a.inner.GetEntry(entryID)
if entry == nil {
return nil
}
return a.convertEntry(entry)
}
// GetSessionID implements SessionManager.
func (a *treeManagerAdapter) GetSessionID() string {
return a.inner.GetSessionID()
}
// GetSessionName implements SessionManager.
func (a *treeManagerAdapter) GetSessionName() string {
return a.inner.GetSessionName()
}
// SetSessionName implements SessionManager.
func (a *treeManagerAdapter) SetSessionName(name string) error {
_, err := a.inner.AppendSessionInfo(name)
return err
}
// GetCreatedAt implements SessionManager.
func (a *treeManagerAdapter) GetCreatedAt() time.Time {
return a.inner.GetHeader().Timestamp
}
// IsPersisted implements SessionManager.
func (a *treeManagerAdapter) IsPersisted() bool {
return a.inner.IsPersisted()
}
// AppendCompaction implements SessionManager.
func (a *treeManagerAdapter) AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
return a.inner.AppendCompaction(summary, firstKeptEntryID,
tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
}
// GetLastCompaction implements SessionManager.
func (a *treeManagerAdapter) GetLastCompaction() *CompactionEntry {
c := a.inner.GetLastCompaction()
if c == nil {
return nil
}
return &CompactionEntry{
ID: c.ID,
Summary: c.Summary,
FirstKeptEntryID: c.FirstKeptEntryID,
TokensBefore: c.TokensBefore,
TokensAfter: c.TokensAfter,
MessagesRemoved: c.MessagesRemoved,
ReadFiles: c.ReadFiles,
ModifiedFiles: c.ModifiedFiles,
Timestamp: c.Timestamp,
}
}
// AppendExtensionData implements SessionManager.
func (a *treeManagerAdapter) AppendExtensionData(extType, data string) (string, error) {
return a.inner.AppendExtensionData(extType, data)
}
// GetExtensionData implements SessionManager.
func (a *treeManagerAdapter) GetExtensionData(extType string) []ExtensionDataEntry {
entries := a.inner.GetExtensionData(extType)
var result []ExtensionDataEntry
for _, e := range entries {
result = append(result, ExtensionDataEntry{
ID: e.ID,
ExtType: e.ExtType,
Data: e.Data,
Timestamp: e.Timestamp,
})
}
return result
}
// AppendModelChange implements SessionManager.
func (a *treeManagerAdapter) AppendModelChange(provider, modelID string) (string, error) {
return a.inner.AppendModelChange(provider, modelID)
}
// GetContextEntryIDs implements SessionManager.
func (a *treeManagerAdapter) GetContextEntryIDs() []string {
return a.inner.GetContextEntryIDs()
}
// Close implements SessionManager.
func (a *treeManagerAdapter) Close() error {
return a.inner.Close()
}
// Helper: Convert internal entry types to BranchEntry
func (a *treeManagerAdapter) convertEntry(entry any) *BranchEntry {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// Build content text from parts
var content strings.Builder
for _, part := range msg.Parts {
if textPart, ok := part.(TextContent); ok {
content.WriteString(textPart.Text)
}
}
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeMessage,
Role: string(msg.Role),
Content: content.String(),
Model: e.Model,
Provider: e.Provider,
Timestamp: e.Timestamp,
RawParts: msg.Parts,
}
case *session.BranchSummaryEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeBranchSummary,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ModelChangeEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeModelChange,
Content: "Model changed to " + e.Provider + "/" + e.ModelID,
Model: e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp,
}
case *session.CompactionEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeCompaction,
Content: e.Summary,
Timestamp: e.Timestamp,
}
case *session.ExtensionDataEntry:
return &BranchEntry{
ID: e.ID,
ParentID: e.ParentID,
Type: EntryTypeExtensionData,
Content: "Extension data: " + e.ExtType,
Timestamp: e.Timestamp,
}
default:
return nil
}
}
// convertKitMessagesToFantasy converts kit LLM messages to fantasy messages.
// Since LLMMessage is an alias for fantasy.Message, this is a no-op.
func convertKitMessagesToFantasy(msgs []LLMMessage) []fantasy.Message {
// LLMMessage is just an alias for fantasy.Message, so we can type convert
return msgs
}
+12 -12
View File
@@ -21,9 +21,9 @@ type ContextStats struct {
const defaultReserveTokens = 16384
// EstimateContextTokens returns the estimated token count of the current
// conversation based on tree session messages.
// conversation based on session messages.
func (m *Kit) EstimateContextTokens() int {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
return compaction.EstimateMessageTokens(messages)
}
@@ -42,8 +42,8 @@ func (m *Kit) ShouldCompact() bool {
reserveTokens = m.compactionOpts.ReserveTokens
}
messages := m.treeSession.GetLLMMessages()
return compaction.ShouldCompact(messages, info.Limit.Context, reserveTokens)
messages := m.session.GetMessages()
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
}
// GetContextStats returns current context usage statistics including
@@ -55,7 +55,7 @@ func (m *Kit) ShouldCompact() bool {
// because it includes system prompts, tool definitions, and other overhead
// that the heuristic cannot account for.
func (m *Kit) GetContextStats() ContextStats {
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
// Prefer the real API-reported input token count when available.
m.lastInputTokensMu.RLock()
@@ -114,7 +114,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
}
}
messages := m.treeSession.GetLLMMessages()
messages := m.session.GetMessages()
if len(messages) < 2 {
return nil, fmt.Errorf("cannot compact: need at least 2 messages")
}
@@ -145,7 +145,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Carry forward file tracking from previous compaction.
var prev *compaction.PreviousCompaction
if lastCompaction := m.treeSession.GetLastCompaction(); lastCompaction != nil {
if lastCompaction := m.session.GetLastCompaction(); lastCompaction != nil {
prev = &compaction.PreviousCompaction{
ReadFiles: lastCompaction.ReadFiles,
ModifiedFiles: lastCompaction.ModifiedFiles,
@@ -171,7 +171,7 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// Non-destructive: append a CompactionEntry to the session tree instead
// of clearing and rewriting messages.
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if result.CutPoint >= 0 && result.CutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[result.CutPoint]
@@ -188,9 +188,9 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// custom summary. It still determines the cut point and persists a
// CompactionEntry.
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
originalTokens := compaction.EstimateMessageTokens(messages)
originalTokens := compaction.EstimateMessageTokens(convertKitMessagesToFantasy(messages))
cutPoint := compaction.FindCutPoint(messages, opts.KeepRecentTokens)
cutPoint := compaction.FindCutPoint(convertKitMessagesToFantasy(messages), opts.KeepRecentTokens)
if cutPoint == 0 {
cutPoint = len(messages) - 1
if cutPoint < 1 {
@@ -198,7 +198,7 @@ func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts
}
}
entryIDs := m.treeSession.GetContextEntryIDs()
entryIDs := m.session.GetContextEntryIDs()
firstKeptEntryID := ""
if cutPoint >= 0 && cutPoint < len(entryIDs) {
firstKeptEntryID = entryIDs[cutPoint]
@@ -234,7 +234,7 @@ func (m *Kit) persistAndEmitCompaction(
originalTokens, compactedTokens, messagesRemoved int,
readFiles, modifiedFiles []string,
) error {
if _, err := m.treeSession.AppendCompaction(
if _, err := m.session.AppendCompaction(
summary,
firstKeptEntryID,
originalTokens,
+2
View File
@@ -48,6 +48,8 @@ func setSDKDefaults() {
viper.SetDefault("temperature", 0.7)
viper.SetDefault("top-p", 0.95)
viper.SetDefault("top-k", 40)
viper.SetDefault("frequency-penalty", 0.0)
viper.SetDefault("presence-penalty", 0.0)
viper.SetDefault("stream", true)
viper.SetDefault("thinking-level", "off")
viper.SetDefault("num-gpu-layers", -1)
+10
View File
@@ -39,6 +39,9 @@ const (
EventCompaction EventType = "compaction"
// EventReasoningDelta fires for each streaming reasoning/thinking chunk.
EventReasoningDelta EventType = "reasoning_delta"
// EventReasoningComplete fires when reasoning/thinking is finished,
// after the last reasoning token has been processed.
EventReasoningComplete EventType = "reasoning_complete"
// EventToolOutput fires when a tool produces streaming output chunks.
EventToolOutput EventType = "tool_output"
EventStepUsage EventType = "step_usage"
@@ -149,6 +152,13 @@ type ReasoningDeltaEvent struct {
// EventType implements Event.
func (e ReasoningDeltaEvent) EventType() EventType { return EventReasoningDelta }
// ReasoningCompleteEvent fires when reasoning/thinking is finished, after the
// last reasoning token has been processed.
type ReasoningCompleteEvent struct{}
// EventType implements Event.
func (e ReasoningCompleteEvent) EventType() EventType { return EventReasoningComplete }
// ToolOutputEvent fires when a tool produces streaming output chunks (e.g., bash output).
type ToolOutputEvent struct {
ToolCallID string
+3
View File
@@ -177,6 +177,7 @@ func TestEventTypes(t *testing.T) {
{ResponseEvent{}, EventResponse},
{CompactionEvent{}, EventCompaction},
{ReasoningDeltaEvent{}, EventReasoningDelta},
{ReasoningCompleteEvent{}, EventReasoningComplete},
{ToolOutputEvent{}, EventToolOutput},
{StepUsageEvent{}, EventStepUsage},
{SteerConsumedEvent{}, EventSteerConsumed},
@@ -224,6 +225,7 @@ func TestEventOrdering(t *testing.T) {
EventMessageStart,
EventMessageUpdate,
EventReasoningDelta,
EventReasoningComplete,
EventToolOutput,
EventToolCall,
EventToolExecutionStart,
@@ -242,6 +244,7 @@ func TestEventOrdering(t *testing.T) {
bus.emit(MessageStartEvent{})
bus.emit(MessageUpdateEvent{Chunk: "hello"})
bus.emit(ReasoningDeltaEvent{Delta: "thinking..."})
bus.emit(ReasoningCompleteEvent{})
bus.emit(ToolOutputEvent{ToolName: "bash", Chunk: "output"})
bus.emit(ToolCallEvent{ToolName: "bash"})
bus.emit(ToolExecutionStartEvent{ToolName: "bash"})
+34 -11
View File
@@ -227,28 +227,51 @@ func (e *extensionAPI) GetMessageRenderer(name string) *extensions.MessageRender
// Session data
func (e *extensionAPI) GetSessionMessages() []extensions.SessionMessage {
return iterBranchMessages(e.kit.treeSession, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if e.kit.session == nil {
return nil
}
// Try to use the legacy iterBranchMessages for backward compatibility
// with the default TreeManager adapter
if adapter, ok := e.kit.session.(*treeManagerAdapter); ok {
return iterBranchMessages(adapter.inner, func(me *session.MessageEntry, msg message.Message) extensions.SessionMessage {
return extensions.SessionMessage{
ID: me.ID,
Role: string(msg.Role),
Content: msg.Content(),
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
}
})
}
// For custom SessionManagers, use the public interface
branch := e.kit.session.GetCurrentBranch()
var result []extensions.SessionMessage
for _, entry := range branch {
if entry.Type == EntryTypeMessage {
result = append(result, extensions.SessionMessage{
ID: entry.ID,
Role: entry.Role,
Content: entry.Content,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
})
}
return result
}
func (e *extensionAPI) AppendEntry(extType, data string) (string, error) {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return "", fmt.Errorf("no session available")
}
return e.kit.treeSession.AppendExtensionData(extType, data)
return e.kit.session.AppendExtensionData(extType, data)
}
func (e *extensionAPI) GetEntries(extType string) []extensions.ExtensionEntry {
if e.kit.treeSession == nil {
if e.kit.session == nil {
return nil
}
entries := e.kit.treeSession.GetExtensionData(extType)
entries := e.kit.session.GetExtensionData(extType)
result := make([]extensions.ExtensionEntry, 0, len(entries))
for _, e := range entries {
result = append(result, extensions.ExtensionEntry{
+349 -136
View File
@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
@@ -11,7 +12,6 @@ import (
"time"
"charm.land/fantasy"
charmlog "github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/agent"
"github.com/mark3labs/kit/internal/config"
@@ -39,7 +39,7 @@ type ContextFile struct {
// agents, sessions, and model configurations.
type Kit struct {
agent *agent.Agent
treeSession *session.TreeManager
session SessionManager
modelString string
events *eventBus
autoCompact bool
@@ -48,6 +48,8 @@ type Kit struct {
skills []*skills.Skill
extRunner *extensions.Runner
bufferedLogger *tools.BufferedDebugLogger
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
opts *Options // stored for reload operations (skills, etc.)
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
@@ -80,8 +82,8 @@ type Kit struct {
// the running agent turn via the LLM library's PrepareStep. Created fresh for
// each generate() call and set to nil when idle. Protected by steerMu.
steerMu sync.Mutex
steerCh chan string
leftoverSteer []string // unconsumed steer messages from the last turn
steerCh chan agent.SteerMessage
leftoverSteer []agent.SteerMessage // unconsumed steer messages from the last turn
}
// Subscribe registers an EventListener that will be called for every lifecycle
@@ -112,15 +114,32 @@ func (m *Kit) GetLoadingMessage() string {
}
// GetLoadedServerNames returns the names of successfully loaded MCP servers.
// If MCP servers are still loading in the background, this returns only the
// servers that have completed loading so far.
func (m *Kit) GetLoadedServerNames() []string {
return m.agent.GetLoadedServerNames()
}
// GetMCPToolCount returns the number of tools loaded from external MCP servers.
// If MCP servers are still loading in the background, this returns the count
// of tools loaded so far (may be 0).
func (m *Kit) GetMCPToolCount() int {
return m.agent.GetMCPToolCount()
}
// WaitForMCPTools blocks until background MCP tool loading completes.
// Returns nil if no MCP servers are configured or if loading succeeded.
// Returns the loading error if all servers failed. Safe to call multiple times.
func (m *Kit) WaitForMCPTools() error {
return m.agent.WaitForMCPTools()
}
// MCPToolsReady returns true if MCP tool loading has completed (or was never
// started). This is a non-blocking check useful for UI status display.
func (m *Kit) MCPToolsReady() bool {
return m.agent.MCPToolsReady()
}
// GetExtensionToolCount returns the number of tools registered by extensions.
func (m *Kit) GetExtensionToolCount() int {
return m.agent.GetExtensionToolCount()
@@ -153,27 +172,39 @@ type StructuredMessage struct {
// flattens all content to a single text string, this preserves tool calls,
// tool results, reasoning blocks, and finish markers as distinct typed parts.
func (m *Kit) GetStructuredMessages() []StructuredMessage {
return iterBranchMessages(m.treeSession, func(me *session.MessageEntry, msg message.Message) StructuredMessage {
return StructuredMessage{
ID: me.ID,
ParentID: me.ParentID,
Role: msg.Role,
Parts: msg.Parts,
Model: msg.Model,
Provider: msg.Provider,
Timestamp: me.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
if m.session == nil {
return nil
}
branch := m.session.GetCurrentBranch()
var results []StructuredMessage
for _, entry := range branch {
if entry.Type != EntryTypeMessage {
continue
}
})
results = append(results, StructuredMessage{
ID: entry.ID,
ParentID: entry.ParentID,
Role: MessageRole(entry.Role),
Parts: entry.RawParts,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format("2006-01-02T15:04:05Z07:00"),
})
}
return results
}
// iterBranchMessages iterates over the current branch's MessageEntry items,
// converting each to a message.Message and calling fn to build the result.
// Returns nil if there is no tree session. Skips entries that are not
// Returns nil if there is no session. Skips entries that are not
// MessageEntry or that fail conversion.
// Deprecated: Use SessionManager.GetCurrentBranch() directly.
func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.MessageEntry, message.Message) T) []T {
if tm == nil {
return nil
}
branch := tm.GetBranch("")
var results []T
for _, entry := range branch {
@@ -224,6 +255,10 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
config.TopP = &topP
topK := int32(viper.GetInt("top-k"))
config.TopK = &topK
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
config.FrequencyPenalty = &frequencyPenalty
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
config.PresencePenalty = &presencePenalty
if err := m.agent.SetModel(ctx, config); err != nil {
return err
@@ -268,8 +303,8 @@ func (m *Kit) GetAvailableModels() []extensions.ModelInfoEntry {
}
// ReloadExtensions hot-reloads all extensions from disk. Event handlers,
// commands, renderers, and shortcuts update immediately. Extension-defined
// tools are NOT updated (they are baked into the agent at creation time).
// commands, renderers, shortcuts, and extension-defined tools all update
// immediately.
func (m *Kit) ReloadExtensions() error {
if m.extRunner == nil {
return fmt.Errorf("no extensions loaded")
@@ -290,6 +325,12 @@ func (m *Kit) ReloadExtensions() error {
// Swap extensions on the runner (clears dynamic state).
m.extRunner.Reload(loaded)
// Update extension tools on the agent so the LLM sees changes.
if m.agent != nil {
extTools := extensions.ExtensionToolsAsFantasy(m.extRunner.RegisteredTools(), m.extRunner)
m.agent.SetExtraTools(extTools)
}
// Re-set context and emit SessionStart.
ctx := m.extRunner.GetContext()
m.extRunner.SetContext(ctx)
@@ -416,6 +457,17 @@ type Options struct {
Tools []Tool // Custom tool set. If empty, AllTools() is used.
ExtraTools []Tool // Additional tools added alongside core/MCP/extension tools.
// SkipConfig, when true, skips loading .kit.yml configuration files.
// Viper defaults (setSDKDefaults) and environment variables (KIT_*)
// are still applied. Use this for fully programmatic configuration.
SkipConfig bool
// DisableCoreTools, when true, prevents loading any core tools.
// Use with Tools or ExtraTools to provide only custom tools.
// If both DisableCoreTools is true and Tools is empty, the agent
// will have no tools (useful for simple chat completions).
DisableCoreTools bool
// Session configuration
SessionDir string // Base directory for session discovery (default: cwd)
SessionPath string // Open a specific session file by path
@@ -433,8 +485,32 @@ type Options struct {
// Debug enables debug logging for the SDK.
Debug bool
// MCPAuthHandler handles OAuth authorization for remote MCP servers.
// When set, remote transports (streamable HTTP, SSE) are configured with
// OAuth support. If the server returns a 401, the handler is invoked to
// let the user authorize via browser.
//
// If nil, a [DefaultMCPAuthHandler] is created automatically — opening the
// system browser and listening on a local callback server.
//
// Set to a custom implementation to control the authorization UX (e.g.
// display a URL in a custom UI, redirect to a web app, etc.).
MCPAuthHandler MCPAuthHandler
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading during Kit initialization. The callback receives the server name,
// tool count, and any error. Called from a background goroutine; safe to
// call app.NotifyMCPServerLoaded() from within the callback to display
// real-time progress in the TUI.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// CLI is optional CLI-specific configuration. SDK users leave this nil.
CLI *CLIOptions
// SessionManager allows custom session storage backends.
// If nil (default), Kit uses the built-in file-based TreeManager.
// When provided, SessionPath, Continue, and NoSession options are ignored.
SessionManager SessionManager
}
// CLIOptions holds fields only relevant to the CLI binary. SDK users should
@@ -499,85 +575,126 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
opts = &Options{}
}
viperInitMu.Lock()
defer viperInitMu.Unlock()
// All viper writes (SetSDKDefaults, InitConfig, Set calls, system-prompt
// composition) happen under viperInitMu. We also call BuildProviderConfig
// here — it's fast (just reads) — so we can capture the full config
// snapshot before releasing the lock. The expensive work (MCP loading,
// provider creation, session init) then runs outside the lock, allowing
// parallel subagent spawns to proceed concurrently.
var (
providerConfig *models.ProviderConfig
modelString string
cwd string
contextFiles []*ContextFile
loadedSkills []*Skill
mcpConfig *config.Config
debug bool
noExtensions bool
maxSteps int
streaming bool
)
// Set CLI-equivalent defaults for viper. When used as an SDK (without
// cobra), these defaults are not registered via flag bindings.
setSDKDefaults()
if err := func() error {
viperInitMu.Lock()
defer viperInitMu.Unlock()
// Initialize config (loads config files and env vars).
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
// Check if model is already set, which indicates config was loaded.
if viper.GetString("model") == "" {
if err := InitConfig(opts.ConfigFile, false); err != nil {
return nil, fmt.Errorf("failed to initialize config: %w", err)
}
}
// Set CLI-equivalent defaults for viper. When used as an SDK (without
// cobra), these defaults are not registered via flag bindings.
setSDKDefaults()
// Handle CLI debug mode.
if opts.Debug {
viper.Set("debug", true)
}
// Override viper settings with options.
if opts.Model != "" {
viper.Set("model", opts.Model)
}
if opts.SystemPrompt != "" {
viper.Set("system-prompt", opts.SystemPrompt)
}
if opts.MaxSteps > 0 {
viper.Set("max-steps", opts.MaxSteps)
}
viper.Set("stream", opts.Streaming)
// Resolve working directory for context/skill discovery.
cwd := opts.SessionDir
if cwd == "" {
cwd, _ = os.Getwd()
}
// Load context files (AGENTS.md) from the project root.
contextFiles := loadContextFiles(cwd)
// Load skills — either from explicit paths or via auto-discovery.
loadedSkills, err := loadSkills(opts)
if err != nil {
return nil, fmt.Errorf("failed to load skills: %w", err)
}
// Always compose the system prompt with runtime context: base prompt +
// AGENTS.md context + skills metadata + date/cwd.
{
basePrompt := viper.GetString("system-prompt")
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
for _, cf := range contextFiles {
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
// Initialize config (loads config files and env vars).
// Only initialize if not already done (e.g., by CLI's cobra.OnInitialize).
// Check if model is already set, which indicates config was loaded.
// SkipConfig bypasses .kit.yml file loading (viper defaults and env vars still apply).
if !opts.SkipConfig && viper.GetString("model") == "" {
if err := InitConfig(opts.ConfigFile, false); err != nil {
return fmt.Errorf("failed to initialize config: %w", err)
}
}
// Inject skills metadata (name + description + location).
if len(loadedSkills) > 0 {
pb.WithSkills(loadedSkills)
// Handle CLI debug mode.
if opts.Debug {
viper.Set("debug", true)
}
// Append current date/time and working directory.
pb.WithSection("", fmt.Sprintf(
"Current date and time: %s\nCurrent working directory: %s",
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
))
// Override viper settings with options.
if opts.Model != "" {
viper.Set("model", opts.Model)
}
if opts.SystemPrompt != "" {
viper.Set("system-prompt", opts.SystemPrompt)
}
if opts.MaxSteps > 0 {
viper.Set("max-steps", opts.MaxSteps)
}
viper.Set("stream", opts.Streaming)
viper.Set("system-prompt", pb.Build())
// Resolve working directory for context/skill discovery.
cwd = opts.SessionDir
if cwd == "" {
cwd, _ = os.Getwd()
}
// Load context files (AGENTS.md) from the project root.
contextFiles = loadContextFiles(cwd)
// Load skills — either from explicit paths or via auto-discovery.
var err error
loadedSkills, err = loadSkills(opts)
if err != nil {
return fmt.Errorf("failed to load skills: %w", err)
}
// Always compose the system prompt with runtime context: base prompt +
// AGENTS.md context + skills metadata + date/cwd.
{
basePrompt := viper.GetString("system-prompt")
pb := skills.NewPromptBuilder(basePrompt)
// Inject AGENTS.md content as project context.
for _, cf := range contextFiles {
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
}
// Inject skills metadata (name + description + location).
if len(loadedSkills) > 0 {
pb.WithSkills(loadedSkills)
}
// Append current date/time and working directory.
pb.WithSection("", fmt.Sprintf(
"Current date and time: %s\nCurrent working directory: %s",
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
))
viper.Set("system-prompt", pb.Build())
}
// Snapshot all viper-derived values now, while the lock is held.
// BuildProviderConfig is fast (pure reads), so we do it here.
var pcErr error
providerConfig, _, pcErr = kitsetup.BuildProviderConfig()
if pcErr != nil {
return fmt.Errorf("failed to build provider config: %w", pcErr)
}
modelString = viper.GetString("model")
debug = viper.GetBool("debug")
noExtensions = viper.GetBool("no-extensions")
maxSteps = viper.GetInt("max-steps")
streaming = viper.GetBool("stream")
return nil
}(); err != nil {
return nil, err
}
// ---- viperInitMu released — heavy I/O below runs concurrently ----
// Load MCP configuration. Use pre-loaded config if provided via CLI options.
var mcpConfig *config.Config
if opts.CLI != nil {
if opts.CLI != nil && opts.CLI.MCPConfig != nil {
mcpConfig = opts.CLI.MCPConfig
}
if mcpConfig == nil {
var err error
mcpConfig, err = config.LoadAndValidateConfig()
if err != nil {
return nil, fmt.Errorf("failed to load MCP config: %w", err)
@@ -595,13 +712,39 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
beforeCompact := newHookRegistry[BeforeCompactHook, BeforeCompactResult]()
// Build agent setup options, pulling CLI-specific fields when available.
// Pass the pre-built ProviderConfig and scalar viper snapshots so
// SetupAgent doesn't need to re-read viper (which would require the lock).
setupOpts := kitsetup.AgentSetupOptions{
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
MCPConfig: mcpConfig,
Quiet: opts.Quiet,
CoreTools: opts.Tools,
DisableCoreTools: opts.DisableCoreTools,
ExtraTools: opts.ExtraTools,
ToolWrapper: hookToolWrapper(beforeToolCall, afterToolResult),
ProviderConfig: providerConfig,
Debug: debug,
NoExtensions: noExtensions,
MaxSteps: maxSteps,
StreamingEnabled: streaming,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
}
// Set up OAuth handler for remote MCP servers.
// The SDK MCPAuthHandler interface is structurally identical to
// tools.MCPAuthHandler, so any implementation satisfies both.
if opts.MCPAuthHandler != nil {
setupOpts.AuthHandler = opts.MCPAuthHandler
} else {
// Create a default handler that opens the system browser.
defaultHandler, authErr := NewDefaultMCPAuthHandler()
if authErr != nil {
// Non-fatal: OAuth just won't be available for remote servers.
log.Printf("WARN Failed to create OAuth handler; remote MCP servers requiring auth will fail: %v", authErr)
} else {
setupOpts.AuthHandler = defaultHandler
}
}
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
@@ -614,17 +757,26 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
return nil, err
}
// Initialize tree session.
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
// Initialize session manager.
var sessionManager SessionManager
if opts.SessionManager != nil {
// Use custom session manager provided by user.
sessionManager = opts.SessionManager
} else {
// DEFAULT: Use built-in TreeManager (existing behavior).
treeSession, err := InitTreeSession(opts)
if err != nil {
_ = agentResult.Agent.Close()
return nil, fmt.Errorf("failed to initialize session: %w", err)
}
// Wrap TreeManager in adapter to satisfy SessionManager interface.
sessionManager = NewTreeManagerAdapter(treeSession)
}
k := &Kit{
agent: agentResult.Agent,
treeSession: treeSession,
modelString: viper.GetString("model"),
session: sessionManager,
modelString: modelString,
events: newEventBus(),
autoCompact: opts.AutoCompact,
compactionOpts: opts.CompactionOptions,
@@ -632,6 +784,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
skills: loadedSkills,
extRunner: agentResult.ExtRunner,
bufferedLogger: agentResult.BufferedLogger,
authHandler: setupOpts.AuthHandler,
opts: opts,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
@@ -904,6 +1058,16 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
if timeout == 0 {
timeout = 5 * time.Minute
}
// Pre-flight check: if the incoming context is already dead, don't
// waste time attempting init. This catches the case where the parent
// generation loop's context was cancelled (e.g. user ESC, step cancel)
// between when the LLM requested the subagent tool and when this code
// runs. We replace it with a fresh context carrying only the timeout,
// since the subagent should be independently bounded.
if ctx.Err() != nil {
ctx = context.Background()
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
@@ -920,6 +1084,17 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
}
}
// Early validation: check model format and provider before doing any
// expensive work (MCP init, system prompt composition, etc.). This
// gives the calling agent immediate feedback it can act on — e.g.
// correcting a typo — instead of waiting for a full Kit.New() cycle
// that silently falls back to the parent model.
if model != m.modelString {
if err := models.GetGlobalRegistry().ValidateModelString(model); err != nil {
return nil, fmt.Errorf("invalid subagent model %q: %w", model, err)
}
}
// Default system prompt.
systemPrompt := cfg.SystemPrompt
if systemPrompt == "" {
@@ -932,9 +1107,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
tools = SubagentTools()
}
// Create child Kit instance. If the requested model fails (bad name,
// unsupported provider, etc.), fall back to the parent's model so the
// agent gets a useful error message instead of a hard failure.
// Create child Kit instance.
childOpts := &Options{
Model: model,
SystemPrompt: systemPrompt,
@@ -943,20 +1116,8 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
Quiet: true,
}
child, err := New(ctx, childOpts)
if err != nil && model != m.modelString {
// Model-specific failure — retry with parent's model.
childOpts.Model = m.modelString
child, err = New(ctx, childOpts)
if err != nil {
return nil, fmt.Errorf("failed to create subagent: %w", err)
}
// Prepend a note so the agent knows which model is actually running.
cfg.Prompt = fmt.Sprintf(
"[Note: requested model %q was not available, using %s instead.]\n\n%s",
model, m.modelString, cfg.Prompt,
)
} else if err != nil {
return nil, fmt.Errorf("failed to create subagent: %w", err)
if err != nil {
return &SubagentResult{Elapsed: time.Since(start)}, fmt.Errorf("failed to create subagent: %w", err)
}
defer func() { _ = child.Close() }()
@@ -970,7 +1131,7 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
elapsed := time.Since(start)
if err != nil {
return nil, err
return &SubagentResult{Elapsed: elapsed}, err
}
subResult := &SubagentResult{
@@ -996,14 +1157,14 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.GenerateWithLoopResult, error) {
// Create a per-turn steer channel and attach it to the context so the
// agent's PrepareStep can inject steering messages between steps.
steerCh := make(chan string, 16)
steerCh := make(chan agent.SteerMessage, 16)
m.steerMu.Lock()
m.steerCh = steerCh
m.steerMu.Unlock()
defer func() {
// Drain any unconsumed steer messages before nilling the channel.
// These are stored in leftoverSteer so DrainSteer() can return them.
var leftover []string
var leftover []agent.SteerMessage
for {
select {
case msg := <-steerCh:
@@ -1093,12 +1254,52 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
func(content string) {
m.events.emit(ToolCallContentEvent{Content: content})
},
func(chunk string) {
m.events.emit(MessageUpdateEvent{Chunk: chunk})
},
// <think> tag filtering: models like Qwen/DeepSeek wrap reasoning inside
// <think>...</think> tags in the regular text stream. We intercept those
// spans here and re-route them as ReasoningDeltaEvent/ReasoningCompleteEvent
// so callers always receive clean, tag-free text and structured reasoning.
func() func(chunk string) {
const (
thinkOpen = "<think>"
thinkClose = "</think>"
)
var inThinkTag bool
return func(chunk string) {
remaining := chunk
for remaining != "" {
if inThinkTag {
i := strings.Index(remaining, thinkClose)
if i == -1 {
m.events.emit(ReasoningDeltaEvent{Delta: remaining})
return
}
if i > 0 {
m.events.emit(ReasoningDeltaEvent{Delta: remaining[:i]})
}
inThinkTag = false
m.events.emit(ReasoningCompleteEvent{})
remaining = remaining[i+len(thinkClose):]
} else {
i := strings.Index(remaining, thinkOpen)
if i == -1 {
m.events.emit(MessageUpdateEvent{Chunk: remaining})
return
}
if i > 0 {
m.events.emit(MessageUpdateEvent{Chunk: remaining[:i]})
}
inThinkTag = true
remaining = remaining[i+len(thinkOpen):]
}
}
}
}(),
func(delta string) {
m.events.emit(ReasoningDeltaEvent{Delta: delta})
},
func() {
m.events.emit(ReasoningCompleteEvent{})
},
func(toolCallID, toolName, chunk string, isStderr bool) {
// Emit tool output chunk event for streaming bash output
m.events.emit(ToolOutputEvent{
@@ -1111,11 +1312,8 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
// Emit step usage event for real-time cost tracking
if viper.GetBool("debug") {
charmlog.Debug("Kit.generate emitting StepUsageEvent",
"input", inputTokens,
"output", outputTokens,
"cacheRead", cacheReadTokens,
"cacheCreate", cacheCreationTokens,
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
)
}
m.events.emit(StepUsageEvent{
@@ -1182,9 +1380,9 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
}
}
// Persist pre-generation messages to tree session.
// Persist pre-generation messages to session.
for _, msg := range preMessages {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
// Auto-compact if enabled and conversation is near the context limit.
@@ -1192,8 +1390,8 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
_, _ = m.compactInternal(ctx, m.compactionOpts, "", true) // best-effort, automatic
}
// Build context from the tree so only the current branch is sent.
messages := m.treeSession.GetLLMMessages()
// Build context from the session so only the current branch is sent.
messages, _, _ := m.session.BuildContext()
// Run ContextPrepare hooks — extensions can filter, reorder, or inject messages.
if hookResult := m.contextPrepare.run(ContextPrepareHook{Messages: messages}); hookResult != nil && hookResult.Messages != nil {
@@ -1216,7 +1414,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
// (pending) message or tool call is discarded.
if result != nil && len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
}
m.events.emit(TurnEndEvent{Error: err})
@@ -1232,7 +1430,7 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
// GetContextStats() see up-to-date token counts.
if len(result.ConversationMessages) > sentCount {
for _, msg := range result.ConversationMessages[sentCount:] {
_, _ = m.treeSession.AppendLLMMessage(msg)
_, _ = m.session.AppendMessage(msg)
}
}
@@ -1314,7 +1512,7 @@ func (m *Kit) Steer(ctx context.Context, instruction string) (string, error) {
// Returns an error if there are no previous messages in the session.
func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
// Verify there is conversation history to follow up on.
if len(m.treeSession.GetLLMMessages()) == 0 {
if len(m.session.GetMessages()) == 0 {
return "", fmt.Errorf("cannot follow up: no previous messages")
}
@@ -1344,6 +1542,13 @@ func (m *Kit) FollowUp(ctx context.Context, text string) (string, error) {
// This is the preferred way to redirect an agent mid-turn without cancelling
// in-progress tool execution.
func (m *Kit) InjectSteer(message string) {
m.InjectSteerWithFiles(message, nil)
}
// InjectSteerWithFiles sends a steering message with optional file attachments
// (e.g. pasted images) into the currently active agent turn. Behaves like
// InjectSteer but includes file parts in the injected user message.
func (m *Kit) InjectSteerWithFiles(message string, files []LLMFilePart) {
m.steerMu.Lock()
ch := m.steerCh
m.steerMu.Unlock()
@@ -1351,7 +1556,7 @@ func (m *Kit) InjectSteer(message string) {
return
}
select {
case ch <- message:
case ch <- agent.SteerMessage{Text: message, Files: files}:
default:
// Channel full — extremely unlikely with buffer of 16, but don't block.
}
@@ -1369,7 +1574,7 @@ func (m *Kit) IsGenerating() bool {
// a turn completes so the app layer can process any steer messages that
// arrived after the last PrepareStep fired (e.g. during a text-only response
// with no tool calls, or after the agent finished its last step).
func (m *Kit) DrainSteer() []string {
func (m *Kit) DrainSteer() []agent.SteerMessage {
m.steerMu.Lock()
defer m.steerMu.Unlock()
@@ -1382,7 +1587,7 @@ func (m *Kit) DrainSteer() []string {
// If a turn is still active, drain from the live channel.
if m.steerCh != nil {
var msgs []string
var msgs []agent.SteerMessage
for {
select {
case msg := <-m.steerCh:
@@ -1463,10 +1668,12 @@ func (m *Kit) PromptResultWithMessages(ctx context.Context, messages []string) (
return m.runTurn(ctx, promptLabel, messages[len(messages)-1], preMessages)
}
// ClearSession resets the tree session's leaf pointer to the root, starting
// ClearSession resets the session's leaf pointer to the root, starting
// a fresh conversation branch.
func (m *Kit) ClearSession() {
m.treeSession.ResetLeaf()
if m.session != nil {
_ = m.session.Branch("")
}
}
// GetModelString returns the current model string identifier (e.g.,
@@ -1535,8 +1742,14 @@ func (m *Kit) Close() error {
if m.extRunner != nil && m.extRunner.HasHandlers(extensions.SessionShutdown) {
_, _ = m.extRunner.Emit(extensions.SessionShutdownEvent{})
}
if m.treeSession != nil {
_ = m.treeSession.Close()
if m.session != nil {
_ = m.session.Close()
}
// Release the OAuth callback port if we own the handler.
if closer, ok := m.authHandler.(interface{ Close() error }); ok {
_ = closer.Close()
}
return m.agent.Close()
}
// Conversion helpers are defined in adapter.go.
+265
View File
@@ -0,0 +1,265 @@
package kit
import (
"context"
"fmt"
"net"
"net/http"
"os/exec"
"runtime"
"sync"
"time"
)
// MCPAuthHandler handles OAuth authorization for MCP servers.
// Implementations control the user experience — opening a browser, showing a
// prompt, displaying a URL, etc.
//
// The default implementation ([DefaultMCPAuthHandler]) opens the system browser
// and starts a local HTTP callback server to receive the authorization code.
type MCPAuthHandler interface {
// RedirectURI returns the OAuth redirect URI that the callback server
// will listen on. This is called during MCP transport setup — before any
// OAuth errors occur — so the redirect URI can be registered with the
// authorization server.
RedirectURI() string
// HandleAuth is called when an MCP server requires OAuth authorization.
// It receives the server name and an authorization URL that the user must
// visit. The handler must:
// 1. Direct the user to authURL (e.g. open browser, display URL)
// 2. Listen for the OAuth callback on the redirect URI
// 3. Return the full callback URL (with code and state query params)
//
// Return an error to abort the connection to this MCP server.
// The context controls the overall timeout; implementations should
// respect ctx.Done().
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
}
// DefaultMCPAuthHandler opens the system browser and starts a local HTTP
// callback server to receive the OAuth authorization code. It eagerly reserves
// a TCP port on construction so [RedirectURI] is stable for the lifetime of
// the handler.
//
// Create instances with [NewDefaultMCPAuthHandler] (random port) or
// [NewDefaultMCPAuthHandlerWithPort] (explicit port).
type DefaultMCPAuthHandler struct {
listener net.Listener
port int
mu sync.Mutex // guards listener lifecycle
}
// NewDefaultMCPAuthHandler creates a handler that listens on a random
// available port on localhost. The port is reserved immediately so
// [RedirectURI] returns a stable value. Call [DefaultMCPAuthHandler.Close]
// when the handler is no longer needed to release the port.
func NewDefaultMCPAuthHandler() (*DefaultMCPAuthHandler, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to listen for OAuth callback: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
return &DefaultMCPAuthHandler{listener: listener, port: port}, nil
}
// NewDefaultMCPAuthHandlerWithPort creates a handler that listens on the
// specified port on localhost. The port is reserved immediately. Pass 0 to
// let the OS pick a free port (equivalent to [NewDefaultMCPAuthHandler]).
// Call [DefaultMCPAuthHandler.Close] when the handler is no longer needed.
func NewDefaultMCPAuthHandlerWithPort(port int) (*DefaultMCPAuthHandler, error) {
addr := fmt.Sprintf("localhost:%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s for OAuth callback: %w", addr, err)
}
actualPort := listener.Addr().(*net.TCPAddr).Port
return &DefaultMCPAuthHandler{listener: listener, port: actualPort}, nil
}
// RedirectURI returns the OAuth redirect URI pointing to the local callback
// server. This value is stable for the lifetime of the handler.
func (h *DefaultMCPAuthHandler) RedirectURI() string {
return fmt.Sprintf("http://localhost:%d/oauth/callback", h.port)
}
// Port returns the TCP port the callback server is bound to.
func (h *DefaultMCPAuthHandler) Port() int {
return h.port
}
// HandleAuth opens the system browser to authURL and waits for the OAuth
// callback on the local server. It returns the full callback URL including
// query parameters (code, state, etc.).
//
// If the context has no deadline, a default 2-minute timeout is applied.
// The callback server is started for each HandleAuth call and shut down
// before returning.
func (h *DefaultMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
h.mu.Lock()
listener := h.listener
h.mu.Unlock()
if listener == nil {
return "", fmt.Errorf("OAuth callback handler is closed")
}
// Apply default timeout if the context has no deadline.
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 2*time.Minute)
defer cancel()
}
// Channel receives the full callback URL from the HTTP handler.
callbackCh := make(chan string, 1)
mux := http.NewServeMux()
mux.HandleFunc("/oauth/callback", func(w http.ResponseWriter, r *http.Request) {
// Reconstruct the full callback URL as the caller expects it.
fullURL := fmt.Sprintf("http://localhost:%d%s", h.port, r.RequestURI)
// Send the callback URL to the waiting goroutine (non-blocking).
select {
case callbackCh <- fullURL:
default:
}
// Respond with a friendly HTML page so the user knows they can
// close the browser tab.
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprint(w, oauthSuccessHTML)
})
server := &http.Server{
Handler: mux,
}
// Start serving on the pre-reserved listener. We need to create a new
// listener on the same port because http.Server.Serve takes ownership
// and closes the listener when done. The original listener is kept open
// to reserve the port; we create a second listener via SO_REUSEADDR
// semantics (Go's default on most platforms) or, more reliably, we
// temporarily release and re-acquire.
//
// Strategy: use the held listener directly for Serve. After Serve
// returns (due to Shutdown), re-acquire the listener to keep the port
// reserved for future HandleAuth calls.
h.mu.Lock()
serveListener := h.listener
h.listener = nil // Serve will close it
h.mu.Unlock()
if serveListener == nil {
return "", fmt.Errorf("OAuth callback handler is closed")
}
// Start the HTTP server in a background goroutine.
serverErrCh := make(chan error, 1)
go func() {
err := server.Serve(serveListener)
if err != nil && err != http.ErrServerClosed {
serverErrCh <- err
}
close(serverErrCh)
}()
// Re-acquire the listener after Serve completes (deferred).
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
_ = server.Shutdown(shutdownCtx)
// Re-reserve the port for future HandleAuth calls.
h.mu.Lock()
defer h.mu.Unlock()
if h.listener == nil {
newListener, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", h.port))
if err == nil {
h.listener = newListener
}
// If re-listen fails, the handler degrades gracefully — the
// next HandleAuth call will return an error.
}
}()
// Open the system browser.
if err := openBrowser(authURL); err != nil {
// Browser open is best-effort; the user can still navigate manually.
_ = err
}
// Wait for the callback, a server error, or context cancellation.
select {
case url := <-callbackCh:
return url, nil
case err := <-serverErrCh:
return "", fmt.Errorf("OAuth callback server error for %q: %w", serverName, err)
case <-ctx.Done():
return "", fmt.Errorf("OAuth authorization timed out for %q: %w", serverName, ctx.Err())
}
}
// Close releases the reserved port and shuts down the handler. After Close,
// HandleAuth will return an error. Close is safe to call multiple times.
func (h *DefaultMCPAuthHandler) Close() error {
h.mu.Lock()
defer h.mu.Unlock()
if h.listener != nil {
err := h.listener.Close()
h.listener = nil
return err
}
return nil
}
// openBrowser opens the default system browser to the given URL. This is a
// best-effort operation — errors are returned but callers typically ignore
// them since the user can navigate manually.
func openBrowser(url string) error {
switch runtime.GOOS {
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
return exec.Command("open", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}
// oauthSuccessHTML is the HTML page returned to the browser after a
// successful OAuth callback.
const oauthSuccessHTML = `<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Authorization Successful</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background: #f8f9fa;
color: #333;
}
.container {
text-align: center;
padding: 2rem;
}
h1 { color: #22863a; }
p { color: #586069; margin-top: 0.5rem; }
</style>
</head>
<body>
<div class="container">
<h1>&#10003; Authorization Successful</h1>
<p>You can close this tab and return to the terminal.</p>
</div>
</body>
</html>`
+68
View File
@@ -0,0 +1,68 @@
package kit
import (
"context"
"fmt"
"io"
"os"
)
// CLIMCPAuthHandler wraps a [DefaultMCPAuthHandler] and prints status messages
// to a writer (typically stderr) so the user knows what's happening during
// OAuth authorization. This is the handler used by the CLI/TUI binary.
//
// For TUI integration, set NotifyFunc to route messages through the TUI's
// event system instead of (or in addition to) the writer.
type CLIMCPAuthHandler struct {
inner *DefaultMCPAuthHandler
w io.Writer
// NotifyFunc, when set, is called with status messages instead of writing
// to the writer. This allows the TUI to display system messages in the
// chat stream. If nil, messages are written to w.
NotifyFunc func(serverName, message string)
}
// NewCLIMCPAuthHandler creates a CLI auth handler that prints status messages
// to stderr and delegates the actual OAuth flow to a [DefaultMCPAuthHandler].
func NewCLIMCPAuthHandler() (*CLIMCPAuthHandler, error) {
inner, err := NewDefaultMCPAuthHandler()
if err != nil {
return nil, err
}
return &CLIMCPAuthHandler{inner: inner, w: os.Stderr}, nil
}
// RedirectURI returns the OAuth redirect URI from the inner handler.
func (h *CLIMCPAuthHandler) RedirectURI() string {
return h.inner.RedirectURI()
}
// HandleAuth prints status messages and delegates to the inner handler.
func (h *CLIMCPAuthHandler) HandleAuth(ctx context.Context, serverName string, authURL string) (string, error) {
h.notify(serverName, fmt.Sprintf("🔐 MCP server %q requires authentication. Opening browser...", serverName))
h.notify(serverName, fmt.Sprintf(" If the browser doesn't open, visit:\n %s", authURL))
callbackURL, err := h.inner.HandleAuth(ctx, serverName, authURL)
if err != nil {
h.notify(serverName, fmt.Sprintf("✗ Authentication failed for %q: %v", serverName, err))
return "", err
}
h.notify(serverName, fmt.Sprintf("✓ Authenticated with %q", serverName))
return callbackURL, nil
}
// Close releases the inner handler's resources.
func (h *CLIMCPAuthHandler) Close() error {
return h.inner.Close()
}
// notify sends a message through NotifyFunc if set, otherwise writes to w.
func (h *CLIMCPAuthHandler) notify(serverName, message string) {
if h.NotifyFunc != nil {
h.NotifyFunc(serverName, message)
return
}
_, _ = fmt.Fprintln(h.w, message)
}
+132
View File
@@ -0,0 +1,132 @@
package kit
import (
"time"
)
// SessionManager defines the contract for conversation storage backends.
// Implementations can use files (default), databases, cloud storage, etc.
type SessionManager interface {
// AppendMessage adds a message to the current branch and returns its entry ID.
// The entry ID is used for tree navigation and must be unique within the session.
AppendMessage(msg LLMMessage) (entryID string, err error)
// GetMessages returns all messages on the current branch (from root to leaf),
// including any compaction summaries at the appropriate positions.
GetMessages() []LLMMessage
// BuildContext returns the message history to send to the LLM, applying
// compaction rules and branch summaries as needed.
// Returns: messages, currentProvider, currentModelID
BuildContext() (messages []LLMMessage, provider string, modelID string)
// Branch moves the leaf pointer to the given entry ID, creating a branch point.
// Subsequent AppendMessage calls extend from this new position.
// entryID can be empty to reset to root (new conversation branch).
Branch(entryID string) error
// GetCurrentBranch returns the path from root to current leaf as entry metadata.
// Used for UI display and navigation.
GetCurrentBranch() []BranchEntry
// GetChildren returns direct child entry IDs for a given parent entry.
// Used to display branch points in the conversation tree.
GetChildren(parentID string) []string
// GetEntry returns a specific entry by ID, or nil if not found.
GetEntry(entryID string) *BranchEntry
// GetSessionID returns the unique session identifier (UUID).
GetSessionID() string
// GetSessionName returns the user-defined display name, or empty.
GetSessionName() string
// SetSessionName sets a display name for the session.
SetSessionName(name string) error
// GetCreatedAt returns when the session was created.
GetCreatedAt() time.Time
// IsPersisted returns true if this session writes to durable storage.
IsPersisted() bool
// AppendCompaction adds a compaction entry that summarizes older messages.
// firstKeptEntryID is the ID of the first message to preserve in context.
// readFiles and modifiedFiles track file changes for the compaction summary.
AppendCompaction(summary string, firstKeptEntryID string,
tokensBefore, tokensAfter int, messagesRemoved int, readFiles, modifiedFiles []string) (string, error)
// GetLastCompaction returns the most recent compaction entry on the current
// branch, or nil if none exists.
GetLastCompaction() *CompactionEntry
// AppendExtensionData stores custom extension data in the session tree.
// Extensions use this to persist state across restarts.
AppendExtensionData(extType, data string) (string, error)
// GetExtensionData returns all extension data entries of the given type
// on the current branch. If extType is empty, returns all extension data.
GetExtensionData(extType string) []ExtensionDataEntry
// AppendModelChange records a provider/model switch in the session.
AppendModelChange(provider, modelID string) (string, error)
// GetContextEntryIDs returns the entry IDs corresponding to the messages
// returned by BuildContext, in the same order. Used by compaction to
// determine which entries to summarize.
GetContextEntryIDs() []string
// Close releases resources (database connections, file handles, etc.).
Close() error
}
// BranchEntry represents a single node in the conversation tree.
// This is a SDK-friendly struct (not the internal entry types).
type BranchEntry struct {
ID string
ParentID string
Type EntryType // "message", "branch_summary", "model_change", "compaction", "extension_data"
Role string // for messages: "user", "assistant", "system", "tool"
Content string // text content or summary
Model string // model used (for messages and model_change)
Provider string // provider used
Timestamp time.Time
Children []string // child entry IDs (for tree display)
// RawParts contains the full typed content parts for structured access.
// Only populated for message entries.
RawParts []ContentPart
}
// EntryType identifies the kind of entry in the session tree.
type EntryType string
const (
EntryTypeMessage EntryType = "message"
EntryTypeBranchSummary EntryType = "branch_summary"
EntryTypeModelChange EntryType = "model_change"
EntryTypeCompaction EntryType = "compaction"
EntryTypeExtensionData EntryType = "extension_data"
)
// CompactionEntry represents a context compaction/summarization event.
type CompactionEntry struct {
ID string
Summary string
FirstKeptEntryID string
TokensBefore int
TokensAfter int
MessagesRemoved int
ReadFiles []string
ModifiedFiles []string
Timestamp time.Time
}
// ExtensionDataEntry represents custom extension data stored in the session.
type ExtensionDataEntry struct {
ID string
ExtType string
Data string
Timestamp time.Time
}
+111 -80
View File
@@ -8,7 +8,6 @@ import (
"time"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/session"
)
@@ -47,49 +46,73 @@ func OpenTreeSession(path string) (*TreeManager, error) {
// --- Instance methods on Kit ---
// GetSessionManager returns the session manager, or nil if not configured.
func (m *Kit) GetSessionManager() SessionManager {
return m.session
}
// GetTreeSession returns the tree session manager, or nil if not configured.
// Deprecated: Use GetSessionManager instead.
func (m *Kit) GetTreeSession() *TreeManager {
return m.treeSession
// Try to unwrap the adapter if using default implementation
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner
}
return nil
}
// SetSessionManager replaces the session manager on a Kit instance.
func (m *Kit) SetSessionManager(sm SessionManager) {
m.session = sm
}
// SetTreeSession replaces the tree session on a Kit instance. This is used by
// the CLI when it handles session creation externally (e.g. --resume with a
// TUI picker) and needs to inject the result into a Kit-like workflow.
// Deprecated: Use SetSessionManager instead.
func (m *Kit) SetTreeSession(ts *TreeManager) {
m.treeSession = ts
m.session = NewTreeManagerAdapter(ts)
}
// GetSessionPath returns the file path of the active tree session, or empty
// for in-memory sessions or when no tree session is configured.
// GetSessionPath returns the file path of the active session, or empty
// for in-memory sessions or when no file-based session is configured.
func (m *Kit) GetSessionPath() string {
if m.treeSession != nil {
return m.treeSession.GetFilePath()
// Only file-based sessions have a path
// Try to get it from the underlying TreeManager if using default adapter
if m.session == nil {
return ""
}
// Check if it's the default adapter
if adapter, ok := m.session.(*treeManagerAdapter); ok {
return adapter.inner.GetFilePath()
}
return ""
}
// GetSessionID returns the UUID of the active tree session, or empty when no
// tree session is configured.
// GetSessionID returns the UUID of the active session, or empty when no
// session is configured.
func (m *Kit) GetSessionID() string {
if m.treeSession != nil {
return m.treeSession.GetSessionID()
if m.session == nil {
return ""
}
return ""
return m.session.GetSessionID()
}
// Branch moves the tree session's leaf pointer to the given entry ID, creating
// Branch moves the session's leaf pointer to the given entry ID, creating
// a branch point. Subsequent Prompt() calls will extend from the new position.
func (m *Kit) Branch(entryID string) error {
return m.treeSession.Branch(entryID)
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.session.Branch(entryID)
}
// SetSessionName sets a user-defined display name for the active tree session.
// SetSessionName sets a user-defined display name for the active session.
func (m *Kit) SetSessionName(name string) error {
if m.treeSession == nil {
return fmt.Errorf("session naming requires a tree session")
if m.session == nil {
return fmt.Errorf("session naming requires a session")
}
_, err := m.treeSession.AppendSessionInfo(name)
return err
return m.session.SetSessionName(name)
}
// ---------------------------------------------------------------------------
@@ -97,27 +120,27 @@ func (m *Kit) SetSessionName(name string) error {
// ---------------------------------------------------------------------------
// GetTreeNode returns a node by ID with full metadata and children.
// Returns nil if entry not found or no tree session.
// Returns nil if entry not found or no session.
func (m *Kit) GetTreeNode(entryID string) *TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
entry := m.treeSession.GetEntry(entryID)
entry := m.session.GetEntry(entryID)
if entry == nil {
return nil
}
return m.entryToTreeNode(entry)
return m.branchEntryToTreeNode(entry)
}
// GetCurrentBranch returns the path from root to current leaf as TreeNodes.
func (m *Kit) GetCurrentBranch() []TreeNode {
if m.treeSession == nil {
if m.session == nil {
return nil
}
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var nodes []TreeNode
for _, entry := range branch {
node := m.entryToTreeNode(entry)
node := m.branchEntryToTreeNode(&entry)
if node != nil {
nodes = append(nodes, *node)
}
@@ -127,34 +150,34 @@ func (m *Kit) GetCurrentBranch() []TreeNode {
// GetChildren returns direct child IDs of an entry.
func (m *Kit) GetChildren(parentID string) []string {
if m.treeSession == nil {
if m.session == nil {
return nil
}
return m.treeSession.GetChildren(parentID)
return m.session.GetChildren(parentID)
}
// NavigateTo branches/forks the session to the specified entry ID.
// Returns an error if the session is unavailable or the entry ID is not found.
func (m *Kit) NavigateTo(entryID string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
return m.treeSession.Branch(entryID)
return m.session.Branch(entryID)
}
// SummarizeBranch uses the LLM to summarize the conversation between two
// entry IDs. Returns the summary text, or an error if the range is invalid,
// the session is unavailable, or the LLM call fails.
func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
if m.treeSession == nil {
return "", fmt.Errorf("no tree session available")
if m.session == nil {
return "", fmt.Errorf("no session available")
}
// Get the branch and find the range
branch := m.treeSession.GetBranch("")
branch := m.session.GetCurrentBranch()
var startIdx, endIdx = -1, -1
for i, entry := range branch {
id := m.treeSession.EntryID(entry)
id := entry.ID
if id == fromID {
startIdx = i
}
@@ -170,7 +193,7 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// Build text to summarize
var content strings.Builder
for i := startIdx; i <= endIdx; i++ {
node := m.entryToTreeNode(branch[i])
node := m.branchEntryToTreeNode(&branch[i])
if node != nil && node.Content != "" {
fmt.Fprintf(&content, "[%s] %s\n\n", node.Role, node.Content)
}
@@ -195,73 +218,81 @@ func (m *Kit) SummarizeBranch(fromID, toID string) (string, error) {
// CollapseBranch replaces a branch range with a summary entry.
// Returns an error if the session is unavailable or the operation fails.
func (m *Kit) CollapseBranch(fromID, toID, summary string) error {
if m.treeSession == nil {
return fmt.Errorf("no tree session available")
if m.session == nil {
return fmt.Errorf("no session available")
}
_, err := m.treeSession.AppendBranchSummary(fromID, summary)
return err
// Note: This operation is not directly supported by SessionManager interface
// as it requires AppendBranchSummary which is TreeManager-specific.
// For custom SessionManagers, this would need to be implemented differently.
// For now, we try to use the underlying TreeManager if available.
if adapter, ok := m.session.(*treeManagerAdapter); ok {
_, err := adapter.inner.AppendBranchSummary(fromID, summary)
return err
}
return fmt.Errorf("CollapseBranch not supported by custom session manager")
}
// entryToTreeNode converts a session entry to a TreeNode.
func (m *Kit) entryToTreeNode(entry any) *TreeNode {
switch e := entry.(type) {
case *session.MessageEntry:
msg, err := e.ToMessage()
if err != nil {
return nil
}
// branchEntryToTreeNode converts a BranchEntry to a TreeNode.
func (m *Kit) branchEntryToTreeNode(entry *BranchEntry) *TreeNode {
if entry == nil {
return nil
}
switch entry.Type {
case EntryTypeMessage:
// Build content from RawParts
var content strings.Builder
for _, p := range msg.Parts {
for _, p := range entry.RawParts {
switch pt := p.(type) {
case message.TextContent:
case TextContent:
content.WriteString(pt.Text)
case message.ReasoningContent:
case ReasoningContent:
content.WriteString(pt.Thinking)
case message.ToolCall:
case ToolCall:
fmt.Fprintf(&content, "[tool_call: %s]", pt.Name)
case message.ToolResult:
case ToolResult:
fmt.Fprintf(&content, "[tool_result: %s]", pt.Content)
}
}
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "message",
Role: string(msg.Role),
Role: entry.Role,
Content: content.String(),
Model: msg.Model,
Provider: msg.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.BranchSummaryEntry:
case EntryTypeBranchSummary:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "branch_summary",
Content: e.Summary,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ModelChangeEntry:
case EntryTypeModelChange:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "model_change",
Content: fmt.Sprintf("Model changed to %s/%s", e.Provider, e.ModelID),
Model: e.Provider + "/" + e.ModelID,
Provider: e.Provider,
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Model: entry.Model,
Provider: entry.Provider,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
case *session.ExtensionDataEntry:
case EntryTypeExtensionData:
return &TreeNode{
ID: e.ID,
ParentID: e.ParentID,
ID: entry.ID,
ParentID: entry.ParentID,
Type: "extension_data",
Content: fmt.Sprintf("Extension data: %s", e.ExtType),
Timestamp: e.Timestamp.Format(time.RFC3339),
Children: m.treeSession.GetChildren(e.ID),
Content: entry.Content,
Timestamp: entry.Timestamp.Format(time.RFC3339),
Children: m.session.GetChildren(entry.ID),
}
default:
return nil
+13
View File
@@ -1,6 +1,7 @@
package kit
import (
"fmt"
"os"
"github.com/mark3labs/kit/internal/extensions"
@@ -136,3 +137,15 @@ func (m *Kit) ClearSkillCache() {
defer m.skillCache.mu.Unlock()
m.skillCache.skills = nil
}
// ReloadSkills re-discovers skills from disk, replacing the current set.
// This is called by file watchers when skill files change.
func (m *Kit) ReloadSkills() error {
newSkills, err := loadSkills(m.opts)
if err != nil {
return fmt.Errorf("reloading skills: %w", err)
}
m.skills = newSkills
m.ClearSkillCache()
return nil
}
+119
View File
@@ -1,6 +1,8 @@
package kit
import (
"context"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/core"
@@ -16,6 +18,123 @@ type ToolOption = core.ToolOption
// If empty, os.Getwd() is used at execution time.
var WithWorkDir = core.WithWorkDir
// --- Custom tool creation ---
// ToolOutput is the return value from custom tool handlers created with
// [NewTool] or [NewParallelTool]. It provides a dependency-free way to
// return results without importing the underlying LLM framework.
type ToolOutput struct {
// Content is the text content returned to the LLM.
Content string
// IsError, when true, signals to the LLM that the tool call failed.
IsError bool
// Data contains optional binary data (images, audio, etc.).
Data []byte
// MediaType is the MIME type for binary Data (e.g. "image/png").
MediaType string
// Metadata is optional opaque metadata attached to the response.
// It is not sent to the LLM but may be consumed by hooks or the UI.
Metadata any
}
// TextResult creates a successful text [ToolOutput].
func TextResult(content string) ToolOutput {
return ToolOutput{Content: content}
}
// ErrorResult creates an error [ToolOutput]. The LLM will see the content
// as a tool error, allowing it to retry or adjust its approach.
func ErrorResult(content string) ToolOutput {
return ToolOutput{Content: content, IsError: true}
}
// toolCallIDKey is the context key for the tool call ID.
type toolCallIDKey struct{}
// ToolCallIDFromContext extracts the tool call ID from the context.
// The call ID is set automatically by [NewTool] and [NewParallelTool]
// before invoking the handler. Returns an empty string if no ID is present.
func ToolCallIDFromContext(ctx context.Context) string {
s, _ := ctx.Value(toolCallIDKey{}).(string)
return s
}
// NewTool creates a custom [Tool] with automatic JSON schema generation from
// the TInput struct type. The handler receives a typed input (deserialized
// from the LLM's JSON arguments) and returns a [ToolResult].
//
// Struct tags on TInput control the generated schema:
//
// json:"name" → parameter name
// description:"..." → parameter description shown to the LLM
// enum:"a,b,c" → restrict valid values
// omitempty → marks the parameter as optional
//
// The tool call ID is injected into the context and can be retrieved with
// [ToolCallIDFromContext].
//
// Example:
//
// type WeatherInput struct {
// City string `json:"city" description:"City name"`
// }
//
// tool := kit.NewTool("get_weather", "Get weather for a city",
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
// return kit.TextResult("72°F, sunny in " + input.City), nil
// },
// )
func NewTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// NewParallelTool is like [NewTool] but marks the tool as safe for concurrent
// execution alongside other tools. Use this when the tool has no side effects
// or when concurrent calls are safe.
func NewParallelTool[TInput any](name, description string, fn func(ctx context.Context, input TInput) (ToolOutput, error)) Tool {
return fantasy.NewParallelAgentTool(name, description,
func(ctx context.Context, input TInput, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
ctx = context.WithValue(ctx, toolCallIDKey{}, call.ID)
result, err := fn(ctx, input)
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
resp := fantasy.ToolResponse{
Content: result.Content,
IsError: result.IsError,
Data: result.Data,
MediaType: result.MediaType,
}
if result.Metadata != nil {
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
}
return resp, nil
},
)
}
// --- Individual tool constructors ---
// NewReadTool creates a file-reading tool.
+119
View File
@@ -0,0 +1,119 @@
package kit_test
import (
"context"
"testing"
kit "github.com/mark3labs/kit/pkg/kit"
)
// TestNewTool_BasicTextResult verifies that NewTool creates a working tool
// that returns text content via ToolOutput.
func TestNewTool_BasicTextResult(t *testing.T) {
type Input struct {
Name string `json:"name"`
}
tool := kit.NewTool("greet", "Greet someone",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("hello " + input.Name), nil
},
)
info := tool.Info()
if info.Name != "greet" {
t.Errorf("Info().Name = %q, want %q", info.Name, "greet")
}
if info.Description != "Greet someone" {
t.Errorf("Info().Description = %q, want %q", info.Description, "Greet someone")
}
if info.Parallel {
t.Error("NewTool should not mark tool as parallel")
}
}
// TestNewParallelTool_MarkedParallel verifies that NewParallelTool marks the
// tool as safe for concurrent execution.
func TestNewParallelTool_MarkedParallel(t *testing.T) {
type Input struct {
Query string `json:"query"`
}
tool := kit.NewParallelTool("search", "Search for things",
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
return kit.TextResult("found: " + input.Query), nil
},
)
info := tool.Info()
if info.Name != "search" {
t.Errorf("Info().Name = %q, want %q", info.Name, "search")
}
if !info.Parallel {
t.Error("NewParallelTool should mark tool as parallel")
}
}
// TestTextResult verifies the TextResult convenience constructor.
func TestTextResult(t *testing.T) {
r := kit.TextResult("ok")
if r.Content != "ok" {
t.Errorf("Content = %q, want %q", r.Content, "ok")
}
if r.IsError {
t.Error("TextResult should not set IsError")
}
}
// TestErrorResult verifies the ErrorResult convenience constructor.
func TestErrorResult(t *testing.T) {
r := kit.ErrorResult("bad input")
if r.Content != "bad input" {
t.Errorf("Content = %q, want %q", r.Content, "bad input")
}
if !r.IsError {
t.Error("ErrorResult should set IsError")
}
}
// TestToolCallIDFromContext verifies round-trip context injection.
func TestToolCallIDFromContext(t *testing.T) {
// Empty context returns empty string.
if id := kit.ToolCallIDFromContext(context.Background()); id != "" {
t.Errorf("expected empty string from bare context, got %q", id)
}
}
// TestToolOutput_Metadata verifies that metadata can be set on ToolOutput.
func TestToolOutput_Metadata(t *testing.T) {
r := kit.ToolOutput{
Content: "data",
Metadata: map[string]string{"key": "value"},
}
if r.Metadata == nil {
t.Error("expected non-nil Metadata")
}
m, ok := r.Metadata.(map[string]string)
if !ok {
t.Fatalf("expected map[string]string, got %T", r.Metadata)
}
if m["key"] != "value" {
t.Errorf("Metadata[key] = %q, want %q", m["key"], "value")
}
}
// TestToolOutput_BinaryData verifies that binary data fields work correctly.
func TestToolOutput_BinaryData(t *testing.T) {
data := []byte{0x89, 0x50, 0x4E, 0x47}
r := kit.ToolOutput{
Content: "image result",
Data: data,
MediaType: "image/png",
}
if len(r.Data) != 4 {
t.Errorf("Data len = %d, want 4", len(r.Data))
}
if r.MediaType != "image/png" {
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
}
}

Some files were not shown because too many files have changed in this diff Show More