Compare commits

...

72 Commits

Author SHA1 Message Date
Ed Zynda 4ef57eec4e docs(session): correct DefaultSessionDir convention comment
- Stale comment showed ~/.kit/sessions/--<cwd-path>--/ which does not
  match the actual encoding (no leading/trailing dashes)
- Update to reflect the real format and point to encodeCwdForDir for
  full rules
2026-05-05 14:54:20 +03:00
Ed Zynda cbd828e190 fix(session): strip illegal characters from windows session dir (#18)
- Encode cwd via new encodeCwdForDir helper that handles both `/`
  and `\` separators and strips characters illegal in Windows
  directory names (`: < > " | ? *`)
- Fixes session creation on Windows where the drive-letter colon
  produced names like `C:--test` and caused mkdir to fail
- Add regression tests covering Unix paths, Windows drive roots,
  secondary drives, mixed separators, and other illegal chars

Fixes #18
2026-05-05 14:46:36 +03:00
Ed Zynda d304805106 Merge pull request #22 from mark3labs/feat/21-mcp-tasks-mvp
feat(mcp): add MCP Tasks support at the SDK level (#21)
2026-05-04 19:30:15 +03:00
Ed Zynda 6e36053856 fix(mcp): validate tasksMode and inherit task options in Subagent (#21)
Address two review findings on the MCP Tasks PR.

- Config.Validate() now rejects unknown tasksMode values with a clear
  error naming the server and bad value. Without this a typo (e.g.
  "alwasy") was silently downgraded to "auto" by the runtime parser.
- Kit.Subagent() now propagates the parent's six MCP task options
  (mode map, timeout, TTL, poll interval, max poll interval, progress
  callback) onto the child via a new inheritMCPTaskOptions helper.
  Without this, child subagents always saw default polling and no
  progress feedback regardless of parent configuration.

The propagation logic lives in a helper so the test exercises the real
code path instead of duplicating it; future task fields only need to be
added in one place.
2026-05-04 17:06:11 +03:00
Ed Zynda 92eaaf6a59 docs(mcp): document MCP Tasks support (#21)
- README: add tasksMode YAML example and MCP Tasks subsection with
  SDK opt-in snippet
- pkg/kit/README: add MCP Tasks subsection covering MCPTaskMode,
  progress callbacks, and List/Get/Cancel methods
- www/configuration: document the tasksMode server field plus a
  per-mode behaviour table
- www/sdk/options: extend the Compaction & MCP table with the six
  new Options fields and add a top-level MCP Tasks section
- www/sdk/overview: add a brief MCP Tasks section between MCP
  prompts/resources and Context & compaction

All examples verified against the public symbols in pkg/kit/mcp_tasks.go;
docs site builds cleanly via npx tome build.
2026-05-04 17:01:47 +03:00
Ed Zynda e6084b7bd0 feat(mcp): add MCP Tasks support at the SDK level (#21)
Implement Phase 1 of the MCP Tasks spec so long-running tools/call
requests can run asynchronously, survive proxy timeouts, and be
cancelled mid-flight.

- connection pool now advertises mcp.NewTasksCapability() during
  initialize and captures the InitializeResult so callers can detect
  per-server task support
- new MCPServerConfig.TasksMode (auto|never|always, default auto)
  parsed from both new and legacy mcp.json shapes
- ExecuteTool augments tools/call with TaskParams when policy and
  capability allow, polls tasks/get / tasks/result until terminal,
  and best-effort tasks/cancel on context cancellation
- new MCPToolManager methods: SetTaskConfig, ListServerTasks,
  GetServerTask, CancelServerTask
- public SDK surface in pkg/kit: MCPTask, MCPTaskStatus, MCPTaskMode,
  MCPTaskProgress, MCPTaskProgressHandler, plus Options fields
  (MCPTaskMode, MCPTaskTimeout, MCPTaskTTL, MCPTaskPollInterval,
  MCPTaskMaxPollInterval, MCPTaskProgress) and Kit.{List,Get,Cancel}
  MCPTask methods
- works around two upstream mcp-go v0.51.0 parser bugs
  (ParseCallToolResult rejects task responses; ParseTaskResultResult
  looks for content under a non-existent nested key) by decoding the
  wire shape directly via the transport
- defaults to MCPTaskModeAuto so servers that don't advertise task
  support behave exactly as before

Fixes #21
2026-05-04 16:51:09 +03:00
Ed Zynda 34d5abff9c build(deps): update dependencies and implement new acp.Agent methods
- Bump fantasy v0.21.0 -> v0.23.0, mcp-go v0.49.0 -> v0.51.0,
  acp-go-sdk v0.12.0 -> v0.12.2, chroma v2.23.1 -> v2.24.1,
  fsnotify v1.9.0 -> v1.10.1, ultraviolet, AWS SDK, Google API
- Implement CloseSession and ResumeSession on acpserver.Agent to
  satisfy the expanded acp.Agent interface in acp-go-sdk v0.12.2
- Add sessionRegistry.remove helper to support session close
2026-05-04 16:23:12 +03:00
Ed Zynda fc0ddd5f4f update 2026-05-04 15:51:00 +03:00
Ed Zynda 7aa6160c75 updates 2026-05-04 12:10:46 +03:00
Ed Zynda e830bf87ca refactor(models): remove responses API model registration hack
Fantasy v0.21.0 natively includes gpt-5.5 and other newer models in
its responsesModelIDs/responsesReasoningModelIDs lists, making our
workaround unnecessary.

- Delete responses_models.go (go:linkname hack + RegisterResponsesModels)
- Delete responses_models_test.go
- Replace isResponsesAPIModel/isResponsesReasoningModel heuristics with
  direct openai.IsResponsesModel/openai.IsResponsesReasoningModel calls
- Remove RegisterResponsesModels calls from registry init/reload
- Remove hack documentation from AGENTS.md
- Update all deps (fantasy v0.21.0, smithy-go, ultraviolet, etc.)
2026-04-27 09:42:52 +03:00
Ed Zynda 3881d1c28f fix(models): auto-register new OpenAI models for Responses API routing
Fantasy's hardcoded responsesModelIDs list gates whether a model uses
the Responses API or Chat Completions code path. When a new model
(e.g. gpt-5.5) is added via `kit update-models` but fantasy hasn't
been updated yet, the type mismatch between *ResponsesProviderOptions
and *ProviderOptions causes a crash.

- Add isResponsesAPIModel()/isResponsesReasoningModel() helpers that
  supplement fantasy's checks with prefix-based heuristics for modern
  OpenAI model families (gpt-4.1+, gpt-5+, o-series, codex, chatgpt)
- Add RegisterResponsesModels() using go:linkname to append missing
  model IDs from our database into fantasy's internal slices at init
  time and after ReloadGlobalRegistry()
- Replace all direct openai.IsResponsesModel/IsResponsesReasoningModel
  calls in providers.go with the new helpers
- Merge embedded + cached model databases instead of cache-only fallback
- Bump fantasy v0.19.0 -> v0.20.0 to match existing import usage
- Document the technique and model-family update process in AGENTS.md
2026-04-24 15:13:38 +03:00
Ed Zynda 53f6682bd0 refactor(core): remove redundant single-edit mode from edit tool
- Remove top-level old_text/new_text params from edit tool schema
- Make edits array the sole interface; single edits pass 1-item array
- Simplify normalizeEditInput, removing dual-mode branching logic
- Update UI renderer to only read from edits array
- Remove old_text/new_text from bodyKeys in message summarizer
- Update web session HTML to iterate edits array
- Convert all single-edit tests to use Edits array
- Replace mixed-mode test with empty-array validation test
2026-04-23 16:33:55 +03:00
Ed Zynda 996b15c9b9 fix(extensions): return nil error for blocked/disabled tools so LLM sees the reason
Tool blocking via OnToolCall and SetActiveTools returned both a
ToolResponse (IsError=true) and a Go error. Fantasy treats a non-nil
Go error from tool.Run() as a critical failure, aborting the agent
loop without delivering the tool result to the LLM. The model never
saw the block reason and would retry or hallucinate.

- Return nil error for blocked tools (OnToolCall Block=true)
- Return nil error for disabled tools (SetActiveTools)
- Return nil error for extension tool execution failures
- Update tests to assert nil error (IsError response conveys the error)

Fixes #20
2026-04-23 13:13:28 +03:00
Ed Zynda aeb704367c feat(app): update token counts and context fill after every step
- Set context tokens per-step in recordStepUsage instead of waiting
  for turn completion; each step re-sends the full conversation so
  the reported usage monotonically increases
- Add UsageUpdatedEvent to trigger a TUI re-render after each step
  so the status bar reflects updated tokens, cost, and context %
  even during gaps between streaming chunks
- Update test to expect per-step context token updates
2026-04-23 12:56:00 +03:00
Ed Zynda d2e23295b6 perf(ui): cache item heights in ScrollList to eliminate redundant renders
- Add heightCache map to ScrollList, keyed by item ID, avoiding
  repeated Render() calls purely to count lines
- Rewrite GotoBottom() to walk backwards from the end in O(visible)
  instead of two full O(N) forward passes over all items
- Replace all height-only Render() calls in clampOffset(), AtBottom(),
  ScrollBy(), and ScrollPercent() with cached itemHeight() lookups
- Invalidate cache on width changes (SetWidth) and item mutations
  (AppendChunk, AppendStdout/Stderr via InvalidateItemHeight)
- Refresh cache entries in View() from authoritative renders
2026-04-23 12:03:44 +03:00
Ed Zynda e5a13e2e12 feat(sdk): add missing LLM type aliases and remove fantasy dependency leakage
- Add LLMToolResultOutputContentMedia alias (closes gap in tool result types)
- Add LLMToolResultContentType enum and constants (Text, Error, Media)
- Add LLMToolInfo, LLMProviderOptions, LLMProviderMetadata, LLMPrompt aliases
- Replace all fantasy.* references in hooks.go and hooks_test.go with
  SDK-owned aliases, removing the charm.land/fantasy import from both
- Fix gofmt alignment in internal/extensions/symbols.go
- Update SDK skill doc with complete LLM type reference
2026-04-22 21:05:04 +03:00
Ed Zynda 558fb5214f feat(sdk): expose remaining Fantasy lifecycle callbacks as events and hooks
Closes #19.

SDK events (pkg/kit):
- Add 10 new event types: StepStart, StepFinish, TextStart, TextEnd,
  ReasoningStart, Warnings, Source, StreamFinish, Error, Retry
- Add typed convenience subscribers for all 31 event types (20 previously
  required raw Subscribe + type assertion)
- Add OnPrepareStep hook for intercepting/replacing messages between
  steps within a multi-step turn (composes with existing steering)
- Rename OnStreaming to OnMessageUpdate (deprecated alias kept)

Agent internals (internal/agent):
- Add GenerateCallbacks struct replacing 16 positional callback params
- Add GenerateWithCallbacks method; deprecate GenerateWithLoopAndStreaming
- Wire all Fantasy stream callbacks: OnStepStart, OnTextStart/End,
  OnReasoningStart, OnWarnings, OnSource, OnStreamFinish, OnError,
  OnRetry, OnStepFinish (unified step event)
- Compose PrepareStep with steering channel + consumer hook

Extension system (internal/extensions):
- Add 8 new extension events: StepStart, StepFinish, ReasoningStart,
  Warnings, Source, Error, Retry, PrepareStep
- Bridge SDK events to extension runner with Yaegi-safe types (string
  errors, plain int64 token fields, ContextMessage for PrepareStep)

Docs: update README, SDK skill, www/sdk/callbacks, www/sdk/overview
2026-04-22 20:25:06 +03:00
Ed Zynda 61408ed490 fix(sdk): infer ToolResponse.Type for binary data in NewTool/NewParallelTool
- Infer Type="image" for image/* MIME types and Type="media" for all
  other binary content so the downstream framework creates a media
  content block instead of silently discarding Data bytes (#17)
- Extract shared toolOutputToResponse() helper to eliminate duplication
- Add ImageResult() and MediaResult() convenience constructors
- Add LLMToolCall and LLMToolResponse type aliases so SDK consumers
  can call Tool.Run() without importing the underlying framework
- Add 6 regression tests covering image, media, and text responses

Closes #17
2026-04-22 16:58:07 +03:00
Ed Zynda 3cfb6437f9 perf(session,ui): reduce syscalls, allocations, and subprocess spam
- Buffer session JSONL writes with bufio.Writer, flush at sync points;
  ForkToNewSession and AddLLMMessages now batch N entries into ~1 syscall
- Cache lipgloss styles in style.CachedStyles, lazily built and
  invalidated on SetTheme; eliminates ~15 NewStyle() calls per frame in
  hot render paths (reasoning blocks, spinner, tool headers, margins)
- Cache git ls-files results for @file suggestions with 3s TTL; typing
  @filename no longer spawns 3 subprocesses per keystroke
- Use strings.Builder for StreamingMessageItem.content; eliminates O(n²)
  string copying during LLM response streaming
2026-04-22 16:48:17 +03:00
Ed Zynda d33ad4028b fix(kit): enable streaming for subagent child instances
- Set Streaming: true in subagent childOpts to prevent
  viper.Set("stream", false) from polluting global state
- Without this, concurrent subagents and the parent could read
  stale stream=false from viper, causing provider-level issues
  (e.g. Anthropic non-streaming timeouts with extended thinking)
2026-04-22 13:06:37 +03:00
Ed Zynda 307dcd1734 cleanup 2026-04-22 11:56:06 +03:00
Ed Zynda 81240b075e chore: update all deps and fix acp-go-sdk v0.12.0 breaking changes
- Update all Go dependencies (bubbletea v2.0.6, fantasy v0.19.0,
  acp-go-sdk v0.12.0, mcp-go v0.49.0, and transitive deps)
- Replace SetSessionModel with SetSessionConfigOption to match new
  acp-go-sdk Agent interface (union type with ValueId/Boolean variants)
- Add ListSessions stub returning empty list (new required method)
- Refresh embedded_models.json from models.dev/api.json
- Update ACP smoke test: add initialize handshake, session/list,
  session/set_config_option, session/cancel, and fix update parsing
2026-04-22 11:55:40 +03:00
Ed Zynda 9a662d440c fix(ui): reduce TUI visual noise and improve layout
- remove "You" label and icon from user messages, use borderless content block
- remove input title bar ("Enter your prompt...") and hint line
- increase textarea from 3 to 4 rows with top/bottom margin
- hide input hints permanently for a cleaner UI
- match separator colors (use theme.Border for both startup and input dividers)
- make startup separator full terminal width instead of hardcoded 80
- add /help for help hint and pipe separators to status bar
- add printCustomMessage/RenderCustomMessage for custom alert labels
- render /help output as markdown with "Help" alert label
- add Ctrl+V (paste image) to help message keys section
- fix reasoning text wrapping using ANSI-aware lipgloss.Style.Width
- export HighlightFileTokens for cross-package use
2026-04-22 11:41:09 +03:00
Ed Zynda 4ba9d6fab3 feat(events): mirror Fantasy tool input streaming callbacks as Kit events
- Add ToolCallStartEvent, ToolCallDeltaEvent, ToolCallEndEvent to SDK
- Wire Fantasy OnToolInputStart/Delta/End through agent to EventBus
- Add typed convenience subscribers: OnToolCallStart/Delta/End on Kit
- Bridge new events to TUI via ToolCallInputStart/Delta/End app events
- Extend extension system with OnToolCallInputStart/Delta/End handlers
- Add extension event types, API methods, loader wiring, Yaegi symbols
- Update docs: README, SDK skill, extensions skill, www/sdk, www/extensions

Closes #16
2026-04-21 23:28:13 +03:00
Ed Zynda aec0e7cc01 docs: document noOAuth MCP server config field
- Add noOAuth to MCP server fields table in www/pages/configuration.md
- Add pubmed example with noOAuth in README and www config docs
2026-04-21 22:44:27 +03:00
Ed Zynda bac04636bf feat(config): add noOAuth flag to skip OAuth on public MCP servers
- Add NoOAuth field to MCPServerConfig with JSON/YAML support
- Guard OAuth error handling and transport setup with the new flag
- Prevents failed dynamic client registration on servers like PubMed
  that do not support OAuth
2026-04-21 22:24:10 +03:00
Ed Zynda 5f851fd08e fix(ui): require double ctrl+c to quit, matching double-esc pattern
- First ctrl+c clears input and arms quit flag with 3s timeout
- Second ctrl+c within timeout window actually quits
- Show '⚠ Press Ctrl+C again to quit' warning after first press
- Empty input no longer quits immediately on single ctrl+c
- Prompt/overlay states: ctrl+c cancels dialog, re-dispatches to
  main handler for double-press tracking instead of quitting
- Update placeholder, help text, and tests to match new behavior
2026-04-21 22:05:13 +03:00
Ed Zynda f8371836d8 fix(cmd): fix character encoding in OAuth success page
Add charset=utf-8 to Content-Type header and use HTML entity
&#10003; instead of raw Unicode checkmark to prevent garbled
text display in browsers.

Fixes #9
2026-04-21 21:19:51 +03:00
Ed Zynda 74f00244be fix(ui): wrap reasoning blocks to terminal width to prevent clipping
- wrap thinking text in StreamComponent and render.ReasoningBlock
- plumb width through renderer and streaming item paths
- keeps style consistent with user/assistant blocks and avoids cut-off lines
2026-04-21 20:42:53 +03:00
Ed Zynda b5d7fd4f3e update docs 2026-04-21 20:33:32 +03:00
Ed Zynda 5857d40978 cleanup 2026-04-21 20:27:32 +03:00
Ed Zynda 3ff701054a fix(models): add gpt-5.4 reasoning level support with auto-adjustment
Adds 'none' thinking level to support OpenAI gpt-5.4 models which use
'reasoning_effort: none' instead of 'minimal'. Includes validation and
auto-adjustment when switching models with incompatible levels.

- Add ThinkingNone constant mapping to ReasoningEffortNone
- Add IsValidThinkingLevelForModel() with gpt-5.4 detection
- Add SuggestThinkingLevelFallback() for level migration
- Auto-adjust thinking level on model switch with user notification
- Update all docs to include 'none' in valid levels

Fixes #11
2026-04-21 20:19:00 +03:00
Ed Zynda c1dee3ceba feat(cmd): add --set-default flag and improve auth error messages
Add --set-default flag to 'kit auth login' to automatically set the
provider's default model after successful authentication. When no Anthropic
credentials exist but OpenAI credentials are detected, error messages
now suggest using OpenAI with the correct --model flag.

Fixes #9
2026-04-21 19:52:06 +03:00
Ed Zynda 2d9783a44d fix(ui): make ctrl+c clear input before quitting
Change Ctrl+C behavior to match other terminal AI tools (claude, codex, pi):
- First Ctrl+C clears the current input when text is present
- Second Ctrl+C (within 3 seconds) quits the application
- Ctrl+C on empty input quits immediately
- 3-second auto-reset timer clears the 'pressed once' state
- Flag also resets after message submission

Updates placeholder text and help message to reflect new behavior.

Fixes #13
2026-04-21 19:32:48 +03:00
Ed Zynda 88dd216e15 fix(session): prevent circular parent references in tree session
Add defensive validation to detect and prevent cycles in the session tree
parent chain that could occur after compaction or file corruption.

- Add tree_validation.go with cycle detection and parent chain validation
- Validate parent chain before appending messages (AppendMessage)
- Validate firstKeptEntryID exists in AppendCompaction
- Add depth limit and cycle detection to buildTreeNode to prevent infinite recursion
- Log diagnostics on session open to detect existing cycles
- Add tests for cycle detection and graceful handling
2026-04-21 16:24:38 +03:00
Ed Zynda 9e5806ade8 fix(subagent): remove biased model example from tool schema
- Remove vendor-specific model example that could bias LLM selection
- Add minimum recommended timeout guidance to subagent schema
2026-04-21 11:28:32 +03:00
Ed Zynda 50f586ec8f chore(models): update embedded model database from models.dev
Update internal/models/embedded_models.json with the latest snapshot
from https://models.dev/api.json.

- Providers: 111 → 115 (+4)
- Models: 4,191 → 4,259 (+68)
2026-04-21 10:38:23 +03:00
Ed Zynda 8a8e684dff docs(sdk): document MCPAuthHandler and OAuth opt-in behavior
Reflect the refactor that made MCPAuthHandler an explicit, opt-in
dependency for remote MCP OAuth. Four surfaces updated:

- README.md: new 'MCP OAuth (remote MCP servers)' subsection under the
  Go SDK section, outlining the three consumer patterns (nil / CLI /
  custom) and linking to the full options docs.
- pkg/kit/README.md: type cheat-sheet now lists MCPAuthHandler,
  DefaultMCPAuthHandler, and CLIMCPAuthHandler alongside the existing
  MCPTokenStore entries.
- skills/kit-sdk/SKILL.md: Options example annotated with nil-disables-
  OAuth semantics; new 'MCP OAuth Authorization' section precedes the
  existing token-storage section; re-exported types list expanded.
- www/pages/sdk/options.md: Options fields table gains MCPAuthHandler
  row; new top-level 'MCP OAuth Authorization' section with consumer
  matrix, CLI/custom/fully-custom code samples, and a warning callout
  about the OnAuthURL nil-hang footgun.
2026-04-17 15:30:10 +03:00
Ed Zynda 7ef99ac60f refactor(sdk): remove UX policy from MCP OAuth handler
Strip user-facing I/O out of the SDK's OAuth surface so library, daemon,
and web-app embedders are not surprised by port binds or browser opens.

- DefaultMCPAuthHandler no longer calls openBrowser; it exposes an
  OnAuthURL(serverName, authURL) hook and performs no presentation I/O.
- kit.New no longer auto-constructs a default handler when
  Options.MCPAuthHandler is nil. OAuth is opt-in; remote MCP servers
  requiring authorization fail with a clear error if no handler is set.
- CLIMCPAuthHandler owns the CLI policy (browser open + stderr prints)
  by wiring an OnAuthURL closure on the inner DefaultMCPAuthHandler.
- openBrowser is now unexported and colocated with its sole caller; no
  new exported helper is added to the SDK surface.

BREAKING CHANGE: SDK consumers relying on implicit OAuth with a nil
MCPAuthHandler must now pass kit.NewCLIMCPAuthHandler() (or a custom
implementation) explicitly. The kit CLI is unaffected — cmd/root.go
already constructs the handler explicitly.
2026-04-17 15:26:35 +03:00
Ed Zynda a67f514560 chore(models): refresh embedded models.dev database
- update internal/models/embedded_models.json from https://models.dev/api.json
- 110 → 111 providers, 4172 → 4191 models
2026-04-17 12:19:21 +03:00
Ed Zynda b6bb35cb71 Merge pull request #7 from mark3labs/feat/sdk-options-overrides
feat(sdk): expose generation and provider params on Options
2026-04-17 12:15:47 +03:00
Ed Zynda 4e82fac442 fix(fileutil): decouple TestDetectMediaType from system MIME db
TestDetectMediaType/.go fails on CI images (Ubuntu mime-support) where
/etc/mime.types registers '.go → text/x-go', because mime.TypeByExtension
reads those files at init. The test intended to exercise the 'unknown
extension falls through to text/plain' branch but used a real extension,
making the assertion environment-dependent.

Replace '.go' with '.kitsyntheticext', an invented extension that no
system MIME database registers. The fallback path is now exercised
deterministically on any host.
2026-04-17 12:13:28 +03:00
Ed Zynda 5ec2217b0f docs(sdk): document global viper state leakage in New and Options
The SDK applies Options by calling viper.Set on viper's process-global
store, which means two Kits constructed in the same process are not
isolated from each other: the second New overwrites the first's keys,
and downstream readers (SetModel, GetThinkingLevel, BuildProviderConfig)
observe the most recent value.

- Add a 'Global viper state warning' block to the Options godoc
  explaining the leak, the zero-value-does-not-clear gotcha, and
  pointing at viper.Reset() as the migration workaround.
- Add a matching warning to the New godoc so consumers discover the
  constraint from either entry point.
- Detach the viperInitMu godoc (previously lodged inside New's comment
  block) and clarify that the mutex only guards the construction
  window, not instance isolation.
- Add a TODO noting the proper fix: refactor to a per-call viper.New()
  instance so each Kit owns its own config store.
2026-04-17 12:09:13 +03:00
Ed Zynda 8a851723ba style(sdk): gofmt trailing newlines in kit_test.go 2026-04-17 12:07:54 +03:00
Ed Zynda 53b628c5f8 fix(sdk): map hyphenated config keys to KIT_* env vars
- InitConfig now installs a viper env key replacer so keys like
  "max-tokens" bind to KIT_MAX_TOKENS under AutomaticEnv; previously
  hyphenated keys silently missed their documented env overrides.
- Simplify TestNewPreservesIsSetSemantics: with SkipConfig: true no env
  bindings are registered, so the os.Getenv guard and upper() helper
  were dead weight. Remove both and drop the unused helper.
2026-04-17 12:07:29 +03:00
Ed Zynda e1c94cb362 fix(sdk): align SDK max-tokens floor with CLI default (4096 → 8192)
The SDK last-resort MaxTokens floor is applied in kit.New() when
Options.MaxTokens, KIT_MAX_TOKENS, .kit.yml, and per-model defaults
are all unset. It was 4096 (inherited from the old setSDKDefaults
viper default) while the CLI --max-tokens cobra default is 8192.

Bump the floor to 8192 so SDK and CLI callers start from the same
base value before rightSizeMaxTokens runs, then update README,
skills/kit-sdk/SKILL.md, and www/pages/{configuration,sdk/options}.md
to match.
2026-04-17 11:59:49 +03:00
Ed Zynda ecf95b52e1 fix(sdk): preserve IsSet semantics for generation param overrides
Previously setSDKDefaults() registered viper.SetDefault for max-tokens,
temperature, top-p, top-k, frequency/presence-penalty, and thinking-level.
viper.SetDefault makes IsSet() return true, which silently suppressed
per-model defaults (ApplyModelSettings) and automatic right-sizing
(rightSizeMaxTokens) for every SDK-created Kit — and for CLI runs too,
since cmd/root.go routes through kit.New. Effective max-tokens for
claude-sonnet-4-5 was pinned at 4096 instead of 32768.

- Drop SetDefault for all IsSet-sensitive keys; keep only model,
  system-prompt, stream, num-gpu-layers, main-gpu.
- Apply a 4096 max-tokens floor directly on the *models.ProviderConfig
  struct in kit.New() when nothing else resolved a value. Keeps
  viper.IsSet("max-tokens") == false so rightSizeMaxTokens and
  per-model maxTokens overrides still fire.
- Update Options.MaxTokens / ThinkingLevel godoc to describe the real
  precedence chain.
- Strengthen tests: add Temperature subtest; add
  TestNewPreservesIsSetSemantics regression covering all seven keys;
  split TestNewWithProviderOptions into three subtests including
  Options-beats-viper-state and ProviderURL propagation; add
  resetViper helper so subtests don't bleed state.
- Document the new SDK fields (MaxTokens, ThinkingLevel, Temperature,
  TopP, TopK, FrequencyPenalty, PresencePenalty, ProviderAPIKey,
  ProviderURL, TLSSkipVerify) in README, skills/kit-sdk, and the www
  configuration / sdk/options / sdk/overview pages, including a
  dedicated precedence table.
2026-04-17 11:50:45 +03:00
Ed Zynda 0641c92acc feat(sdk): expose generation and provider params on Options
Adds programmatic overrides on kit.Options for the model/provider knobs
that were previously only reachable through viper.Set() — letting SDK
consumers (web apps, services, embedded agents) configure kit fully
in-code without polluting global viper state or shipping .kit.yml.

Generation parameters:
  - MaxTokens         int      (max output tokens per response)
  - ThinkingLevel     string   (off/low/medium/high)
  - Temperature       *float32
  - TopP              *float32
  - TopK              *int32
  - FrequencyPenalty  *float32
  - PresencePenalty   *float32

Sampling params use pointer types so explicit 0 is distinguishable from
unset; nil leaves provider/per-model defaults in place.

Provider configuration:
  - ProviderAPIKey    string
  - ProviderURL       string
  - TLSSkipVerify     bool

Implementation just pushes Options values into viper inside New(),
so all existing downstream code (BuildProviderConfig, SetModel,
modelSettings lookups, runtime model switching) picks them up
uniformly without any new code paths. Tests added for MaxTokens,
ThinkingLevel, and ProviderAPIKey.
2026-04-17 11:24:00 +03:00
Ed Zynda 3bb20f5283 feat(models): surface and prevent silent max-tokens truncation
- Raise --max-tokens default from 4096 to 8192.
- Auto-raise MaxTokens toward the model's catalog Limit.Output (capped at
  32768) when the user hasn't set --max-tokens explicitly and no per-model
  modelSettings override applied. Prevents silent 4k/8k truncation on
  models that support 32k-262k output.
- Surface FinishReasonLength at turn end: the app now subscribes to
  TurnEndEvent and renders a system-message banner explaining the current
  cap, the model's known ceiling, and how to raise it. Previously the TUI
  swallowed 'length' stops, producing 'ghost' truncations.
- Export FinishReason* constants on pkg/kit (Stop, Length, ToolCalls,
  ContentFilter, Error, Other, Unknown) and fix stale comments that used
  Anthropic-style strings.
- Add Kit.MaxTokens() and Kit.MaxOutputLimit() SDK accessors, backed by
  Agent.GetMaxTokens() which correctly returns 0 for providers that
  suppress the param (e.g. Codex OAuth).
- Tests: rightSizeMaxTokens covers 7 paths (cap, raise, preserve,
  explicit flag, nil info, zero limit); handleTurnEnd covers length/
  non-length/nil-sendFn and the fallback message formatter.
- Docs: update configuration.md, cli/flags.md, and kit-extensions skill
  to reflect the new default and behavior.
2026-04-16 23:12:10 +03:00
Ed Zynda 633fa38b2b fix(ui): regenerate spinner frames on theme change
- UpdateTheme() only refreshed typography styles, leaving spinner
  frames rendered with the old theme's colors
- Now calls knightRiderFrames() to rebuild frames with the new
  theme's Primary, Muted, VeryMuted, and MutedBorder colors
2026-04-16 12:32:49 +03:00
Ed Zynda f905cee48c fix(ui): dynamically size slash command name column in popup
- Replace hardcoded nameWidth of 15 with dynamic calculation based on
  the longest command name in the filtered list
- Prevents truncation of longer names like /feature-request and
  /release-tagger that were cut off with ellipsis
- Cap name column to leave at least 20 chars for descriptions
- Add 1 char gap between name and description columns
2026-04-16 12:27:56 +03:00
Ed Zynda 182c10ea1a refactor(ui): improve keybinding ergonomics for terminal multiplexers
- Move thinking toggle from ctrl+t to leader chord (ctrl+x t) to avoid
  conflicts with tmux/zellij tab mode and terminal new-tab shortcuts
- Change scrollback jump from alt+home/alt+end to ctrl+home/ctrl+end
  for better compatibility across SSH and older tmux versions
- Remove ctrl+d as submit alias (enter suffices); avoids EOF convention
  confusion and accidental shell disconnects
- Remove ctrl+a from tree selector filter shortcuts to avoid conflict
  with the common tmux prefix remap (ctrl+o cycle still reaches all
  filter modes)
2026-04-16 12:21:37 +03:00
Ed Zynda fcaa52bf1c fix(extensions): serialize handler calls per-extension to prevent data races
- Add per-extension reentrant mutex to Runner that serializes handler
  invocations from concurrent goroutines (e.g. parallel subagent events)
  while allowing re-entrant calls (handler → EmitCustomEvent → handler)
- Fix subagent-monitor slice aliasing bug: submonEntries[:0] reuses the
  backing array, corrupting entries during in-place filtering
- Pass parent's loaded MCPConfig to child subagents in Kit.Subagent(),
  eliminating concurrent viper map access during parallel kit.New() calls
- Add Options.MCPConfig field so SDK consumers can also skip viper reads
- Add tests for concurrent emit, cross-extension concurrency, and
  re-entrant EmitCustomEvent
2026-04-16 12:11:10 +03:00
Ed Zynda 7e6455732c docs: update documentation for sudo password prompt feature
- README.md: mention interactive sudo password prompt in features
- skills/kit-sdk/SKILL.md: add PasswordPromptEvent to event types table
- www/pages/index.md: update features list with sudo prompt
- www/pages/development.md: update project structure description
- www/pages/sdk/callbacks.md: add complete event types table
2026-04-15 18:06:11 +03:00
Ed Zynda 71301a9035 feat: add interactive sudo password prompt for bash tool
Add core TUI support for handling sudo password prompts when executing
bash commands that require elevated privileges.

- Detect sudo commands and check if credentials are cached (sudo -n)
- Show modal password prompt with masked input (• characters) when needed
- Pipe password via stdin using sudo -S -p '' (no password in command string)
- Password flows through context callbacks, never stored in session history
- Add PasswordPromptHandler to agent and SDK event system
- Add password prompt overlay to TUI with 🔐 icon and hidden input
- Include tests for sudo command detection and rewriting

The password is never persisted to disk - it only exists in memory
during execution and is piped directly to sudo via stdin.
2026-04-15 17:33:03 +03:00
Ed Zynda 0974d37ab2 feat(sdk): support mcp-go in-process transport for MCP servers
- Add InProcessServer field to MCPServerConfig (json:"-", never serialized)
- Add "inprocess" transport type to config, validation, and connection pool
- Add createInProcessClient() using mcp-go client.NewInProcessClient()
- Add Kit.AddInProcessMCPServer() convenience method
- Add Options.InProcessMCPServers for init-time registration
- Export MCPServer type alias (= server.MCPServer) in pkg/kit/types.go
- Add 8 tests covering config, pool, tool manager, and edge cases
- Update SDK README, kit-sdk skill, and www docs
2026-04-15 16:29:07 +03:00
Ed Zynda 398e825df8 docs: update docs for recent features and API additions
- Add smart @ attachments (MIME detection, @mcp:server:uri syntax)
- Add MCP Prompts and Resources SDK APIs to skill and www docs
- Add $+ required variadic placeholder for prompt templates
- Add Ctrl+X e (external editor) and Ctrl+X s (steer) keyboard shortcuts
- Fix stale Ctrl+S references, now Ctrl+X s for mid-turn steering
- Add --frequency-penalty and --presence-penalty CLI flags
- Add per-model settings (modelSettings) to configuration docs
- Add NoExtensions, NoSkills, NoContextFiles, SessionManager,
  MCPTokenStoreFactory to SDK options docs
- Add bridge_demo.go to extension examples
- Add dynamic MCP servers, subagents to SDK overview
2026-04-15 16:02:49 +03:00
Ed Zynda 3c51c20be7 feat(mcp): handle embedded resources in prompt messages
- Extract all MCP content types in prompt expansion: ImageContent,
  AudioContent, EmbeddedResource (text and blob), and ResourceLink
- Add MCPFilePart type to carry decoded binary attachments through
  the tools → SDK → bridge → UI layers
- Inline text resources as fenced code blocks with URI annotation
- Decode image/audio/blob content from base64 into LLMFilePart
  attachments submitted via RunWithFiles
- Render ResourceLink as text annotation for the LLM
- Show attachment badges on user messages (e.g. '1 image(s) attached')
  matching the existing clipboard paste UI pattern
- Log warnings on base64 decode failures instead of silently dropping
2026-04-15 15:23:01 +03:00
Ed Zynda 25410af440 feat: add smart @ attachments with MIME detection and MCP resource support
Phase 1: Smart @ for local files
- ProcessFileAttachments now returns FileAttachmentResult with separate
  ProcessedText and FileParts fields instead of a plain string
- Binary files (images, audio, video, PDFs, etc.) detected via MIME type
  are extracted as multimodal FileParts instead of XML-wrapped text garbage
- detectMediaType() uses extension-based lookup then content sniffing
- isBinaryMediaType() classifies image/*, audio/*, video/*, and specific
  application types as binary
- @mcp:server:uri token format for referencing MCP resources in text
- All 4 submission paths (TUI submit, TUI steer, MCP prompt, CLI) updated
- App.RunOnceWithFiles/RunOnceResultWithFiles/RunOnceWithDisplayAndFiles
  added for non-interactive multimodal submission

Phase 2: MCP resources in @ autocomplete
- MCPToolManager gains loadServerResources(), GetResources(), ReadResource(),
  SubscribeResource(), UnsubscribeResource(), RefreshServerResources()
- MCPResource and MCPResourceContent types for resource metadata/content
- FileSuggestion extended with IsMCPResource, MCPServerName, MCPResourceURI
- InputComponent.SetMCPResourceProvider() wires resource suggestions into
  the @ popup alongside local files
- @ popup merges local file suggestions with MCP resource suggestions,
  sorted by fuzzy match score
- MCP resources display 'mcp:servername' in the popup description
- Selecting an MCP resource inserts @mcp:server:uri format
- ProcessFileAttachments resolves @mcp: tokens via MCPResourceReader callback
- Text resources are XML-wrapped as <resource>; binary resources become
  FileParts for multimodal submission
- Agent, Kit SDK, and cmd/root.go wired end-to-end

Phase 3: Resource subscriptions (foundation)
- SubscribeResource/UnsubscribeResource on MCPToolManager
- onResourcesChanged callback for live refresh (wired but not yet
  triggering UI refresh automatically)
- RefreshServerResources for manual resource list refresh
2026-04-15 13:01:36 +03:00
Ed Zynda 26c9f009f9 refactor: remove fantasy dependency name leaks from SDK surface
- Rename ExtensionToolsAsFantasy -> ExtensionToolsAsLLMTools
- Rename convertKitMessagesToFantasy -> convertToLLMMessages
- Delete GetFantasyProviders, ToFantasyMessages, FromFantasyMessage
- Replace direct fantasy type usage with kit.LLM* aliases in app tests
- Scrub fantasy references from godoc comments across pkg/kit and internal
2026-04-15 12:24:52 +03:00
Ed Zynda e068487ff7 style(ui): fix gofmt alignment in MCPPromptInfo struct 2026-04-15 11:50:33 +03:00
Ed Zynda 0ffb0ba788 refactor(tools): remove fantasy dependency from internal/tools
- Replace fantasy.AgentTool with plain MCPTool struct in MCPToolManager
- Move fantasy adapter from internal/tools to internal/agent as mcpAgentTool
- Add MCPToolManager.ExecuteTool() for framework-agnostic tool execution
- Remove dead fantasy.LanguageModel field from MCPConnectionPool
- Remove MCPToolManager.SetModel() (was only feeding the dead field)

internal/tools is now a pure MCP client library with no LLM framework
dependency. The fantasy-to-MCP bridging is confined to the agent layer
where it belongs.
2026-04-15 11:27:47 +03:00
Ed Zynda 65c6e9f797 refactor(models): decouple TUI progress from SDK dependency tree
- Remove direct internal/ui/progress import from internal/models/providers.go
- Add ProgressReaderFunc callback to ProviderConfig for dependency inversion
- Wire Bubble Tea progress reader via CLIOptions in cmd/root.go
- Add NewProgressReadCloser convenience wrapper in progress package
- SDK consumers (pkg/kit) no longer transitively pull bubbletea, lipgloss
  v2, or bubbles
- Update embedded_models.json from models.dev (110 providers, 4172 models)
2026-04-14 17:17:01 +03:00
Ed Zynda 68d798d2f4 feat(prompts): add $+ required variadic, skip code in placeholders
- Add internal/fences package for detecting markdown code regions
  (fenced blocks and inline code spans) with ReplaceOutside/StripCode
- SubstituteArgs, HasArgPlaceholders, RequiredArgs now skip $
  placeholders inside ``` fences and `inline` code spans
- ProcessFileAttachments skips @file tokens inside code regions
- Add $+ placeholder: expands like $@ but requires at least 1 argument
- Add RequiredArgs() method; expandPromptTemplate validates arg count
  and re-populates input on failure instead of submitting
- Update feature-request, file-issue, new-prompt to use $+
2026-04-14 13:22:10 +03:00
Ed Zynda eefd5565f8 feat(ui): populate input instead of auto-submitting prompts with args
- Add HasArgPlaceholders() method to PromptTemplate to detect , $@,
  $ARGUMENTS, etc. placeholders in template content
- Add HasArgs field to SlashCommand struct
- Set HasArgs when registering prompt templates as slash commands
- In fuzzy finder Enter handler, populate input with command + trailing
  space when HasArgs is true, letting the user type arguments naturally
- Fix potential index bug by capturing selectedCmd before resetting index
2026-04-14 12:46:12 +03:00
Ed Zynda 9d1b8a102e feat(ui): open external $EDITOR via ctrl+x e chord
- Add ctrl+x e leader key chord to open $VISUAL/$EDITOR in a temp file
  pre-populated with the current input text
- On save & quit, replace the input textarea with the edited content
- On error exit (e.g. :cq in vim), silently preserve original input
- Use charmbracelet/x/editor for editor command construction
- Use tea.ExecProcess to suspend/resume the TUI around the editor
- Update input hint text at all width breakpoints to show the shortcut
- Add ctrl+x e to /help output

Closes #5
2026-04-14 12:39:29 +03:00
Ed Zynda f57e045c69 feat(ui): highlight @file tokens in user messages with accent color
- Add highlightFileTokens() to style @file references with theme.Accent + bold
- Apply highlighting in UserBlock() after line wrapping, before herald rendering
- Support both unquoted (@path/to/file) and quoted (@"path with spaces") tokens
- Add tests for highlightFileTokens and UserBlock integration

Closes #6
2026-04-14 12:28:04 +03:00
Ed Zynda eb5da28a15 chore(deps): update all dependencies
- bump go directive to 1.26.2 (required by fantasy v0.17.2)
- fantasy v0.17.1 → v0.17.2
- bubbletea/v2 v2.0.2 → v2.0.5
- lipgloss/v2 v2.0.2 → v2.0.3
- mcp-go v0.47.1 → v0.48.0
- x/ansi v0.11.6 → v0.11.7
- x/term v0.41.0 → v0.42.0
- genai v1.52.1 → v1.54.0
- ultraviolet, x/crypto, x/net, x/text, x/exp, and other indirects
- allow MODEL env override in acp_smoke_test.py
2026-04-14 12:16:22 +03:00
Ed Zynda cd8e2a7654 feat(extensions): expand inline bash in editor for interactive mode
- Add editor interceptor via OnSessionStart so !{...} expansions
  appear in the user message block on screen
- Restrict OnInput handler to non-interactive sources (CLI, script,
  queue) to avoid double expansion
- Extract expand() helper and hoist regex to package level
- Update doc comments and examples
2026-04-14 11:56:41 +03:00
Ed Zynda 64da1caf41 docs(readme): add clickable links to extension examples
- Convert 31 extension example entries from plain code spans to
  Markdown links pointing to their source files
- Link the go-edit-lint.go and tool-logger_test.go references
2026-04-14 11:39:43 +03:00
Ed Zynda 7eaeafff8c fix(mcp): propagate OAuth config for runtime-added servers
- Store authHandler and tokenStoreFactory on Agent struct so
  AddMCPServer() can propagate them to new MCPToolManagers (#3)
- Add OAuthClientID, OAuthClientSecret, OAuthScopes fields to
  MCPServerConfig for servers without dynamic registration (#4)
- Pass OAuth fields from server config to transport OAuthConfig
  in both SSE and Streamable HTTP client creation paths
- Add GetAuthHandler() accessor to MCPToolManager
- Add tests for auth handler propagation and OAuth config parsing

Closes #3, closes #4
2026-04-11 15:24:47 +03:00
Ed Zynda 8ed8d23c73 docs(sdk): update kit-sdk skill with recent API additions
- Add NoSkills, NoExtensions, NoContextFiles options
- Add MCPTokenStoreFactory option and MCP OAuth types
- Add dynamic MCP server management (AddMCPServer, RemoveMCPServer,
  ListMCPServers, MCPServerStatus)
- Add per-model system prompts and generation parameters sections
- Add MCPToolsReady() to tool querying section
- Expand LLMUsage fields to include CacheCreationTokens/CacheReadTokens
- Update FinalUsage and ShouldCompact docs for cache-aware token counting
- Add MCP OAuth types to re-exported types reference
2026-04-11 12:09:51 +03:00
128 changed files with 13262 additions and 1430 deletions
+47
View File
@@ -0,0 +1,47 @@
---
description: Open a GitHub PR for the current branch using the repo's PR template
---
Open a GitHub pull request for the current branch, filling out the repository's PR template with a description grounded in the actual commits and diff.
## Steps
1. **Verify the branch is pushed**:
- `git status -sb` and `git log @{u}..HEAD --oneline 2>/dev/null` — if there is no upstream or unpushed commits, run `git push -u origin "$(git branch --show-current)"` first
- If the working tree is dirty, stop and tell the user to commit first (suggest `/commit-push`)
2. **Gather context**:
- `git log origin/main..HEAD --oneline` — list of commits going into the PR
- `git diff origin/main...HEAD --stat` then `git diff origin/main...HEAD` — read the actual changes
- Identify the linked issue (from commit messages, branch name, or extra user input: $@) — capture as `Fixes #N` if applicable
3. **Locate the PR template**:
- Check `.github/pull_request_template.md`, `.github/PULL_REQUEST_TEMPLATE.md`, or `docs/pull_request_template.md`
- If none exists, use a minimal `## Description` / `## Type of Change` / `## Checklist` structure
4. **Draft the PR body** by filling out the template:
- **Description**: 13 short paragraphs explaining *what* changed and *why*, grounded in the diff. Include a brief before/after example for new APIs when useful.
- **Fixes #N**: only if there is a real linked issue
- **Type of Change**: tick the single most accurate box with `[x]` (leave others as `[ ]`)
- **Checklist**: tick items that are genuinely true (style, self-review, tests added, docs updated)
- **Additional Information**: bullet list of added / modified files and any backward-compatibility notes
- Remove template sections explicitly marked "remove if not applicable" (e.g. MCP Spec Compliance) when they don't apply
5. **Write the body to a temp file**: `/tmp/pr-body-<branch-or-issue>.md` — never inline a long body via `--body`, always use `--body-file`
6. **Choose the title**: prefer the subject of the primary commit if it already follows Conventional Commits; otherwise craft one in the same style (`<type>(<scope>): <imperative summary>`, ≤72 chars)
7. **Create the PR**:
```
gh pr create \
--title "<title>" \
--body-file /tmp/pr-body-<...>.md \
--base main \
--head "$(git branch --show-current)"
```
Use the repo's actual default branch if it isn't `main` (`gh repo view --json defaultBranchRef -q .defaultBranchRef.name`)
8. **Report the PR URL** returned by `gh` and stop
## Guidelines
- Read the diff and commit messages — do **not** invent features that aren't in the code
- One PR per logical change; if the branch contains unrelated commits, surface that and ask before continuing
- Keep the description focused on reviewer-relevant information (what / why), not a replay of the diff
- Only check checklist boxes that are actually satisfied; leave the rest unchecked rather than lying
- If `gh` is not authenticated (`gh auth status` fails), stop and tell the user
$@
+1 -1
View File
@@ -16,7 +16,7 @@ This prompt uses the `feature_request` GitHub template which requires:
## Steps
1. **Understand the request** from `$@`
1. **Understand the request** from the user input: $@
- What capability is missing?
- What would the ideal behavior look like?
+1 -1
View File
@@ -16,7 +16,7 @@ This repository has structured issue templates. You MUST use the appropriate tem
## Steps
1. **Determine the issue type** from `$@`:
1. **Determine the issue type** from the user input: $@
- Bug → use `--template bug_report`
- Feature → use `--template feature_request`
- Documentation → use `--template documentation`
+61
View File
@@ -0,0 +1,61 @@
---
description: Implement the fix/feature/docs change requested by a GitHub issue
---
Resolve GitHub issue #$1 by reading it, classifying it, and producing the appropriate code or doc change. **Stop once the working tree contains the change** — committing, pushing, and opening a PR are handled by `/commit-push` and `/create-pr`.
## Steps
1. **Fetch the issue**:
- Run: gh issue view $1 --json number,title,body,labels,state,author,comments
- If the issue is closed, stop and ask the user whether to proceed
- Read the **entire** thread including comments — the latest comment often refines the ask
2. **Classify the issue** from labels, title prefix, and body content:
- `bug` / `fix:` → reproduce, then fix
- `enhancement` / `feature` / `feat:` → design, then implement
- `documentation` / `docs:` → locate and update docs
- `question` / `discussion` → answer in a comment, do **not** write code
- Anything else → ask the user how to proceed
3. **Create a working branch** off the default branch:
- `git checkout main && git pull --ff-only`
- Branch name: <type>/$1-<slug> (e.g. `fix/42-borderColor-ignored`, `feat/57-keyboard-clear`, `docs/63-widget-lifecycle`)
4. **Do the work** based on type:
### Bug (`bug` label / `fix:` title)
- Reproduce the failure first (write a failing test if feasible) — if you cannot reproduce, comment on the issue asking for clarification and stop
- Locate the root cause; do not patch symptoms
- Add or extend a regression test that fails before and passes after the fix
- Run `go test ./... -race` and `golangci-lint run`
### Feature (`enhancement` / `feature` label / `feat:` title)
- Re-read the motivation and proposed implementation in the issue body
- For large, ambiguous, or breaking changes, sketch the design in a comment on the issue and wait for sign-off before writing code
- Implement behind sensible defaults; add godoc on every exported symbol
- Add unit tests covering the new behaviour and edge cases
- Update `README.md` / `docs/` if the public surface changed
- Run `go test ./... -race` and `golangci-lint run`
### Documentation (`documentation` label / `docs:` title)
- Open the file/URL referenced in the issue's "Documentation Location"
- Apply the suggested improvement; verify code samples compile (`go build ./...`)
- No tests required, but run `golangci-lint run` if Go files were touched
5. **Report**:
- Branch name (`git branch --show-current`)
- Summary of files changed (`git status -s`) and the diff highlights
- Test/lint results (pass/fail with key output)
- Suggest the next step explicitly:
- `/commit-push` to commit with a Conventional Commit subject (the message should reference `(#$1)` and include `Fixes #$1` so merge auto-closes)
- then `/create-pr $1` to open the pull request
## Guidelines
- This prompt **stops at a clean working tree with the change applied** — do not run `git commit`, `git push`, or `gh pr create`
- If the issue is unclear, post a clarifying comment on the issue and stop; do not guess
- Keep the change scoped to the issue; surface unrelated cleanups separately
- For breaking changes or architecture shifts, propose the design on the issue first and wait for maintainer sign-off
- If the issue is a duplicate or already fixed on `main`, comment with the reference and stop
- Do not close the issue manually — the eventual PR's `Fixes #$1` handles that on merge
+47 -10
View File
@@ -16,28 +16,64 @@ It becomes a `/slug` slash command in the kit input box — typed as `/filename`
description: One-line description shown in autocomplete
---
Body text of the prompt. Use $@ for all user-supplied arguments,
$1 $2 etc. for positional arguments.
Body text of the prompt. Reference user-supplied arguments
with positional placeholders (see "Argument placeholders" below).
```
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
- **Arguments**: `$@` expands to everything the user typed after the slash command name;
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
- **Required args**: kit infers required positional args from the highest `$N` it finds *outside* backtick/tilde code fences — a stray `$2` in active prose means kit will refuse to run without 2 arguments
## Argument placeholders
kit performs shell-style substitution before sending the prompt to the model:
- `$1`, `$2`, … — positional arguments (1-indexed)
- `${1}`, `${2}`, … — same, brace form (use when followed by digits/letters: `${1}_suffix`)
- `$@` — all arguments joined by spaces (zero or more, optional)
- `$+` — all arguments, **at least one required**
- `$ARGUMENTS` / `${ARGUMENTS}` — alias for `$@`
- `${@:N}` — args from the Nth onwards (1-indexed, bash-style)
- `${@:N:L}``L` args starting from the Nth
### ⚠️ Critical: code fences and inline code preserve placeholders verbatim
Anything inside triple-backtick fences, `~~~` fences, or single-backtick `inline` code spans is **left untouched** so example code samples don't get corrupted. That means:
- An inline-coded `gh issue view $1` stays literal `$1` in the model's input ❌
- The same command without backticks: gh issue view $1 → expands to `gh issue view 42`
**Rule of thumb:** if you want a placeholder to substitute, keep it outside backticks and fences. If you want a literal `$1` in the output (e.g. teaching the user shell syntax), put it inside backticks.
### Workarounds for "I want it to look like code AND substitute"
1. **Drop the backticks** around just the placeholder portion — the rest can still read as a command line in prose
2. **Use a 4-space-indented code block** instead of a triple-backtick fence — kit only skips backtick/tilde fences, so indentation-style code blocks still get substitution:
git push -u origin "$(git branch --show-current)"
gh pr create --title "fix: ... (#$1)" --base main
3. **Bind once, reference loosely**: put `Issue: $1` at the top in prose, then leave the backticked examples literal — the model will substitute mentally
## Steps
1. **Understand the workflow** the user described in `$@` — ask a clarifying question if the intent is ambiguous
1. **Understand the workflow** the user described in $@ — ask a clarifying question if the intent is ambiguous
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
3. **Write the description**: one sentence, imperative, fits in autocomplete
4. **Draft the body**:
- Open with a single sentence stating the goal
4. **Decide on arguments**:
- No args needed → omit placeholders entirely
- One required value (issue number, PR url, file path) → use `$1`
- Free-form trailing context → end with a single `$@` line
- Multiple distinct values → use `$1`, `$2`, … and document each at the top
5. **Draft the body**:
- Open with a single sentence stating the goal, weaving in `$1`/`$@` where the value belongs
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
- Be specific: name commands, flags, and file paths where relevant
- End with `$@` on its own line if the user might want to pass context or a hint; omit if the prompt is self-contained
5. **Write the file** to `.kit/prompts/<slug>.md`
6. **Confirm** by showing the final file content and the slash command that activates it
- **Audit every backtick and code fence**: any `$N` or `$@` inside them will not expand — was that intentional? If not, apply one of the workarounds above
6. **Write the file** to `.kit/prompts/<slug>.md`
7. **Verify substitution** by mentally (or actually) replacing `$1`/`$@` with a sample value and confirming every reference resolves — and that the prompt's *own* example snippets don't accidentally bump the required-arg count (wrap illustrative `$N` examples in triple-backtick fences, not 4-space indentation, so `RequiredArgs()` ignores them)
8. **Confirm** by showing the final file content and the slash command that activates it (e.g. `/code-review 42`)
## Guidelines
@@ -45,3 +81,4 @@ $1 $2 etc. for positional arguments.
- Prefer concrete steps over vague instructions
- A prompt that does one thing well beats one that tries to cover every edge case
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
- When in doubt about substitution behaviour, write the file and run `/<slug> testvalue` once to confirm — wrong placement of backticks is the #1 failure mode
+52
View File
@@ -0,0 +1,52 @@
---
description: Audit and update project documentation (README and docs site) for a recent change
---
Review recent code changes, identify all documentation surfaces that should
mention them, and update each one — grounded in the actual diff, not guesses.
## Steps
1. **Identify the change**:
- If the user input ($@) names a commit / PR / branch / topic, use that as the focus
- Otherwise inspect `git log origin/main..HEAD --oneline` and `git diff origin/main...HEAD --stat` to discover what shipped on the current branch
- Read the actual diff (`git diff origin/main...HEAD`) — never document features that aren't in the code
2. **Inventory the doc surfaces**:
- `README.md` at the repo root
- Any docs site (commonly `www/`, `docs/`, `site/`) — list its pages and identify the one(s) most thematically related to the change
- Inline godoc / API reference comments on the new exported symbols
- `CHANGELOG.md` if the project keeps one
- Any `examples/` directory entries that demonstrate the affected area
3. **Audit each surface** with `grep`:
- Search for the names of related existing APIs (e.g. if you added `IterTools`, grep for `ListTools`) to find every page that already discusses the area
- Decide for each hit: does it need a cross-reference, a side-by-side comparison, or to stay untouched?
4. **Decide where new content lives**:
- Prefer extending an existing page over creating a new one
- For a docs site, place new sections near related content (check the page's `## Heading` outline first)
- Skip surfaces that genuinely don't apply (e.g. a server-focused README for a client-only change) and say so explicitly
5. **Draft the updates**:
- Lead with a one-sentence statement of what's new and why
- Show concrete code examples copied from real signatures — verify against the source files
- Include a comparison / "when to use which" table when adding an alternative to an existing API
- Note backwards-compatibility behaviour if relevant
6. **Verify the docs build** before committing:
- For vocs / docusaurus / mkdocs sites, run the local build command (e.g. `npx vocs build`, `mkdocs build`) and fix any MDX/markdown errors
- For godoc, run `go vet ./...` and `go doc <pkg> <Symbol>` to sanity-check rendering
7. **Report**:
- List every file changed and every file deliberately left alone (with a one-line reason)
- Suggest the next step (typically `/commit-push`) — do not auto-commit unless asked
## Guidelines
- Read the diff before writing anything — invented API names erode trust faster than missing docs
- One change per doc commit; keep doc updates separate from code changes when possible
- Match the existing voice and formatting of each surface (headings, code-fence languages, table styles)
- Prefer linking between pages over duplicating content
$@
-8
View File
@@ -1,8 +0,0 @@
{
"$schema": "https://opencode.ai/config.json",
"permission": {
"external_directory": {
"~/go/**": "deny"
}
}
}
-80
View File
@@ -1,80 +0,0 @@
# Autoscroll Fix - Final Summary
## Root Cause
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
### The Problem
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
2. The `Height()` method returns `0` when `cachedRender == ""`
3. `GotoBottom()` was calling:
```go
itemHeight := item.Height() // Returns 0 for reasoning
if itemHeight == 0 {
item.Render(s.width) // Renders but doesn't cache (reasoning)
itemHeight = item.Height() // Still returns 0!
}
```
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
## The Solution
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
```go
// OLD: item.Height() which checks cached render
itemHeight := item.Height()
if itemHeight == 0 {
item.Render(s.width)
itemHeight = item.Height() // Still might be 0!
}
// NEW: Calculate from rendered string directly
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
```
This works for **all** items regardless of whether they cache their render or not.
## Files Changed
### `internal/ui/scrolllist.go`
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
### `internal/ui/model.go`
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
## Testing Results
✅ **Test prompt**: "explore this repo"
**Before fix**:
- Autoscroll stopped after reasoning block completed
- Viewport stuck showing end of reasoning ("Thought for 203ms")
- Assistant response streamed off-screen below
**After fix**:
- Autoscroll works throughout reasoning block
- Autoscroll continues during reasoning → assistant transition
- Viewport stays at bottom showing latest assistant content
- Final position shows end of response (build commands section)
## Behavior Verified
1. ✅ Streaming text auto-scrolls to bottom
2. ✅ Works across reasoning → assistant transition
3. ✅ Manual scroll up (PgUp) disables autoscroll
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
5. ✅ Accurate positioning with no offset errors
## Performance Note
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
- `Render()` is already optimized with caching for non-reasoning items
- `GotoBottom()` is only called during content updates (not every frame)
- Reasoning blocks need to render anyway for live duration updates
- This matches iteratr's approach of ensuring items are rendered before height calculations
+181 -44
View File
@@ -18,7 +18,8 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
## Features
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, subagent - no MCP overhead
- **Built-in Core Tools**: bash (with interactive sudo password prompt), read, write, edit, grep, find, ls, subagent - no MCP overhead
- **Smart @ Attachments**: Binary files auto-detected via MIME type, MCP resources via `@mcp:server:uri`
- **MCP Integration**: Connect external MCP servers for expanded capabilities
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
@@ -28,7 +29,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
- **Session Management**: Tree-based conversation history with branching support
- **Non-Interactive Mode**: Script-friendly positional args with JSON output
- **ACP Server**: Run Kit as an [Agent Client Protocol](https://agentclientprotocol.com) agent over stdio
- **Go SDK**: Embed Kit in your own applications
- **Go SDK**: Embed Kit in your own applications with full agent lifecycle events (30+ event types) and behavior-modifying hooks
## Installation
@@ -125,8 +126,13 @@ model: anthropic/claude-sonnet-latest
max-tokens: 4096
temperature: 0.7
stream: true
thinking-level: off # off, none, minimal, low, medium, high
```
All of the above keys can also be set programmatically via the SDK
(`kit.Options.MaxTokens`, `Options.Temperature`, `Options.ThinkingLevel`, etc.)
without touching config files — see [SDK options](#with-options).
### Environment Variables
```bash
@@ -151,6 +157,16 @@ mcpServers:
search:
type: remote
url: "https://mcp.example.com/search"
pubmed:
type: remote
url: "https://pubmed.mcp.example.com"
noOAuth: true # skip OAuth for public servers that don't require auth
builds:
type: remote
url: "https://builds.mcp.example.com"
tasksMode: always # async task execution — see MCP Tasks below
```
## CLI Reference
@@ -186,12 +202,14 @@ mcpServers:
--no-prompt-templates Disable prompt template loading
# Generation parameters
--max-tokens Maximum tokens in response (default: 4096)
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
--temperature Randomness 0.0-1.0 (default: 0.7)
--top-p Nucleus sampling 0.0-1.0 (default: 0.95)
--top-k Limit top K tokens (default: 40)
--stop-sequences Custom stop sequences (comma-separated)
--thinking-level Extended thinking level: off, minimal, low, medium, high (default: off)
--frequency-penalty Penalize frequent tokens 0.0-2.0 (default: 0.0)
--presence-penalty Penalize present tokens 0.0-2.0 (default: 0.0)
--thinking-level Extended thinking level: off, none, minimal, low, medium, high (default: off)
# System
--config Config file path (default: ~/.kit.yml)
@@ -203,9 +221,10 @@ mcpServers:
```bash
# Authentication (for OAuth-enabled providers)
kit auth login [provider] # Start OAuth flow (e.g., anthropic)
kit auth logout [provider] # Remove credentials for provider
kit auth status # Check authentication status
kit auth login [provider] # Start OAuth flow (e.g., anthropic)
kit auth login [provider] --set-default # Set provider's default model as system default
kit auth logout [provider] # Remove credentials for provider
kit auth status # Check authentication status
# Model database
kit models [provider] # List available models (optionally filter by provider)
@@ -287,7 +306,7 @@ kit -e examples/extensions/minimal.go
### Extension Capabilities
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
**Custom Components**:
- **Tools**: Add new tools the LLM can invoke
@@ -317,39 +336,40 @@ kit -e examples/extensions/minimal.go
See the `examples/extensions/` directory:
- `minimal.go` - Clean UI with custom footer
- `auto-commit.go` - Auto-commit on shutdown
- `bookmark.go` - Bookmark conversations
- `branded-output.go` - Branded output rendering
- `compact-notify.go` - Notification on compaction
- `confirm-destructive.go` - Confirm destructive operations
- `context-inject.go` - Inject context into conversations
- `conversation-manager.go` - **NEW** Tree navigation, branch summarization, and fresh context loops
- `custom-editor-demo.go` - Vim-like modal editor
- `dev-reload.go` - Development live-reload
- `header-footer-demo.go` - Custom headers and footers
- `inline-bash.go` - Inline bash execution
- `interactive-shell.go` - Interactive shell integration
- `kit-kit.go` - Kit-in-Kit (sub-agent spawning)
- `lsp-diagnostics.go` - LSP diagnostic integration
- `notify.go` - Desktop notifications
- `overlay-demo.go` - Modal dialogs
- `permission-gate.go` - Permission gating for tools
- `pirate.go` - Pirate-themed personality
- `plan-mode.go` - Read-only planning mode
- `project-rules.go` - Project-specific rules
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
- `prompt-templates.go` - **NEW** Frontmatter-driven templates with model switching and skill injection
- `protected-paths.go` - Path protection for sensitive files
- `subagent-widget.go` - Multi-agent orchestration with status widget
- `subagent-test.go` - Subagent testing utilities
- `summarize.go` - Conversation summarization
- `tool-logger.go` - Log all tool calls
- `neon-theme.go` - Custom theme registration and switching
- `tool-renderer-demo.go` - Custom tool call rendering
- `widget-status.go` - Persistent status widgets
- [`minimal.go`](examples/extensions/minimal.go) - Clean UI with custom footer
- [`auto-commit.go`](examples/extensions/auto-commit.go) - Auto-commit on shutdown
- [`bookmark.go`](examples/extensions/bookmark.go) - Bookmark conversations
- [`branded-output.go`](examples/extensions/branded-output.go) - Branded output rendering
- [`bridge-demo.go`](examples/extensions/bridge_demo.go) - Bridged SDK API demo (tree navigation, skills, templates, model resolution)
- [`compact-notify.go`](examples/extensions/compact-notify.go) - Notification on compaction
- [`confirm-destructive.go`](examples/extensions/confirm-destructive.go) - Confirm destructive operations
- [`context-inject.go`](examples/extensions/context-inject.go) - Inject context into conversations
- [`conversation-manager.go`](examples/extensions/conversation-manager.go) - **NEW** Tree navigation, branch summarization, and fresh context loops
- [`custom-editor-demo.go`](examples/extensions/custom-editor-demo.go) - Vim-like modal editor
- [`dev-reload.go`](examples/extensions/dev-reload.go) - Development live-reload
- [`header-footer-demo.go`](examples/extensions/header-footer-demo.go) - Custom headers and footers
- [`inline-bash.go`](examples/extensions/inline-bash.go) - Inline bash execution
- [`interactive-shell.go`](examples/extensions/interactive-shell.go) - Interactive shell integration
- [`kit-kit.go`](examples/extensions/kit-kit.go) - Kit-in-Kit (sub-agent spawning)
- [`lsp-diagnostics.go`](examples/extensions/lsp-diagnostics.go) - LSP diagnostic integration
- [`notify.go`](examples/extensions/notify.go) - Desktop notifications
- [`overlay-demo.go`](examples/extensions/overlay-demo.go) - Modal dialogs
- [`permission-gate.go`](examples/extensions/permission-gate.go) - Permission gating for tools
- [`pirate.go`](examples/extensions/pirate.go) - Pirate-themed personality
- [`plan-mode.go`](examples/extensions/plan-mode.go) - Read-only planning mode
- [`project-rules.go`](examples/extensions/project-rules.go) - Project-specific rules
- [`prompt-demo.go`](examples/extensions/prompt-demo.go) - Interactive prompts (select/confirm/input)
- [`prompt-templates.go`](examples/extensions/prompt-templates.go) - **NEW** Frontmatter-driven templates with model switching and skill injection
- [`protected-paths.go`](examples/extensions/protected-paths.go) - Path protection for sensitive files
- [`subagent-widget.go`](examples/extensions/subagent-widget.go) - Multi-agent orchestration with status widget
- [`subagent-test.go`](examples/extensions/subagent-test.go) - Subagent testing utilities
- [`summarize.go`](examples/extensions/summarize.go) - Conversation summarization
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
Also see `.kit/extensions/go-edit-lint.go` (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
### Loading Extensions
@@ -406,7 +426,7 @@ func TestMyExtension(t *testing.T) {
- `AssertPrinted()`, `AssertPrintedContains()` — Verify output
- `AssertToolRegistered()`, `AssertCommandRegistered()` — Verify registration
See `examples/extensions/tool-logger_test.go` for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
See [`examples/extensions/tool-logger_test.go`](examples/extensions/tool-logger_test.go) for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
### Prompt Templates
@@ -428,10 +448,13 @@ Focus on $1 specifically.
**Argument placeholders:**
- `$1`, `$2`, etc. — Individual arguments
- `$@` or `$ARGUMENTS` — All arguments
- `$@` or `$ARGUMENTS` — All arguments (zero or more)
- `$+` — All arguments (one or more required; error if none given)
- `${@:2}` — Arguments from position 2 onwards
- `${@:1:3}` — 3 arguments starting at position 1
Placeholders inside fenced code blocks (```) and inline code spans are ignored.
Disable templates with `--no-prompt-templates` or load a specific template with `--prompt-template <name>`.
## Session Management
@@ -480,6 +503,15 @@ During an interactive session, use these slash commands:
| `/fork` | Fork to new session from an earlier message |
| `/new` | Start a fresh session |
### Keyboard Shortcuts
| Shortcut | Description |
|----------|-------------|
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
| `` / `` | Navigate prompt history |
## Go SDK
Embed Kit in your Go applications:
@@ -525,6 +557,20 @@ host, err := kit.New(ctx, &kit.Options{
Streaming: true,
Quiet: true,
// Generation parameters (override env/config/per-model defaults)
MaxTokens: 16384, // 0 = auto-resolve (env → config → per-model → 8192 floor)
ThinkingLevel: "medium", // "off", "none", "minimal", "low", "medium", "high"
Temperature: ptr(float32(0.2)), // pointer so 0.0 != unset; nil = provider default
TopP: nil, // nil = leave provider/per-model default
TopK: nil,
FrequencyPenalty: nil,
PresencePenalty: nil,
// Provider configuration (override env/config without reaching into viper)
ProviderAPIKey: "sk-...", // "" = use config / provider env var
ProviderURL: "https://proxy.internal/v1", // "" = provider default
TLSSkipVerify: false, // only takes effect when true
// Session options
SessionPath: "./session.jsonl", // Open specific session
Continue: true, // Resume most recent session
@@ -545,6 +591,76 @@ host, err := kit.New(ctx, &kit.Options{
})
```
**Generation & provider fields** (added in v0.55+) let SDK consumers configure
Kit entirely in-code without `viper.Set()` workarounds or shipping a `.kit.yml`.
Precedence is `Options` > `KIT_*` env vars > `.kit.yml` > per-model defaults
(`modelSettings` / `customModels`) > provider-level defaults. Sampling params
are pointer types so explicit `0.0` is distinguishable from "leave alone"; a
non-zero `MaxTokens` suppresses automatic right-sizing the same way `--max-tokens`
does on the CLI.
### MCP OAuth (remote MCP servers)
When a remote MCP server returns 401, Kit runs the full OAuth flow (dynamic
client registration → PKCE → token exchange → persistence) but delegates the
user-facing step — showing the authorization URL and receiving the callback —
to an `MCPAuthHandler` that you pass explicitly via `Options.MCPAuthHandler`.
If nil, OAuth is disabled and the authorization-required error surfaces to the
caller; the SDK never auto-opens a browser or binds a localhost port.
```go
// CLI/TUI apps: opens the system browser + prints status to stderr.
authHandler, _ := kit.NewCLIMCPAuthHandler()
defer authHandler.Close()
host, _ := kit.New(ctx, &kit.Options{
MCPAuthHandler: authHandler,
})
// Custom UX: reuse the SDK's port + callback server, supply your own
// presentation via OnAuthURL (TUI modal, QR code, web redirect, etc.).
// h, _ := kit.NewDefaultMCPAuthHandler()
// h.OnAuthURL = func(server, authURL string) { myUI.Show(server, authURL) }
//
// Full control (web apps, daemons): implement kit.MCPAuthHandler yourself —
// no localhost binding, no side effects.
```
Tokens are persisted to `$XDG_CONFIG_HOME/.kit/mcp_tokens.json` by default; swap
in a custom `MCPTokenStoreFactory` for encrypted, DB-backed, or in-memory
storage. See the [SDK options docs](/sdk/options#mcp-oauth-authorization) for
the full matrix.
### MCP Tasks (long-running tools)
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
during `initialize`, so cooperating MCP servers can respond to `tools/call`
with a `taskId` instead of blocking the connection. Kit then polls
`tasks/get` / `tasks/result` until the task reaches a terminal state, and
best-effort `tasks/cancel`s on context cancellation.
Defaults are safe — a server that doesn't advertise task capability runs
synchronously, exactly as before. Opt in per server via `tasksMode` in
`.kit.yml` (`auto` | `never` | `always`) or programmatically through the SDK:
```go
host, _ := kit.New(ctx, &kit.Options{
MCPTaskMode: map[string]kit.MCPTaskMode{
"build-server": kit.MCPTaskModeAlways,
},
MCPTaskTimeout: 15 * time.Minute,
MCPTaskProgress: func(p kit.MCPTaskProgress) {
log.Printf("%s: %s", p.TaskID, p.Status)
},
})
tasks, _ := host.ListMCPTasks(ctx, "build-server")
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
```
See the [configuration docs](/configuration#mcp-tasks-long-running-tools) and
[SDK options → MCP Tasks](/sdk/options#mcp-tasks) for the full surface.
### Custom Tools
Create custom tools with automatic schema generation — no external dependencies needed:
@@ -565,7 +681,28 @@ host, _ := kit.New(ctx, &kit.Options{
})
```
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
Use `kit.NewParallelTool` for tools safe to run concurrently. Binary data (images, audio, etc.) in `ToolOutput.Data` is automatically forwarded to the LLM when `MediaType` is set. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
#### Return Helpers
| Helper | Description |
| --- | --- |
| `kit.TextResult(content)` | Successful text result |
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
| `kit.ImageResult(content, data, mediaType)` | Image result with binary data (e.g. `"image/png"`) |
| `kit.MediaResult(content, data, mediaType)` | Non-image media result (e.g. `"audio/mpeg"`) |
#### ToolOutput Fields
```go
kit.ToolOutput{
Content: "result text", // text returned to the LLM
IsError: false, // true = LLM sees this as an error
Data: pngBytes, // optional binary data (images, audio)
MediaType: "image/png", // MIME type for binary Data
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
}
```
### With Callbacks
@@ -582,7 +719,7 @@ unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
})
defer unsub2()
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
unsub3 := host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
print(e.Chunk)
})
defer unsub3()
+64 -4
View File
@@ -11,6 +11,7 @@ import (
"charm.land/huh/v2"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/ui"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/spf13/cobra"
)
@@ -54,9 +55,13 @@ Available providers:
- anthropic: Anthropic Claude API (OAuth)
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
Example:
Flags:
--set-default Set this provider's default model as the system default
Examples:
kit auth login anthropic
kit auth login openai`,
kit auth login openai
kit auth login openai --set-default`,
Args: cobra.ExactArgs(1),
RunE: runAuthLogin,
}
@@ -99,10 +104,43 @@ Example:
RunE: runAuthStatus,
}
var (
loginSetDefault bool
)
// defaultModels maps providers to their recommended default models.
// These are used when --set-default flag is passed to auth login.
var defaultModels = map[string]string{
"anthropic": "anthropic/claude-sonnet-4-5-20250929",
"openai": "openai/gpt-5.4",
}
// setDefaultModelIfRequested sets the default model for the given provider
// if the --set-default flag was provided.
func setDefaultModelIfRequested(provider string) error {
if !loginSetDefault {
return nil
}
model, ok := defaultModels[provider]
if !ok {
return fmt.Errorf("no default model configured for provider: %s", provider)
}
if err := ui.SaveModelPreference(model); err != nil {
return fmt.Errorf("failed to save model preference: %w", err)
}
fmt.Printf("\n✓ Set default model to: %s\n", model)
return nil
}
func init() {
authCmd.AddCommand(authLoginCmd)
authCmd.AddCommand(authLogoutCmd)
authCmd.AddCommand(authStatusCmd)
authLoginCmd.Flags().BoolVar(&loginSetDefault, "set-default", false, "Set this provider's default model as the system default after login")
}
func runAuthLogin(cmd *cobra.Command, args []string) error {
@@ -288,6 +326,17 @@ func loginAnthropic() error {
fmt.Println("\n🎉 Your OAuth credentials will now be used for Anthropic API calls.")
fmt.Println("💡 You can check your authentication status with: kit auth status")
// Set default model if requested
if err := setDefaultModelIfRequested("anthropic"); err != nil {
return err
}
// Remind users how to set this as default if they didn't use --set-default
if !loginSetDefault {
fmt.Println("\n💡 To set Anthropic as your default model, run:")
fmt.Println(" kit auth login anthropic --set-default")
}
return nil
}
@@ -454,6 +503,17 @@ func loginOpenAI() error {
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
fmt.Println("💡 You can check your authentication status with: kit auth status")
// Set default model if requested
if err := setDefaultModelIfRequested("openai"); err != nil {
return err
}
// Remind users how to set this as default if they didn't use --set-default
if !loginSetDefault {
fmt.Println("\n💡 To set OpenAI as your default model, run:")
fmt.Println(" kit auth login openai --set-default")
}
return nil
}
@@ -504,13 +564,13 @@ func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
}
// Return success page
w.Header().Set("Content-Type", "text/html")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
<html>
<head><title>Authentication Successful</title></head>
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
<h1> Authentication Successful</h1>
<h1>&#10003; Authentication Successful</h1>
<p>You can close this window and return to the terminal.</p>
</body>
</html>`)
+140 -30
View File
@@ -19,6 +19,7 @@ import (
"github.com/mark3labs/kit/internal/prompts"
"github.com/mark3labs/kit/internal/ui"
"github.com/mark3labs/kit/internal/ui/commands"
"github.com/mark3labs/kit/internal/ui/progress"
"github.com/mark3labs/kit/internal/watcher"
kit "github.com/mark3labs/kit/pkg/kit"
"github.com/spf13/cobra"
@@ -33,13 +34,18 @@ var (
providerURL string
providerAPIKey string
debugMode bool
positionalPrompt string // set by processPositionalArgs from CLI positional args
quietFlag bool
jsonFlag bool
noExitFlag bool
maxSteps int
streamFlag bool // Enable streaming output
autoCompactFlag bool // Enable auto-compaction near context limit
positionalPrompt string // set by processPositionalArgs from CLI positional args
positionalFiles []ui.FilePart // binary @file parts from processPositionalArgs
// MCP resource callbacks, set in runNormalMode, consumed by runInteractiveModeBubbleTea.
mcpGetResources func() []ui.FileSuggestion
mcpResourceReader ui.MCPResourceReader
quietFlag bool
jsonFlag bool
noExitFlag bool
maxSteps int
streamFlag bool // Enable streaming output
autoCompactFlag bool // Enable auto-compaction near context limit
// Session management
sessionPath string
@@ -291,14 +297,14 @@ func init() {
flags.BoolVar(&noPromptTemplates, "no-prompt-templates", false, "disable prompt template discovery")
// Model generation parameters
flags.IntVar(&maxTokens, "max-tokens", 4096, "maximum number of tokens in the response")
flags.IntVar(&maxTokens, "max-tokens", 8192, "maximum number of output tokens per response (auto-raised up to 32768 for models with higher known output limits; see internal/models/embedded_models.json)")
flags.Float32Var(&temperature, "temperature", 0.7, "controls randomness in responses (0.0-1.0)")
flags.Float32Var(&topP, "top-p", 0.95, "controls diversity via nucleus sampling (0.0-1.0)")
flags.Int32Var(&topK, "top-k", 40, "controls diversity by limiting top K tokens to sample from")
flags.Float32Var(&frequencyPenalty, "frequency-penalty", 0.0, "penalizes tokens based on frequency of appearance (0.0-2.0)")
flags.Float32Var(&presencePenalty, "presence-penalty", 0.0, "penalizes tokens based on whether they have appeared (0.0-2.0)")
flags.StringSliceVar(&stopSequences, "stop-sequences", nil, "custom stop sequences (comma-separated)")
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, minimal, low, medium, high")
flags.StringVar(&thinkingLevel, "thinking-level", "off", "extended thinking level: off, none, minimal, low, medium, high")
// Ollama-specific parameters
flags.Int32Var(&numGPU, "num-gpu-layers", -1, "number of model layers to offload to GPU for Ollama models (-1 for auto-detect)")
@@ -338,12 +344,14 @@ func init() {
}
// processPositionalArgs separates positional CLI arguments into @file
// attachments and prompt text. File content is read and prepended to
// positionalPrompt so the agent receives it. Positional args are the primary
// way to run non-interactive mode:
// attachments and prompt text. Text file content is read and prepended to
// positionalPrompt; binary files (images, audio) are stored in positionalFiles
// for multimodal submission. Positional args are the primary way to run
// non-interactive mode:
//
// kit "Explain this codebase"
// kit @code.ts @test.ts "Review these files"
// kit @screenshot.png "What's in this image?"
func processPositionalArgs(args []string) {
cwd, err := os.Getwd()
if err != nil {
@@ -362,14 +370,17 @@ func processPositionalArgs(args []string) {
}
// Build file content prefix from @file arguments.
// Text files are XML-wrapped inline; binary files become multimodal parts.
var fileContent strings.Builder
for _, token := range fileTokens {
expanded := ui.ProcessFileAttachments(token, cwd)
if expanded != token {
// File was resolved — add it.
fileContent.WriteString(expanded)
result := ui.ProcessFileAttachments(token, cwd)
if result.ProcessedText != token {
// Text file was resolved — add it.
fileContent.WriteString(result.ProcessedText)
fileContent.WriteString("\n\n")
}
// Collect binary file parts for multimodal submission.
positionalFiles = append(positionalFiles, result.FileParts...)
}
// Combine: positional prompt text is appended to any existing --prompt
@@ -753,10 +764,11 @@ func runNormalMode(ctx context.Context) error {
}
},
CLI: &kit.CLIOptions{
MCPConfig: mcpConfig,
ShowSpinner: true,
SpinnerFunc: spinnerFunc,
UseBufferedLogger: true,
MCPConfig: mcpConfig,
ShowSpinner: true,
SpinnerFunc: spinnerFunc,
UseBufferedLogger: true,
ProgressReaderFunc: progress.NewProgressReadCloser,
},
}
if resumeFlag {
@@ -1704,6 +1716,81 @@ func runNormalMode(ctx context.Context) error {
return kitInstance.GetMCPToolCount()
}
// Build MCP prompt provider callbacks for the TUI.
// Convert kit.MCPPrompt → ui.MCPPromptInfo for the UI layer.
convertMCPPromptsForUI := func() []ui.MCPPromptInfo {
prompts := kitInstance.ListMCPPrompts()
if len(prompts) == 0 {
return nil
}
result := make([]ui.MCPPromptInfo, len(prompts))
for i, p := range prompts {
args := make([]ui.MCPPromptArgInfo, len(p.Arguments))
for j, a := range p.Arguments {
args[j] = ui.MCPPromptArgInfo{
Name: a.Name,
Description: a.Description,
Required: a.Required,
}
}
result[i] = ui.MCPPromptInfo{
Name: p.Name,
Description: p.Description,
Arguments: args,
ServerName: p.ServerName,
}
}
return result
}
mcpPrompts := convertMCPPromptsForUI()
getMCPPrompts := func() []ui.MCPPromptInfo {
return convertMCPPromptsForUI()
}
expandMCPPrompt := func(serverName, promptName string, args map[string]string) (*ui.MCPPromptExpandResult, error) {
result, err := kitInstance.GetMCPPrompt(context.Background(), serverName, promptName, args)
if err != nil {
return nil, err
}
msgs := make([]ui.MCPPromptMessageInfo, len(result.Messages))
for i, m := range result.Messages {
msgs[i] = ui.MCPPromptMessageInfo{
Role: m.Role,
Content: m.Content,
FileParts: m.FileParts,
}
}
return &ui.MCPPromptExpandResult{Messages: msgs}, nil
}
// MCP resource callbacks for @ autocomplete and submit-time resolution.
getMCPResources := func() []ui.FileSuggestion {
resources := kitInstance.ListMCPResources()
suggestions := make([]ui.FileSuggestion, len(resources))
for i, r := range resources {
suggestions[i] = ui.FileSuggestion{
RelPath: r.Name,
IsMCPResource: true,
MCPServerName: r.ServerName,
MCPResourceURI: r.URI,
MCPMIMEType: r.MIMEType,
Score: 100, // default score, filtered later
}
}
return suggestions
}
mcpResourceReaderFn := func(serverName, uri string) (string, []byte, string, bool, error) {
content, err := kitInstance.ReadMCPResource(context.Background(), serverName, uri)
if err != nil {
return "", nil, "", false, err
}
return content.Text, content.BlobData, content.MIMEType, content.IsBlob, nil
}
// Store MCP resource callbacks at package level for consumption by
// runInteractiveModeBubbleTea and runNonInteractiveModeApp.
mcpGetResources = getMCPResources
mcpResourceReader = mcpResourceReaderFn
// Start a goroutine that waits for background MCP tool loading to
// complete and notifies the TUI so it can refresh tool names and counts.
if len(mcpConfig.MCPServers) > 0 {
@@ -1840,7 +1927,7 @@ func runNormalMode(ctx context.Context) error {
// Check if running in non-interactive mode
if positionalPrompt != "" {
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
return runNonInteractiveModeApp(ctx, appInstance, cli, positionalPrompt, quietFlag, jsonFlag, noExitFlag, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI)
}
// Quiet mode is not allowed in interactive mode
@@ -1848,7 +1935,7 @@ func runNormalMode(ctx context.Context) error {
return fmt.Errorf("--quiet requires a prompt")
}
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, parsedProvider, kitInstance.GetLoadingMessage(), serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModelForUI, emitModelChangeForUI, kitInstance.IsReasoningModel(), kitInstance.GetThinkingLevel(), setThinkingLevelForUI, switchSessionForUI, reloadExtensionsForUI, startupExtensionMessages)
}
// runNonInteractiveModeApp executes a single prompt via the app layer and exits,
@@ -1861,15 +1948,33 @@ func runNormalMode(ctx context.Context) error {
//
// When --no-exit is set, after the prompt completes the interactive BubbleTea
// TUI is started so the user can continue the conversation.
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui.CLI, prompt string, quiet, jsonOutput, noExit bool, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error) error {
// Expand @file references in the prompt before sending to the agent.
// Text files are XML-inlined; binary files are extracted as multimodal parts.
var fileParts []kit.LLMFilePart
if cwd, err := os.Getwd(); err == nil {
prompt = ui.ProcessFileAttachments(prompt, cwd)
result := ui.ProcessFileAttachments(prompt, cwd, mcpResourceReader)
prompt = result.ProcessedText
for _, fp := range result.FileParts {
fileParts = append(fileParts, kit.LLMFilePart{
Filename: fp.Filename,
Data: fp.Data,
MediaType: fp.MediaType,
})
}
}
// Also include binary files from processPositionalArgs (CLI @file args).
for _, fp := range positionalFiles {
fileParts = append(fileParts, kit.LLMFilePart{
Filename: fp.Filename,
Data: fp.Data,
MediaType: fp.MediaType,
})
}
if jsonOutput {
// JSON mode: no intermediate display, structured JSON output.
result, err := appInstance.RunOnceResult(ctx, prompt)
result, err := appInstance.RunOnceResultWithFiles(ctx, prompt, fileParts)
if err != nil {
writeJSONError(err)
return err
@@ -1881,7 +1986,7 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
fmt.Println(string(data))
} else if quiet {
// Quiet mode: no intermediate display, just print final response.
if err := appInstance.RunOnce(ctx, prompt); err != nil {
if err := appInstance.RunOnceWithFiles(ctx, prompt, fileParts); err != nil {
return err
}
} else if cli != nil {
@@ -1890,21 +1995,21 @@ func runNonInteractiveModeApp(ctx context.Context, appInstance *app.App, cli *ui
// Route events through the shared CLI event handler.
eventHandler := ui.NewCLIEventHandler(cli, modelName)
err := appInstance.RunOnceWithDisplay(ctx, prompt, eventHandler.Handle)
err := appInstance.RunOnceWithDisplayAndFiles(ctx, prompt, eventHandler.Handle, fileParts)
eventHandler.Cleanup()
if err != nil {
return err
}
} else {
// No CLI available (shouldn't happen in non-quiet mode, but be safe).
if err := appInstance.RunOnce(ctx, prompt); err != nil {
if err := appInstance.RunOnceWithFiles(ctx, prompt, fileParts); err != nil {
return err
}
}
// If --no-exit was requested, hand off to the interactive TUI.
if noExit {
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
return runInteractiveModeBubbleTea(ctx, appInstance, modelName, providerName, loadingMessage, serverNames, toolNames, mcpToolCount, extensionToolCount, usageTracker, extCommands, promptTemplates, contextPaths, skillItems, getPromptTemplates, getSkillItems, getToolNames, getMCPToolCount, mcpPrompts, getMCPPrompts, expandMCPPrompt, getWidgets, getHeader, getFooter, getToolRenderer, getEditorInterceptor, getUIVisibility, getStatusBarEntries, emitBeforeFork, emitBeforeSessionSwitch, getGlobalShortcuts, getExtensionCommands, setModel, emitModelChange, isReasoningModel, thinkingLevel, setThinkingLevel, switchSession, reloadExtensions, nil)
}
return nil
@@ -2002,7 +2107,7 @@ func writeJSONError(err error) {
// 4. Calls program.Run() which blocks until the user quits (Ctrl+C or /quit).
//
// SetupCLI is not used for interactive mode; the TUI (AppModel) handles its own rendering.
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelName, providerName, loadingMessage string, serverNames, toolNames []string, mcpToolCount, extensionToolCount int, usageTracker *ui.UsageTracker, extCommands []commands.ExtensionCommand, promptTemplates []*prompts.PromptTemplate, contextPaths []string, skillItems []ui.SkillItem, getPromptTemplates func() []*prompts.PromptTemplate, getSkillItems func() []ui.SkillItem, getToolNames func() []string, getMCPToolCount func() int, mcpPrompts []ui.MCPPromptInfo, getMCPPrompts func() []ui.MCPPromptInfo, expandMCPPrompt func(string, string, map[string]string) (*ui.MCPPromptExpandResult, error), getWidgets func(string) []ui.WidgetData, getHeader, getFooter func() *ui.WidgetData, getToolRenderer func(string) *ui.ToolRendererData, getEditorInterceptor func() *ui.EditorInterceptor, getUIVisibility func() *ui.UIVisibility, getStatusBarEntries func() []ui.StatusBarEntryData, emitBeforeFork func(string, bool, string) (bool, string), emitBeforeSessionSwitch func(string) (bool, string), getGlobalShortcuts func() map[string]func(), getExtensionCommands func() []commands.ExtensionCommand, setModel func(string) error, emitModelChange func(string, string, string), isReasoningModel bool, thinkingLevel string, setThinkingLevel func(string) error, switchSession func(string) error, reloadExtensions func() error, startupExtensionMessages []string) error {
// Redirect all log output (stdlib and charm) to a file so that log
// messages don't write to stderr and corrupt the TUI. Bubble Tea
// captures stdout for rendering; any stray stderr output from
@@ -2041,6 +2146,9 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
ExtensionCommands: extCommands,
PromptTemplates: promptTemplates,
GetPromptTemplates: getPromptTemplates,
MCPPrompts: mcpPrompts,
GetMCPPrompts: getMCPPrompts,
ExpandMCPPrompt: expandMCPPrompt,
ContextPaths: contextPaths,
SkillItems: skillItems,
GetSkillItems: getSkillItems,
@@ -2064,6 +2172,8 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
SwitchSession: switchSession,
ReloadExtensions: reloadExtensions,
ShowSessionPicker: resumeFlag,
GetMCPResources: mcpGetResources,
MCPResourceReader: mcpResourceReader,
})
program := tea.NewProgram(appModel)
+57 -19
View File
@@ -10,13 +10,21 @@ import (
"kit/ext"
)
// re matches !{...} with non-greedy content.
var re = regexp.MustCompile(`!\{([^}]+)\}`)
// Init expands inline bash expressions in user prompts before they reach the
// LLM. Text like !{git branch --show-current} is replaced with the command's
// stdout.
// LLM. Text like !{git rev-parse --abbrev-ref HEAD} is replaced with the
// command's stdout.
//
// In interactive mode the expansion happens at submit time via an editor
// interceptor, so the expanded text is also visible in the user message
// block on screen. In non-interactive mode (CLI, script, queue) the
// expansion happens via OnInput transform.
//
// Examples:
//
// "Fix the tests on !{git branch --show-current}"
// "Fix the tests on !{git rev-parse --abbrev-ref HEAD}"
// → "Fix the tests on main"
//
// "The current directory is !{pwd}"
@@ -24,29 +32,59 @@ import (
//
// Usage: kit -e examples/extensions/inline-bash.go
func Init(api ext.API) {
// Matches !{...} with non-greedy content.
re := regexp.MustCompile(`!\{([^}]+)\}`)
// ── Interactive mode: editor interceptor ──────────────────────────
// Intercept Enter / Ctrl+D so we can expand !{...} BEFORE the
// SubmitMsg is created. This ensures the expanded text appears in
// the user message block on screen as well as in the LLM prompt.
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
if !ctx.Interactive {
return
}
ctx.SetEditor(ext.EditorConfig{
HandleKey: func(key string, currentText string) ext.EditorKeyAction {
if (key == "enter" || key == "ctrl+d") && re.MatchString(currentText) {
expanded := expand(currentText)
// Clear the textarea asynchronously — calling
// SetEditorText synchronously from inside Update()
// would deadlock the BubbleTea event loop.
go ctx.SetEditorText("")
return ext.EditorKeyAction{
Type: ext.EditorKeySubmit,
SubmitText: expanded,
}
}
return ext.EditorKeyAction{Type: ext.EditorKeyPassthrough}
},
})
})
// ── Non-interactive fallback: OnInput transform ──────────────────
// For CLI, script, and queue sources the editor interceptor is not
// active, so we fall back to OnInput which still rewrites the
// prompt text sent to the LLM.
api.OnInput(func(ev ext.InputEvent, ctx ext.Context) *ext.InputResult {
if !re.MatchString(ev.Text) {
if ev.Source == "interactive" || !re.MatchString(ev.Text) {
return nil
}
expanded := re.ReplaceAllStringFunc(ev.Text, func(match string) string {
// Extract the command between !{ and }.
cmd := re.FindStringSubmatch(match)[1]
cmd = strings.TrimSpace(cmd)
out, err := exec.Command("bash", "-c", cmd).Output()
if err != nil {
return match // keep original on error
}
return strings.TrimSpace(string(out))
})
return &ext.InputResult{
Action: "transform",
Text: expanded,
Text: expand(ev.Text),
}
})
}
// expand replaces every !{cmd} in text with the command's stdout.
// On error the original !{cmd} token is preserved.
func expand(text string) string {
return re.ReplaceAllStringFunc(text, func(match string) string {
cmd := re.FindStringSubmatch(match)[1]
cmd = strings.TrimSpace(cmd)
out, err := exec.Command("bash", "-c", cmd).Output()
if err != nil {
return match // keep original on error
}
return strings.TrimSpace(string(out))
})
}
@@ -130,6 +130,58 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}
// TestSubagentMonitor_ConcurrentSubagents verifies no panics when multiple
// subagents emit events concurrently from different goroutines.
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
harness := test.New(t)
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
if err != nil {
t.Fatalf("SessionStart should not error: %v", err)
}
// Start 5 subagents concurrently
done := make(chan struct{}, 5)
for i := range 5 {
go func(idx int) {
defer func() { done <- struct{}{} }()
callID := fmt.Sprintf("concurrent-%d", idx)
task := fmt.Sprintf("concurrent task %d", idx)
_, _ = harness.Emit(extensions.SubagentStartEvent{
ToolCallID: callID,
Task: task,
})
// Emit many chunks rapidly
for j := range 20 {
_, _ = harness.Emit(extensions.SubagentChunkEvent{
ToolCallID: callID,
Task: task,
ChunkType: "text",
Content: fmt.Sprintf("agent %d chunk %d", idx, j),
})
}
_, _ = harness.Emit(extensions.SubagentEndEvent{
ToolCallID: callID,
Task: task,
Response: "done",
})
}(i)
}
// Wait for all goroutines
for range 5 {
<-done
}
// Allow any final processing
time.Sleep(200 * time.Millisecond)
}
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
// even with nil ctx functions.
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
+153
View File
@@ -0,0 +1,153 @@
//go:build ignore
// sudo-handler.go - Extension to handle sudo password prompts securely
//
// This extension intercepts bash commands containing "sudo" and:
// 1. Checks if sudo credentials are already cached (via sudo -n)
// 2. If not cached, prompts the user for their password (with masking)
// 3. Temporarily sets SUDO_PASSWORD environment variable for execution
// 4. The bash tool automatically uses sudo -S -p '' to pipe the password
//
// Usage: kit -e examples/extensions/sudo-handler.go
//
// Security notes:
// - Password is only stored in memory for the duration of the session
// - Password is never logged or displayed
// - Each session requires re-authentication (sudo -k is used)
// - The SUDO_PASSWORD env var is set only during tool execution
package main
import (
"encoding/json"
"os"
"strings"
"sync"
"kit/ext"
)
var (
// cachedPassword stores the sudo password for the session
cachedPassword string
// hasCachedPassword tracks if we have a valid cached password
hasCachedPassword bool
// mu protects cached password access
mu sync.RWMutex
)
// Init sets up the sudo handler extension
func Init(api ext.API) {
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
if tc.ToolName != "bash" {
return nil
}
// Parse the command from tool input
var input struct {
Command string `json:"command"`
}
if err := json.Unmarshal([]byte(tc.Input), &input); err != nil {
return nil
}
// Check if command contains sudo
if !containsSudo(input.Command) {
return nil
}
// Check if we already have cached credentials
mu.RLock()
password := cachedPassword
hasCached := hasCachedPassword
mu.RUnlock()
if hasCached {
// Use cached password
os.Setenv("SUDO_PASSWORD", password)
return nil
}
// No cached password - prompt user
result := ctx.PromptInput(ext.PromptInputConfig{
Message: "🔐 Sudo password required for:\n " + truncateCommand(input.Command, 60),
Placeholder: "Enter your password",
})
if result.Cancelled {
return &ext.ToolCallResult{
Block: true,
Reason: "Sudo password prompt cancelled by user",
}
}
if result.Value == "" {
return &ext.ToolCallResult{
Block: true,
Reason: "No password provided",
}
}
// Cache the password for this session
mu.Lock()
cachedPassword = result.Value
hasCachedPassword = true
mu.Unlock()
// Set environment variable for the bash tool to use
os.Setenv("SUDO_PASSWORD", result.Value)
// Show confirmation (without revealing password)
ctx.PrintInfo("Sudo password cached for this session")
return nil
})
// Clear cached password when session ends
api.OnSessionShutdown(func(event ext.SessionShutdownEvent, ctx ext.Context) {
mu.Lock()
cachedPassword = ""
hasCachedPassword = false
mu.Unlock()
os.Unsetenv("SUDO_PASSWORD")
})
}
// containsSudo checks if the command contains sudo as a command (not in a string)
func containsSudo(command string) bool {
// Simple check for sudo as a word, not inside quotes or as part of another word
lower := strings.ToLower(command)
// Check for sudo at start or after separators
patterns := []string{
"sudo ",
"sudo\t",
";sudo ",
"&& sudo ",
"|| sudo ",
"| sudo ",
"$(sudo ",
"`sudo ",
}
for _, pattern := range patterns {
if strings.Contains(lower, pattern) {
return true
}
}
// Check if command starts with sudo
if strings.HasPrefix(lower, "sudo ") {
return true
}
return false
}
// truncateCommand truncates a long command for display
func truncateCommand(cmd string, maxLen int) string {
if len(cmd) <= maxLen {
return cmd
}
return cmd[:maxLen-3] + "..."
}
+1 -1
View File
@@ -62,7 +62,7 @@ func main() {
}
})
// Subscribe to streaming chunks.
host3.OnStreaming(func(e kit.MessageUpdateEvent) {
host3.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
fmt.Print(e.Chunk)
})
+51 -49
View File
@@ -1,31 +1,32 @@
module github.com/mark3labs/kit
go 1.26.1
go 1.26.2
require (
charm.land/bubbles/v2 v2.1.0
charm.land/bubbletea/v2 v2.0.2
charm.land/fantasy v0.17.1
charm.land/bubbletea/v2 v2.0.6
charm.land/fantasy v0.23.0
charm.land/huh/v2 v2.0.3
charm.land/lipgloss/v2 v2.0.2
github.com/alecthomas/chroma/v2 v2.23.1
charm.land/lipgloss/v2 v2.0.3
github.com/alecthomas/chroma/v2 v2.24.1
github.com/atotto/clipboard v0.1.4
github.com/aymanbagabas/go-udiff v0.4.1
github.com/charmbracelet/fang v1.0.0
github.com/charmbracelet/log v1.0.0
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be
github.com/charmbracelet/x/editor v0.2.0
github.com/clipperhouse/displaywidth v0.11.0
github.com/clipperhouse/uax29/v2 v2.7.0
github.com/coder/acp-go-sdk v0.6.3
github.com/fsnotify/fsnotify v1.9.0
github.com/coder/acp-go-sdk v0.12.2
github.com/fsnotify/fsnotify v1.10.1
github.com/indaco/herald v0.13.0
github.com/indaco/herald-md v0.3.0
github.com/mark3labs/mcp-go v0.47.1
github.com/mark3labs/mcp-go v0.51.0
github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0
github.com/traefik/yaegi v0.16.1
golang.org/x/term v0.41.0
golang.org/x/term v0.42.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -34,23 +35,23 @@ require (
cloud.google.com/go/auth v0.20.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
github.com/aws/smithy-go v1.24.3 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
github.com/aws/smithy-go v1.25.1 // indirect
github.com/catppuccin/go v0.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
@@ -58,40 +59,41 @@ require (
github.com/charmbracelet/harmonica v0.2.0 // indirect
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 // indirect
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 // indirect
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 // indirect
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 // indirect
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
github.com/charmbracelet/x/json v0.2.0 // indirect
github.com/charmbracelet/x/termios v0.1.1 // indirect
github.com/charmbracelet/x/windows v0.2.2 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dlclark/regexp2 v1.12.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 // indirect
github.com/go-logfmt/logfmt v0.6.1 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/goccy/go-yaml v1.19.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/google/jsonschema-go v0.4.3 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/kaptinlin/go-i18n v0.3.1 // indirect
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
github.com/kaptinlin/jsonschema v0.7.7 // indirect
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
github.com/kaptinlin/go-i18n v0.4.7 // indirect
github.com/kaptinlin/jsonpointer v0.4.21 // indirect
github.com/kaptinlin/jsonschema v0.7.13 // indirect
github.com/kaptinlin/messageformat-go v0.6.3 // indirect
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
github.com/muesli/mango v0.2.0 // indirect
github.com/muesli/mango-cobra v1.3.0 // indirect
github.com/muesli/mango-pflag v0.2.0 // indirect
github.com/muesli/roff v0.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
github.com/sagikazarmark/locafero v0.12.0 // indirect
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
@@ -109,32 +111,32 @@ require (
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.49.0 // indirect
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/crypto v0.50.0 // indirect
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
golang.org/x/net v0.53.0 // indirect
golang.org/x/oauth2 v0.36.0 // indirect
golang.org/x/time v0.15.0 // indirect
google.golang.org/api v0.275.0 // indirect
google.golang.org/genai v1.52.1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
google.golang.org/grpc v1.80.0 // indirect
google.golang.org/api v0.277.0 // indirect
google.golang.org/genai v1.55.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect
google.golang.org/grpc v1.81.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6
github.com/charmbracelet/x/ansi v0.11.7
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.21 // indirect
github.com/mattn/go-isatty v0.0.22 // indirect
github.com/mattn/go-runewidth v0.0.23 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/pflag v1.0.10
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.35.0
golang.org/x/text v0.36.0
)
+102 -98
View File
@@ -1,13 +1,13 @@
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
charm.land/fantasy v0.23.0 h1:pocjwC5CxfEg1Bpwb0raML2d5ijo3op33Mmd6hYJyo4=
charm.land/fantasy v0.23.0/go.mod h1:4yzSsd9XmFEVjRnF1P0LTEbLTmQX6OLnPkrHaf7iruo=
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
@@ -16,8 +16,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
@@ -28,42 +28,42 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
github.com/alecthomas/chroma/v2 v2.24.1 h1:m5ffpfZbIb++k8AqFEKy9uVgY12xIQtBsQlc6DfZJQM=
github.com/alecthomas/chroma/v2 v2.24.1/go.mod h1:l+ohZ9xRXIbGe7cIW+YZgOGbvuVLjMps/FYN/CwuabI=
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
@@ -86,24 +86,26 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b h1:ASDO9RT6SNKTQN87jO2bRfxHFJq8cgeYdFzivY2gCeM=
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b/go.mod h1:Vo8TffMf0q7Uho/n8e6XpBZvOWtd3g39yX+9P5rRutA=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be h1:j7w8VP/D4lu5+/4GamMmFy8nrtadcl82/fjvDgSHwLo=
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4jxsAuEs=
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIRg8gGWwk=
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 h1:zmBor0ftFNqVFp9U59ZoEDRUCIYSGOGSIfGGkNZRufs=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 h1:rByFKh9JgQScu7oy0+TlUbC2e93woW/QNZmNXbbbw/E=
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 h1:aEppolah2k9c0LzKX2fk5ryuyQ0Lq8kCOjkvMw1b8o4=
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 h1:PMjHdSo8Vpq9psUw9BoHo9JLPMkm9Hqb+Whk64n3AQQ=
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
@@ -122,15 +124,15 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ=
github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
github.com/coder/acp-go-sdk v0.12.2 h1:fpRJ8Z5HMSr5cZ5IywzFlFZcIxZOsto+laNVu7XelFA=
github.com/coder/acp-go-sdk v0.12.2/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
@@ -144,10 +146,10 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 h1:2WmHkJINIjgXXYDGik8d3oJvFA3DAwPy00csDJ3vo+o=
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
@@ -165,16 +167,16 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/googleapis/enterprise-certificate-proxy v0.3.15 h1:xolVQTEXusUcAA5UgtyRLjelpFFHWlPQ4XfWGc7MBas=
github.com/googleapis/enterprise-certificate-proxy v0.3.15/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
@@ -185,14 +187,14 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
github.com/kaptinlin/go-i18n v0.3.1 h1:plXi3XQE1aYamFi8TU0K6actODmw2+5FSobmhTkfQ/0=
github.com/kaptinlin/go-i18n v0.3.1/go.mod h1:ZRoAHj7elWYamfbv7wev7Ajch6LOzjtBaq8nWe8HIVk=
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
github.com/kaptinlin/messageformat-go v0.4.19 h1:A5kuuZ1ybXDQ7kD1aoEWGAOemX7hLsMY0yolgSbgpRI=
github.com/kaptinlin/messageformat-go v0.4.19/go.mod h1:utSDTfiXTxl66OC5RIEuObLH7Ue3YjbA2X86SYMBYWg=
github.com/kaptinlin/go-i18n v0.4.7 h1:apjIIZHnGRyrkiX3vHj07F1BF6D0JLmV+VGSr1781Jc=
github.com/kaptinlin/go-i18n v0.4.7/go.mod h1:+i1J0pFq/9i9ESC5qRMVkKwC+mdQTABhhBExpYOlbeM=
github.com/kaptinlin/jsonpointer v0.4.21 h1:WVkwQbeerbHFcoXG7Yo/mlQhhZjWiTnagECEfwDXXa0=
github.com/kaptinlin/jsonpointer v0.4.21/go.mod h1:Mo7+DX8RlQTFqS4dnYJl0izSP4ob+Rl5xO/mGDETgaU=
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
github.com/kaptinlin/messageformat-go v0.6.3 h1:m9ZE/fCjnsk8bdkv7Qs56L/ZoHbmQqhz9mRZSAQLU5g=
github.com/kaptinlin/messageformat-go v0.6.3/go.mod h1:2KOZ/hgo/SveZ+uyi7vPUpUXieX65Mppzbc3VpGyqKs=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -201,10 +203,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mark3labs/mcp-go v0.47.1 h1:A9sJJ20mscl/ssLYHjodfaoBmq6uuhMG7pAPNYaQymQ=
github.com/mark3labs/mcp-go v0.47.1/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mark3labs/mcp-go v0.51.0 h1:e8AhEfxzcYt7XqYzwT7uzWNhnqpu3H1Tn7dEJB9Ygj8=
github.com/mark3labs/mcp-go v0.51.0/go.mod h1:Zg9cB2HdwdMMVgY0xtTzq3KvYIOJQDsaut+jWjwDaQY=
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
@@ -221,8 +223,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
@@ -236,6 +238,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
@@ -288,38 +292,38 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
google.golang.org/api v0.277.0 h1:HJfyJUiNeBBUMai7ez8u14wkp/gH/I4wpGbbO9o+cSk=
google.golang.org/api v0.277.0/go.mod h1:B9TqLBwJqVjp1mtt7WeoQwWRwvu/400y5lETOql+giQ=
google.golang.org/genai v1.55.0 h1:iLHGk4Bj/IZ/GNNZb7hYqwSJMRBvqLeu2Hb6YQ+rYGw=
google.golang.org/genai v1.55.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4 h1:2iMJZntwvmfgtse+s744JY7v7PgEdSBuFYXucvpOHNM=
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:v14kaaboYyXQ1Gsu489Q+Hg/oN4B33mWtuOhF1HCeXA=
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo=
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw=
google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+61 -8
View File
@@ -177,22 +177,75 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
return acp.SetSessionModeResponse{}, nil
}
// SetSessionModel changes the active model for a session.
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
// ListSessions returns an empty session list. Kit doesn't persist sessions
// across restarts in ACP mode, so this is effectively a no-op.
func (a *Agent) ListSessions(_ context.Context, _ acp.ListSessionsRequest) (acp.ListSessionsResponse, error) {
return acp.ListSessionsResponse{
Sessions: []acp.SessionInfo{},
}, nil
}
// CloseSession cancels any ongoing work for the session and frees its resources.
func (a *Agent) CloseSession(_ context.Context, params acp.CloseSessionRequest) (acp.CloseSessionResponse, error) {
sessionID := string(params.SessionId)
sess, ok := a.registry.get(sessionID)
if !ok {
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
return acp.CloseSessionResponse{}, nil
}
modelID := string(params.ModelId)
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
log.Debug("acp: close session", "session", sessionID)
sess.cancelPrompt()
a.registry.remove(sessionID)
return acp.CloseSessionResponse{}, nil
}
if err := sess.kit.SetModel(ctx, modelID); err != nil {
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
// ResumeSession is not supported — Kit doesn't persist sessions across
// restarts in ACP mode. Clients should use NewSession instead.
func (a *Agent) ResumeSession(_ context.Context, _ acp.ResumeSessionRequest) (acp.ResumeSessionResponse, error) {
return acp.ResumeSessionResponse{}, fmt.Errorf("resume session not supported")
}
// SetSessionConfigOption handles session configuration changes. Currently
// supports the "model" config option to change the active model for a session.
func (a *Agent) SetSessionConfigOption(ctx context.Context, params acp.SetSessionConfigOptionRequest) (acp.SetSessionConfigOptionResponse, error) {
// Extract session ID and config ID from whichever variant is present.
var sessionID string
var configID string
var value string
switch {
case params.ValueId != nil:
sessionID = string(params.ValueId.SessionId)
configID = string(params.ValueId.ConfigId)
value = string(params.ValueId.Value)
case params.Boolean != nil:
sessionID = string(params.Boolean.SessionId)
configID = string(params.Boolean.ConfigId)
// Boolean config options are not used for model selection.
log.Debug("acp: set_session_config_option (boolean)", "session", sessionID, "config", configID, "value", params.Boolean.Value)
return acp.SetSessionConfigOptionResponse{}, nil
default:
return acp.SetSessionConfigOptionResponse{}, acp.NewInvalidParams("unsupported config option variant")
}
return acp.SetSessionModelResponse{}, nil
sess, ok := a.registry.get(sessionID)
if !ok {
return acp.SetSessionConfigOptionResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
}
log.Debug("acp: set_session_config_option", "session", sessionID, "config", configID, "value", value)
// Handle known config options.
switch configID {
case "model":
if err := sess.kit.SetModel(ctx, value); err != nil {
return acp.SetSessionConfigOptionResponse{}, fmt.Errorf("set model: %w", err)
}
default:
log.Debug("acp: unknown config option", "config", configID)
}
return acp.SetSessionConfigOptionResponse{}, nil
}
// ---------------------------------------------------------------------------
+14
View File
@@ -232,6 +232,20 @@ func (r *sessionRegistry) closeAll() {
}
}
// remove closes and removes a single session by ID.
func (r *sessionRegistry) remove(sessionID string) {
r.mu.Lock()
defer r.mu.Unlock()
sess, ok := r.sessions[sessionID]
if !ok {
return
}
if sess.kit != nil {
_ = sess.kit.Close()
}
delete(r.sessions, sessionID)
}
// cancelPrompt cancels the current prompt for a session, if any.
func (s *acpSession) cancelPrompt() {
s.cancelMu.Lock()
+433 -70
View File
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"strings"
"time"
"charm.land/fantasy"
@@ -58,6 +59,11 @@ type AgentConfig struct {
// loading (successfully or with error). The callback receives the server
// name, tool count, and any error. Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// MCPTaskConfig configures task-augmented tools/call execution. The
// zero value preserves historical synchronous-only behaviour for any
// server that didn't advertise task support during initialize.
MCPTaskConfig tools.MCPTaskConfig
}
// ToolCallHandler is a function type for handling tool calls as they happen.
@@ -87,6 +93,19 @@ type ReasoningDeltaHandler func(delta string)
// Called when the last reasoning token has been processed, before text streaming starts.
type ReasoningCompleteHandler func()
// ToolCallStartHandler is a function type for handling the moment when the LLM
// begins generating tool call arguments. The tool name is known but the full
// argument JSON is still streaming.
type ToolCallStartHandler func(toolCallID, toolName string)
// ToolCallDeltaHandler is a function type for handling streamed fragments of
// tool call arguments as they arrive from the LLM.
type ToolCallDeltaHandler func(toolCallID, delta string)
// ToolCallEndHandler is a function type for handling the end of tool argument
// streaming, before the tool call is parsed and execution begins.
type ToolCallEndHandler func(toolCallID string)
// ToolOutputHandler is a function type for handling streaming tool output chunks.
// Used by tools like bash to stream output as it arrives rather than waiting
// for the command to complete. The isStderr flag indicates if the chunk
@@ -94,6 +113,12 @@ type ReasoningCompleteHandler func()
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
type ToolOutputHandler = core.ToolOutputCallback
// PasswordPromptHandler is a function type for password prompts.
// Used by the bash tool when sudo requires a password. The handler receives
// a prompt message and returns the password and whether it was cancelled.
// Note: This is an alias for core.PasswordPromptCallback.
type PasswordPromptHandler = core.PasswordPromptCallback
// StepMessagesHandler is a function type for persisting messages after each
// complete step in a multi-step agent turn. The handler receives the messages
// produced by the step (typically an assistant message with tool calls followed
@@ -107,6 +132,76 @@ type StepMessagesHandler func(stepMessages []fantasy.Message)
// tracking during long-running tool-calling conversations.
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
// StepStartHandler is called when a new LLM step begins within a turn.
type StepStartHandler func(stepNumber int)
// StepFinishHandler is called when a step completes with full context.
type StepFinishHandler func(stepNumber int, hasToolCalls bool, finishReason string, usage fantasy.Usage)
// TextStartHandler is called when the LLM begins generating text content.
type TextStartHandler func(id string)
// TextEndHandler is called when the LLM finishes generating text content.
type TextEndHandler func(id string)
// ReasoningStartHandler is called when the LLM begins reasoning/thinking.
type ReasoningStartHandler func(id string)
// WarningsHandler is called when the LLM provider returns warnings.
type WarningsHandler func(warnings []string)
// SourceHandler is called when the LLM references a source.
type SourceHandler func(sourceType, id, url, title string)
// StreamFinishHandler is called when a per-step LLM stream completes.
type StreamFinishHandler func(usage fantasy.Usage, finishReason string)
// ErrorHandler is called when an agent-level error occurs.
type ErrorHandler func(err error)
// RetryHandler is called when the LLM request is retried.
type RetryHandler func(attempt int, err error)
// PrepareStepHandler is called between steps to allow message modification.
// It receives the step number and current messages, and returns replacement
// messages (or nil to keep unchanged).
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
// GenerateCallbacks consolidates all callback functions for
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
// 16+ positional callback parameters, making it easier to add new callbacks
// without breaking existing callers (new fields default to nil).
type GenerateCallbacks struct {
OnToolCall ToolCallHandler
OnToolExecution ToolExecutionHandler
OnToolResult ToolResultHandler
OnResponse ResponseHandler
OnToolCallContent ToolCallContentHandler
OnStreamingResponse StreamingResponseHandler
OnReasoningDelta ReasoningDeltaHandler
OnReasoningComplete ReasoningCompleteHandler
OnToolOutput ToolOutputHandler
OnStepMessages StepMessagesHandler
OnStepUsage StepUsageHandler
OnPasswordPrompt PasswordPromptHandler
OnToolCallStart ToolCallStartHandler
OnToolCallDelta ToolCallDeltaHandler
OnToolCallEnd ToolCallEndHandler
// New callbacks for previously unwired Fantasy lifecycle events.
OnStepStart StepStartHandler
OnStepFinish StepFinishHandler
OnTextStart TextStartHandler
OnTextEnd TextEndHandler
OnReasoningStart ReasoningStartHandler
OnWarnings WarningsHandler
OnSource SourceHandler
OnStreamFinish StreamFinishHandler
OnError ErrorHandler
OnRetry RetryHandler
OnPrepareStep PrepareStepHandler
}
// Agent represents an AI agent with core tool integration using the LLM library.
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
// AgentTool implementations — no MCP layer, no serialization overhead.
@@ -135,6 +230,16 @@ type Agent struct {
skipMaxOutputTokens bool
modelConfig *models.ProviderConfig
// authHandler and tokenStoreFactory are stored from AgentConfig so that
// AddMCPServer() can propagate them when creating a new MCPToolManager
// at runtime (i.e. when no MCP servers were configured at init time).
authHandler tools.MCPAuthHandler
tokenStoreFactory tools.TokenStoreFactory
// mcpTaskConfig is stored from AgentConfig so AddMCPServer() can
// propagate it to a lazily-created MCPToolManager.
mcpTaskConfig tools.MCPTaskConfig
// mcpReady is closed when background MCP tool loading completes (success
// or failure). nil when no MCP servers are configured.
mcpReady chan struct{}
@@ -231,13 +336,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
providerOptions: providerResult.ProviderOptions,
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
modelConfig: agentConfig.ModelConfig,
authHandler: agentConfig.AuthHandler,
tokenStoreFactory: agentConfig.TokenStoreFactory,
mcpTaskConfig: agentConfig.MCPTaskConfig,
}
// Start MCP tool loading in the background if servers are configured.
// The mcpReady channel is closed when loading completes (success or failure).
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
toolManager := tools.NewMCPToolManager()
toolManager.SetModel(providerResult.Model)
if agentConfig.AuthHandler != nil {
toolManager.SetAuthHandler(agentConfig.AuthHandler)
}
@@ -251,6 +358,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
if agentConfig.OnMCPServerLoaded != nil {
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
}
// Apply task-augmented tool execution config (zero value = no-op).
toolManager.SetTaskConfig(agentConfig.MCPTaskConfig)
a.toolManager = toolManager
a.mcpReady = make(chan struct{})
@@ -317,7 +426,7 @@ func (a *Agent) rebuildFantasyAgent() {
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
@@ -397,13 +506,20 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
OnToolCall: onToolCall,
OnToolExecution: onToolExecution,
OnToolResult: onToolResult,
OnResponse: onResponse,
OnToolCallContent: onToolCallContent,
})
}
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
// The agent handles the tool call loop internally. We map the rich callback system
// to kit's existing callback interface for UI integration.
// The agent handles the tool call loop internally.
//
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
// struct and is easier to extend with new callbacks.
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
@@ -413,6 +529,35 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
onToolOutput ToolOutputHandler,
onStepMessages StepMessagesHandler,
onStepUsage StepUsageHandler,
onPasswordPrompt PasswordPromptHandler,
onToolCallStart ToolCallStartHandler,
onToolCallDelta ToolCallDeltaHandler,
onToolCallEnd ToolCallEndHandler,
) (*GenerateWithLoopResult, error) {
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
OnToolCall: onToolCall,
OnToolExecution: onToolExecution,
OnToolResult: onToolResult,
OnResponse: onResponse,
OnToolCallContent: onToolCallContent,
OnStreamingResponse: onStreamingResponse,
OnReasoningDelta: onReasoningDelta,
OnReasoningComplete: onReasoningComplete,
OnToolOutput: onToolOutput,
OnStepMessages: onStepMessages,
OnStepUsage: onStepUsage,
OnPasswordPrompt: onPasswordPrompt,
OnToolCallStart: onToolCallStart,
OnToolCallDelta: onToolCallDelta,
OnToolCallEnd: onToolCallEnd,
})
}
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
// The agent handles the tool call loop internally. We map the rich callback system
// to kit's existing callback interface for UI integration.
func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Message,
cb GenerateCallbacks,
) (*GenerateWithLoopResult, error) {
// Wait for background MCP tool loading to complete and rebuild the
@@ -421,8 +566,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
a.ensureMCPTools()
// Inject tool output handler into context for use by core tools (e.g., bash).
if onToolOutput != nil {
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
if cb.OnToolOutput != nil {
ctx = core.ContextWithToolOutputCallback(ctx, cb.OnToolOutput)
}
// Inject password prompt handler into context for use by bash tool.
if cb.OnPasswordPrompt != nil {
ctx = core.ContextWithPasswordPrompt(ctx, cb.OnPasswordPrompt)
}
// The agent requires the current user input as Prompt, with prior messages as history.
@@ -442,8 +592,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
// Stream is required to observe tool execution in real time. The non-streaming
// Generate path is reserved for the simple case with no callbacks at all.
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
onToolCallContent != nil || onStreamingResponse != nil || onReasoningDelta != nil
hasCallbacks := cb.OnToolCall != nil || cb.OnToolExecution != nil || cb.OnToolResult != nil ||
cb.OnToolCallContent != nil || cb.OnStreamingResponse != nil || cb.OnReasoningDelta != nil ||
cb.OnToolCallStart != nil || cb.OnToolCallDelta != nil || cb.OnToolCallEnd != nil ||
cb.OnStepStart != nil || cb.OnStepFinish != nil || cb.OnTextStart != nil ||
cb.OnTextEnd != nil || cb.OnReasoningStart != nil || cb.OnWarnings != nil ||
cb.OnSource != nil || cb.OnStreamFinish != nil || cb.OnError != nil ||
cb.OnRetry != nil || cb.OnPrepareStep != nil
if a.streamingEnabled || hasCallbacks {
// Track completed step messages so we can return partial results
@@ -452,9 +607,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// for every step that completed before the error occurred.
var completedStepMessages []fantasy.Message
// persistedCount tracks how many new messages (beyond the original
// input) were persisted incrementally via onStepMessages, so the
// input) were persisted incrementally via cb.OnStepMessages, so the
// caller can skip them during post-generation persistence.
var persistedCount int
// stepCounter tracks the current step number for StepStart/StepFinish events.
var stepCounter int
// Use the streaming agent
streamCall := fantasy.AgentStreamCall{
@@ -462,13 +619,73 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
Files: files,
Messages: history,
// Tool input streaming callbacks — fire during tool argument generation
OnToolInputStart: func(id, toolName string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnToolCallStart != nil {
cb.OnToolCallStart(id, toolName)
}
return nil
},
OnToolInputDelta: func(id, delta string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnToolCallDelta != nil {
cb.OnToolCallDelta(id, delta)
}
return nil
},
OnToolInputEnd: func(id string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnToolCallEnd != nil {
cb.OnToolCallEnd(id)
}
return nil
},
// Text start/end callbacks
OnTextStart: func(id string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnTextStart != nil {
cb.OnTextStart(id)
}
return nil
},
OnTextEnd: func(id string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnTextEnd != nil {
cb.OnTextEnd(id)
}
return nil
},
// Reasoning start callback
OnReasoningStart: func(id string, _ fantasy.ReasoningContent) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnReasoningStart != nil {
cb.OnReasoningStart(id)
}
return nil
},
// Reasoning/thinking streaming callback
OnReasoningDelta: func(id, delta string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if onReasoningDelta != nil {
onReasoningDelta(delta)
if cb.OnReasoningDelta != nil {
cb.OnReasoningDelta(delta)
}
return nil
},
@@ -478,8 +695,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
if ctx.Err() != nil {
return ctx.Err()
}
if onReasoningComplete != nil {
onReasoningComplete()
if cb.OnReasoningComplete != nil {
cb.OnReasoningComplete()
}
return nil
},
@@ -489,8 +706,64 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
if ctx.Err() != nil {
return ctx.Err()
}
if onStreamingResponse != nil {
onStreamingResponse(text)
if cb.OnStreamingResponse != nil {
cb.OnStreamingResponse(text)
}
return nil
},
// Warnings callback
OnWarnings: func(warnings []fantasy.CallWarning) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnWarnings != nil {
strs := make([]string, len(warnings))
for i, w := range warnings {
strs[i] = w.Message
}
cb.OnWarnings(strs)
}
return nil
},
// Source callback
OnSource: func(source fantasy.SourceContent) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnSource != nil {
cb.OnSource(string(source.SourceType), source.ID, source.URL, source.Title)
}
return nil
},
// Stream finish callback (per-step stream completion)
OnStreamFinish: func(usage fantasy.Usage, finishReason fantasy.FinishReason, _ fantasy.ProviderMetadata) error {
if ctx.Err() != nil {
return ctx.Err()
}
if cb.OnStreamFinish != nil {
cb.OnStreamFinish(usage, string(finishReason))
}
return nil
},
// Error callback
OnError: func(err error) {
if cb.OnError != nil {
cb.OnError(err)
}
},
// Step start callback
OnStepStart: func(stepNumber int) error {
if ctx.Err() != nil {
return ctx.Err()
}
stepCounter = stepNumber
if cb.OnStepStart != nil {
cb.OnStepStart(stepNumber)
}
return nil
},
@@ -503,13 +776,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
currentToolArgs = tc.Input
// Notify about the tool call
if onToolCall != nil {
onToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
if cb.OnToolCall != nil {
cb.OnToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
}
// Notify tool execution starting
if onToolExecution != nil {
onToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
if cb.OnToolExecution != nil {
cb.OnToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
}
return nil
@@ -521,14 +794,14 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
return ctx.Err()
}
// Notify tool execution finished
if onToolExecution != nil {
onToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
if cb.OnToolExecution != nil {
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
}
if onToolResult != nil {
if cb.OnToolResult != nil {
// Extract result text and error status
resultText, isError := extractToolResultText(tr)
onToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
}
return nil
@@ -542,8 +815,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// Persist step messages incrementally so progress is saved
// as it happens rather than only at the end of the turn.
if onStepMessages != nil && len(step.Messages) > 0 {
onStepMessages(step.Messages)
if cb.OnStepMessages != nil && len(step.Messages) > 0 {
cb.OnStepMessages(step.Messages)
persistedCount += len(step.Messages)
}
@@ -553,65 +826,88 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// Check if step has text content alongside tool calls
text := step.Content.Text()
toolCalls := step.Content.ToolCalls()
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
onToolCallContent(text)
if text != "" && len(toolCalls) > 0 && cb.OnToolCallContent != nil {
cb.OnToolCallContent(text)
}
// Emit step usage for real-time cost tracking
if onStepUsage != nil {
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
if cb.OnStepUsage != nil {
cb.OnStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
}
// Emit unified step finish event
if cb.OnStepFinish != nil {
cb.OnStepFinish(stepCounter, len(toolCalls) > 0, string(step.FinishReason), step.Usage)
}
return nil
},
}
// If a steer channel is attached to the context, wire up a
// PrepareStep function that drains the channel between steps
// and injects pending steer messages as user messages before
// the next LLM call. This enables graceful mid-turn steering
// without cancelling in-progress tool execution.
if steerCh := steerChFromContext(ctx); steerCh != nil {
onConsumed := steerConsumedFromContext(ctx)
// Always wire up PrepareStep to handle both steering and the
// OnPrepareStep hook. Steering drains its channel first, then
// OnPrepareStep hooks run against the (possibly already steered)
// messages.
steerCh := steerChFromContext(ctx)
onConsumed := steerConsumedFromContext(ctx)
hasSteering := steerCh != nil
hasPrepareStepHook := cb.OnPrepareStep != nil
if hasSteering || hasPrepareStepHook {
streamCall.PrepareStep = func(
stepCtx context.Context,
opts fantasy.PrepareStepFunctionOptions,
) (context.Context, fantasy.PrepareStepResult, error) {
// Drain all pending steer messages (non-blocking).
var steered []SteerMessage
for {
select {
case msg := <-steerCh:
steered = append(steered, msg)
default:
goto done
}
}
done:
result := fantasy.PrepareStepResult{
Model: opts.Model,
Messages: opts.Messages,
}
if len(steered) > 0 {
// Inject each steer message as a user message so the
// LLM sees the redirection on the next step.
for _, sm := range steered {
result.Messages = append(result.Messages,
fantasy.NewUserMessage(sm.Text, sm.Files...))
// Phase 1: Drain steering channel (if present).
if hasSteering {
var steered []SteerMessage
for {
select {
case msg := <-steerCh:
steered = append(steered, msg)
default:
goto done
}
}
// Notify that steer messages were consumed.
if onConsumed != nil {
onConsumed(len(steered))
done:
if len(steered) > 0 {
for _, sm := range steered {
result.Messages = append(result.Messages,
fantasy.NewUserMessage(sm.Text, sm.Files...))
}
if onConsumed != nil {
onConsumed(len(steered))
}
}
}
// Phase 2: Run OnPrepareStep hook (if registered).
if hasPrepareStepHook {
if replacement := cb.OnPrepareStep(opts.StepNumber, result.Messages); replacement != nil {
result.Messages = replacement
}
}
// Apply message-level cache control for Anthropic models.
// This avoids type conflicts with provider-level options.
result.Messages = applyCacheControlToMessages(result.Messages)
return stepCtx, result, nil
}
}
// Wire OnRetry callback if provided.
if cb.OnRetry != nil {
streamCall.OnRetry = func(err *fantasy.ProviderError, _ time.Duration) {
// Use the retry number from the error if available; Fantasy
// doesn't pass a counter directly, so we approximate with a
// counter incremented on each call.
cb.OnRetry(0, err)
}
}
result, err := a.fantasyAgent.Stream(ctx, streamCall)
if err != nil {
// On cancellation (or any error), return a partial result
@@ -637,8 +933,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// empty (e.g. reasoning-only responses) so the UI properly resets
// the stream component and avoids duplicate content on the next
// flush.
if onResponse != nil {
onResponse(result.Response.Content.Text())
if cb.OnResponse != nil {
cb.OnResponse(result.Response.Content.Text())
}
r := convertAgentResult(result, messages)
@@ -658,8 +954,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
// For non-streaming, fire the response callback so callers can reset
// streaming state (see streaming path comment above).
if onResponse != nil {
onResponse(result.Response.Content.Text())
if cb.OnResponse != nil {
cb.OnResponse(result.Response.Content.Text())
}
return convertAgentResult(result, messages), nil
@@ -800,7 +1096,7 @@ func (a *Agent) GetTools() []fantasy.AgentTool {
allTools := make([]fantasy.AgentTool, len(a.coreTools))
copy(allTools, a.coreTools)
if a.toolManager != nil {
allTools = append(allTools, a.toolManager.GetTools()...)
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
}
if len(a.extraTools) > 0 {
allTools = append(allTools, a.extraTools...)
@@ -844,7 +1140,13 @@ func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPSer
if a.toolManager == nil {
a.toolManager = tools.NewMCPToolManager()
a.toolManager.SetModel(a.model)
if a.authHandler != nil {
a.toolManager.SetAuthHandler(a.authHandler)
}
if a.tokenStoreFactory != nil {
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
}
a.toolManager.SetTaskConfig(a.mcpTaskConfig)
a.toolManager.SetOnToolsChanged(func() {
a.rebuildFantasyAgent()
})
@@ -900,6 +1202,56 @@ func (a *Agent) GetLoadedServerNames() []string {
return a.toolManager.GetLoadedServerNames()
}
// GetMCPPrompts returns all prompts discovered from connected MCP servers.
// Returns nil if no MCP servers are configured or no prompts were found.
func (a *Agent) GetMCPPrompts() []tools.MCPPrompt {
if a.toolManager == nil {
return nil
}
return a.toolManager.GetPrompts()
}
// GetMCPPrompt retrieves and expands a specific prompt from an MCP server.
// This is a lazy call — the server is contacted each time.
func (a *Agent) GetMCPPrompt(ctx context.Context, serverName, promptName string, args map[string]string) (*tools.MCPPromptResult, error) {
if a.toolManager == nil {
return nil, fmt.Errorf("no MCP servers configured")
}
return a.toolManager.GetPrompt(ctx, serverName, promptName, args)
}
// GetMCPResources returns all resources discovered from connected MCP servers.
func (a *Agent) GetMCPResources() []tools.MCPResource {
if a.toolManager == nil {
return nil
}
return a.toolManager.GetResources()
}
// ReadMCPResource reads a specific resource from an MCP server by URI.
func (a *Agent) ReadMCPResource(ctx context.Context, serverName, uri string) (*tools.MCPResourceContent, error) {
if a.toolManager == nil {
return nil, fmt.Errorf("no MCP servers configured")
}
return a.toolManager.ReadResource(ctx, serverName, uri)
}
// SubscribeMCPResource subscribes to change notifications for a resource.
func (a *Agent) SubscribeMCPResource(ctx context.Context, serverName, uri string) error {
if a.toolManager == nil {
return fmt.Errorf("no MCP servers configured")
}
return a.toolManager.SubscribeResource(ctx, serverName, uri)
}
// UnsubscribeMCPResource cancels change notifications for a resource.
func (a *Agent) UnsubscribeMCPResource(ctx context.Context, serverName, uri string) error {
if a.toolManager == nil {
return fmt.Errorf("no MCP servers configured")
}
return a.toolManager.UnsubscribeResource(ctx, serverName, uri)
}
// SetModel swaps the agent's LLM provider to a new model. The existing tools
// and configuration are preserved. When the new model's ProviderConfig carries
// a system prompt (from per-model settings), it replaces the agent's stored
@@ -919,11 +1271,6 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
_ = a.providerCloser.Close()
}
// Update model info on MCP tool manager.
if a.toolManager != nil {
a.toolManager.SetModel(providerResult.Model)
}
// Swap fields.
a.model = providerResult.Model
a.providerCloser = providerResult.Closer
@@ -956,6 +1303,22 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
return a.model
}
// GetMaxTokens returns the effective max output tokens the agent currently
// sends to the LLM provider, after per-model defaults, right-sizing, and any
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
// attached (e.g. early init) or when the provider suppresses the parameter
// (e.g. Codex OAuth), which allows callers to differentiate "default" from
// "explicitly capped".
func (a *Agent) GetMaxTokens() int {
if a.skipMaxOutputTokens {
return 0
}
if a.modelConfig == nil {
return 0
}
return a.modelConfig.MaxTokens
}
// Close closes the agent and cleans up resources.
// If MCP tools are still loading in the background, Close waits for them
// to finish before closing connections to avoid resource leaks.
+60
View File
@@ -55,6 +55,17 @@ func echoServerConfig(t *testing.T) config.MCPServerConfig {
}
}
// mockAuthHandler is a minimal MCPAuthHandler for testing that auth handler
// propagation works without requiring a real OAuth server.
type mockAuthHandler struct {
redirectURI string
}
func (h *mockAuthHandler) RedirectURI() string { return h.redirectURI }
func (h *mockAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
return "", nil
}
// newTestAgent creates a minimal Agent with a mock model and no core tools,
// suitable for testing MCP server management without an API key.
func newTestAgent() *Agent {
@@ -240,3 +251,52 @@ func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
}
}
// TestAgent_AddMCPServer_InheritsAuthHandler verifies that AddMCPServer()
// propagates the agent's authHandler and tokenStoreFactory to a newly created
// MCPToolManager (fix for issue #3).
func TestAgent_AddMCPServer_InheritsAuthHandler(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
handler := &mockAuthHandler{redirectURI: "http://localhost:9999/oauth/callback"}
model := &mockModel{}
a := &Agent{
model: model,
coreTools: nil,
extraTools: nil,
maxSteps: 10,
systemPrompt: "test",
fantasyAgent: fantasy.NewAgent(model),
authHandler: handler,
tokenStoreFactory: nil, // nil is fine; we just test authHandler propagation
}
defer func() { _ = a.Close() }()
// Initially no tool manager.
if a.GetMCPToolManager() != nil {
t.Fatal("Expected nil tool manager initially")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cfg := echoServerConfig(t)
_, err := a.AddMCPServer(ctx, "echo", cfg)
if err != nil {
t.Fatalf("AddMCPServer failed: %v", err)
}
// Tool manager should now exist and have the auth handler set.
tm := a.GetMCPToolManager()
if tm == nil {
t.Fatal("Expected tool manager to be created by AddMCPServer")
}
// Verify the auth handler was propagated by checking the field directly.
if tm.GetAuthHandler() == nil {
t.Fatal("Expected auth handler to be propagated to tool manager")
}
}
+3
View File
@@ -56,6 +56,8 @@ type AgentCreationOptions struct {
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// MCPTaskConfig configures task-augmented tools/call execution.
MCPTaskConfig tools.MCPTaskConfig
}
// CreateAgent creates an agent with optional spinner for Ollama models.
@@ -76,6 +78,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
ToolWrapper: opts.ToolWrapper,
ExtraTools: opts.ExtraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
MCPTaskConfig: opts.MCPTaskConfig,
}
var agent *Agent
+65
View File
@@ -0,0 +1,65 @@
package agent
import (
"context"
"fmt"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/tools"
)
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
// This keeps the fantasy dependency confined to the agent layer — the tools
// package is a pure MCP client library with no LLM framework dependency.
type mcpAgentTool struct {
tool tools.MCPTool
manager *tools.MCPToolManager
providerOptions fantasy.ProviderOptions
}
// Info returns the fantasy tool info including name, description, and parameter schema.
func (t *mcpAgentTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{
Name: t.tool.Name,
Description: t.tool.Description,
Parameters: t.tool.Parameters,
Required: t.tool.Required,
}
}
// Run executes the MCP tool by delegating to the MCPToolManager.
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
}
if result.IsError {
return fantasy.NewTextErrorResponse(result.Content), nil
}
return fantasy.NewTextResponse(result.Content), nil
}
// ProviderOptions returns provider-specific options for this tool.
func (t *mcpAgentTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
// SetProviderOptions sets provider-specific options for this tool.
func (t *mcpAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.providerOptions = opts
}
// mcpToolsToAgentTools converts a slice of MCPTool to fantasy.AgentTool
// implementations that route execution through the MCPToolManager.
func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManager) []fantasy.AgentTool {
agentTools := make([]fantasy.AgentTool, len(mcpTools))
for i, t := range mcpTools {
agentTools[i] = &mcpAgentTool{
tool: t,
manager: manager,
}
}
return agentTools
}
+127 -10
View File
@@ -497,6 +497,12 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
// response text to stdout. No intermediate events are emitted. Blocks until
// the step completes or ctx is cancelled.
func (a *App) RunOnce(ctx context.Context, prompt string) error {
return a.RunOnceWithFiles(ctx, prompt, nil)
}
// RunOnceWithFiles executes a single agent step synchronously with optional
// multimodal file attachments. Prints the response to stdout and returns.
func (a *App) RunOnceWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) error {
stepCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -504,7 +510,7 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
a.cancelStep = cancel
a.mu.Unlock()
result, err := a.executeStep(stepCtx, prompt, nil, nil)
result, err := a.executeStep(stepCtx, prompt, nil, files)
if err != nil {
return err
}
@@ -519,6 +525,12 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
// full TurnResult without printing anything. This is used by --json mode to
// capture structured output for serialization.
func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult, error) {
return a.RunOnceResultWithFiles(ctx, prompt, nil)
}
// RunOnceResultWithFiles executes a single agent step synchronously with
// optional multimodal file attachments and returns the full TurnResult.
func (a *App) RunOnceResultWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) (*kit.TurnResult, error) {
stepCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -526,7 +538,7 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
a.cancelStep = cancel
a.mu.Unlock()
return a.executeStep(stepCtx, prompt, nil, nil)
return a.executeStep(stepCtx, prompt, nil, files)
}
// RunOnceWithDisplay executes a single agent step synchronously, sending
@@ -540,6 +552,12 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
//
// Blocks until the step completes or ctx is cancelled.
func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn func(tea.Msg)) error {
return a.RunOnceWithDisplayAndFiles(ctx, prompt, eventFn, nil)
}
// RunOnceWithDisplayAndFiles executes a single agent step synchronously with
// optional multimodal file attachments, sending intermediate display events.
func (a *App) RunOnceWithDisplayAndFiles(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) error {
stepCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -547,7 +565,7 @@ func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn fun
a.cancelStep = cancel
a.mu.Unlock()
result, err := a.executeStep(stepCtx, prompt, eventFn, nil)
result, err := a.executeStep(stepCtx, prompt, eventFn, files)
if err != nil {
return err
}
@@ -870,6 +888,12 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
switch ev := e.(type) {
case kit.ToolCallEvent:
sendFn(ToolCallStartedEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs})
case kit.ToolCallStartEvent:
sendFn(ToolCallInputStartEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolKind: ev.ToolKind})
case kit.ToolCallDeltaEvent:
sendFn(ToolCallInputDeltaEvent{ToolCallID: ev.ToolCallID, Delta: ev.Delta})
case kit.ToolCallEndEvent:
sendFn(ToolCallInputEndEvent{ToolCallID: ev.ToolCallID})
case kit.ToolExecutionStartEvent:
sendFn(ToolExecutionEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs, IsStarting: true})
case kit.ToolExecutionEndEvent:
@@ -899,7 +923,23 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
case kit.SteerConsumedEvent:
sendFn(SteerConsumedEvent{})
case kit.StepUsageEvent:
a.recordStepUsage(ev, stepUsageSeen)
a.recordStepUsage(ev, stepUsageSeen, sendFn)
case kit.PasswordPromptEvent:
// Convert SDK PasswordPromptEvent to app PasswordPromptEvent
// The TUI will handle this and send the response back
responseCh := make(chan PasswordPromptResponse, 1)
sendFn(PasswordPromptEvent{
Prompt: ev.Prompt,
ResponseCh: responseCh,
})
// Wait for TUI response and forward to SDK
resp := <-responseCh
ev.ResponseCh <- kit.PasswordPromptResponse{
Password: resp.Password,
Cancelled: resp.Cancelled,
}
case kit.TurnEndEvent:
a.handleTurnEnd(ev, sendFn)
}
}))
@@ -910,6 +950,64 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
}
}
// handleTurnEnd inspects a turn's final StopReason and surfaces actionable
// feedback to the user when the turn ended in a state they can act on.
//
// Today the only surfaced case is FinishReasonLength — the model hit its
// configured max_output_tokens budget and the reply was truncated. Without
// this banner the TUI used to swallow the truncation silently, leading to
// "ghost" cut-offs with no indication of why.
//
// Separated from subscribeSDKEvents so tests can exercise it directly via a
// stubbed sendFn without standing up a full Kit.
func (a *App) handleTurnEnd(ev kit.TurnEndEvent, sendFn func(tea.Msg)) {
if sendFn == nil {
return
}
if ev.StopReason != kit.FinishReasonLength {
return
}
sendFn(ExtensionPrintEvent{
Level: "info",
Text: a.formatMaxTokensTruncatedMessage(),
})
}
// formatMaxTokensTruncatedMessage builds the user-facing explanation for a
// truncated turn. It reports the active max_output_tokens budget and, when
// known, the model's catalog output ceiling so the user can judge how much
// headroom is available.
func (a *App) formatMaxTokensTruncatedMessage() string {
k := a.opts.Kit
if k == nil {
// Extremely early / test-stub case: still emit a useful generic hint.
return "⚠ Response truncated: the model hit the configured max_output_tokens limit. " +
"Raise it with --max-tokens N, KIT_MAX_TOKENS=N, or per-model " +
"modelSettings[provider/model].maxTokens in config."
}
current := k.MaxTokens()
ceiling := k.MaxOutputLimit()
model := k.GetModelString()
msg := "⚠ Response truncated: "
if model != "" {
msg += fmt.Sprintf("%s hit the configured max_output_tokens limit", model)
} else {
msg += "the model hit the configured max_output_tokens limit"
}
if current > 0 {
msg += fmt.Sprintf(" (%d)", current)
}
msg += "."
if ceiling > 0 && current > 0 && ceiling > current {
msg += fmt.Sprintf(" This model supports up to %d output tokens.", ceiling)
}
msg += "\n\nRaise it with --max-tokens N, KIT_MAX_TOKENS=N, " +
"or per-model modelSettings[provider/model].maxTokens in your config. " +
"Re-run the last prompt after raising it to get the full response."
return msg
}
// QuitFromExtension triggers a graceful shutdown. In interactive mode it
// sends a tea.QuitMsg to the program so the TUI exits cleanly. In
// non-interactive mode it cancels the root context, stopping any in-flight
@@ -1143,7 +1241,16 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
// recordStepUsage applies token/cost usage reported for a completed step.
// Step usage events arrive even when a turn is later cancelled, so this keeps
// the usage widget accurate on all stop paths.
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
//
// Both session totals (cost, token counts) and the context window fill level
// are updated here so the status bar reflects progress after every LLM call,
// not just at the end of the full turn. Context fill monotonically increases
// across steps because each step re-sends the entire conversation plus any
// new tool results, so the numbers only go up.
//
// sendFn is called with a UsageUpdatedEvent to trigger a TUI re-render so
// the updated values are visible immediately.
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool, sendFn func(tea.Msg)) {
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
if a.opts.Debug {
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
@@ -1164,11 +1271,21 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
int(ev.CacheReadTokens),
int(ev.CacheWriteTokens),
)
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
// at turn completion via updateUsageFromTurnResult, which sums all token
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
// Per-step context tokens would cause the display to jump around during
// multi-step tool calls.
// Update context window fill from this step's usage. Each step sends
// the full conversation to the LLM, so the reported token counts
// represent the actual context utilization at that point.
contextFill := int(ev.InputTokens) + int(ev.CacheReadTokens) + int(ev.CacheWriteTokens) + int(ev.OutputTokens)
if contextFill > 0 {
if a.opts.Debug {
log.Printf("[DEBUG] recordStepUsage: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheWrite=%d + Output=%d)",
contextFill, ev.InputTokens, ev.CacheReadTokens, ev.CacheWriteTokens, ev.OutputTokens)
}
a.opts.UsageTracker.SetContextTokens(contextFill)
}
// Notify the TUI so it re-renders the status bar with updated values.
if sendFn != nil {
sendFn(UsageUpdatedEvent{})
}
}
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
+104 -7
View File
@@ -3,10 +3,12 @@ package app
import (
"context"
"errors"
"strings"
"sync"
"testing"
"time"
tea "charm.land/bubbletea/v2"
kit "github.com/mark3labs/kit/pkg/kit"
)
@@ -532,9 +534,9 @@ func TestQueueLength_reflects(t *testing.T) {
}
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
// recorded immediately for cost tracking. Context tokens are NOT updated here
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
// tool calls.
// recorded immediately for cost tracking. Context tokens are also updated so
// the status bar reflects context fill after every LLM call in a multi-step
// turn, not just at the end.
func TestRecordStepUsage_updatesTracker(t *testing.T) {
usage := &usageUpdaterStub{}
app := New(Options{UsageTracker: usage}, nil)
@@ -545,7 +547,7 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
OutputTokens: 45,
CacheReadTokens: 5,
CacheWriteTokens: 2,
}, nil)
}, nil, nil)
usage.mu.Lock()
defer usage.mu.Unlock()
@@ -557,9 +559,13 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
}
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
if usage.contextCalls != 0 {
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
// Context tokens should now be updated per-step (Input + CacheRead + CacheWrite + Output).
if usage.contextCalls != 1 {
t.Fatalf("expected 1 context token update from recordStepUsage, got %d", usage.contextCalls)
}
expectedContext := 120 + 45 + 5 + 2
if usage.lastContextTokens != expectedContext {
t.Fatalf("expected context tokens %d, got %d", expectedContext, usage.lastContextTokens)
}
}
@@ -666,3 +672,94 @@ func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T)
expected, usage.contextCalls, usage.lastContextTokens)
}
}
// TestHandleTurnEnd_LengthEmitsWarning verifies that when the SDK reports a
// FinishReasonLength (max_output_tokens hit), the app surfaces a user-visible
// ExtensionPrintEvent with Level="info" so the TUI can render a banner
// instead of silently showing a truncated reply.
func TestHandleTurnEnd_LengthEmitsWarning(t *testing.T) {
app := New(Options{}, nil)
defer app.Close()
var mu sync.Mutex
var received []tea.Msg
sendFn := func(m tea.Msg) {
mu.Lock()
defer mu.Unlock()
received = append(received, m)
}
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, sendFn)
mu.Lock()
defer mu.Unlock()
if len(received) != 1 {
t.Fatalf("expected 1 event on length stop, got %d", len(received))
}
ev, ok := received[0].(ExtensionPrintEvent)
if !ok {
t.Fatalf("expected ExtensionPrintEvent, got %T", received[0])
}
if ev.Level != "info" {
t.Errorf("expected Level=info, got %q", ev.Level)
}
if ev.Text == "" {
t.Error("expected non-empty warning text")
}
if !strings.Contains(ev.Text, "max_output_tokens") {
t.Errorf("warning text should mention max_output_tokens, got: %s", ev.Text)
}
}
// TestHandleTurnEnd_NonLengthIgnored verifies that ordinary stop reasons
// (stop, tool-calls, error, unknown, "") do not produce a warning banner.
func TestHandleTurnEnd_NonLengthIgnored(t *testing.T) {
app := New(Options{}, nil)
defer app.Close()
reasons := []string{
kit.FinishReasonStop,
kit.FinishReasonToolCalls,
kit.FinishReasonError,
kit.FinishReasonContentFilter,
kit.FinishReasonOther,
kit.FinishReasonUnknown,
"",
}
for _, r := range reasons {
var called bool
app.handleTurnEnd(kit.TurnEndEvent{StopReason: r}, func(m tea.Msg) {
called = true
})
if called {
t.Errorf("stop reason %q unexpectedly emitted a warning", r)
}
}
}
// TestHandleTurnEnd_NilSendFn guards against panics when no TUI listener is
// attached (e.g. early init or headless teardown).
func TestHandleTurnEnd_NilSendFn(t *testing.T) {
app := New(Options{}, nil)
defer app.Close()
// Should not panic with a nil sendFn.
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, nil)
}
// TestFormatMaxTokensTruncatedMessage_NoKit verifies the fallback message
// when Options.Kit is nil (test/stub path).
func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
app := New(Options{}, nil)
defer app.Close()
msg := app.formatMaxTokensTruncatedMessage()
if msg == "" {
t.Fatal("expected non-empty fallback message")
}
for _, needle := range []string{"max_output_tokens", "--max-tokens", "KIT_MAX_TOKENS", "modelSettings"} {
if !strings.Contains(msg, needle) {
t.Errorf("fallback message missing %q:\n%s", needle, msg)
}
}
}
+54
View File
@@ -32,6 +32,36 @@ type ToolCallStartedEvent struct {
ToolArgs string
}
// ToolCallInputStartEvent is sent when the LLM begins generating tool call
// arguments. The tool name is known but the full argument JSON is still being
// streamed. UIs can use this to show a "running" indicator immediately instead
// of waiting for the full argument JSON to finish streaming.
type ToolCallInputStartEvent struct {
// ToolCallID is the stable identifier for correlating tool lifecycle events.
ToolCallID string
// ToolName is the name of the tool being called.
ToolName string
// ToolKind classifies the tool: "execute", "edit", "read", "search", "agent".
ToolKind string
}
// ToolCallInputDeltaEvent is sent for each streamed fragment of tool call
// arguments as they arrive from the LLM. Useful for live-previewing content
// or showing a progress indicator with byte count.
type ToolCallInputDeltaEvent struct {
// ToolCallID is the stable identifier for correlating tool lifecycle events.
ToolCallID string
// Delta is a JSON fragment of tool call arguments.
Delta string
}
// ToolCallInputEndEvent is sent when tool argument streaming is complete,
// before the tool call is parsed and execution begins.
type ToolCallInputEndEvent struct {
// ToolCallID is the stable identifier for correlating tool lifecycle events.
ToolCallID string
}
// ToolExecutionEvent is sent when a tool starts or finishes executing.
// The IsStarting flag distinguishes between the start and end of execution.
type ToolExecutionEvent struct {
@@ -79,6 +109,24 @@ type ToolCallContentEvent struct {
Content string
}
// PasswordPromptEvent is sent when a sudo command needs a password.
// The TUI should display a password prompt overlay and send the result back.
type PasswordPromptEvent struct {
// Prompt is the message to display to the user.
Prompt string
// ResponseCh receives the password from the TUI.
// The TUI must send exactly one value.
ResponseCh chan<- PasswordPromptResponse
}
// PasswordPromptResponse carries the user's password input.
type PasswordPromptResponse struct {
// Password is the entered password.
Password string
// Cancelled is true if the user cancelled the prompt.
Cancelled bool
}
// ResponseCompleteEvent is sent when the LLM produces a final (non-streaming) response.
// In streaming mode, this may be empty if all content was delivered via StreamChunkEvents.
type ResponseCompleteEvent struct {
@@ -162,6 +210,12 @@ type ModelChangedEvent struct {
ModelName string
}
// UsageUpdatedEvent is sent after each completed LLM step to notify the TUI
// that token counts and costs have changed. The UsageTracker is updated
// in-place before this event is sent; the TUI just needs to re-render to
// reflect the new values in the status bar.
type UsageUpdatedEvent struct{}
// WidgetUpdateEvent is sent when an extension adds, updates, or removes a
// widget via ctx.SetWidget or ctx.RemoveWidget. The TUI re-reads widget state
// from its WidgetProvider on the next render cycle.
+3 -6
View File
@@ -3,24 +3,21 @@ package app
import (
"testing"
"charm.land/fantasy"
kit "github.com/mark3labs/kit/pkg/kit"
)
// makeTextMsg builds a minimal kit.LLMMessage using fantasy.NewUserMessage
// or constructing with the given role.
// makeTextMsg builds a minimal kit.LLMMessage with the given role and text.
func makeTextMsg(role, text string) kit.LLMMessage {
return kit.LLMMessage{
Role: kit.LLMMessageRole(role),
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
Content: []kit.LLMMessagePart{kit.LLMTextPart{Text: text}},
}
}
// textOf extracts the plain text from an LLMMessage for assertions.
func textOf(msg kit.LLMMessage) string {
for _, part := range msg.Content {
if tp, ok := part.(fantasy.TextPart); ok {
if tp, ok := part.(kit.LLMTextPart); ok {
return tp.Text
}
}
+8
View File
@@ -471,5 +471,13 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
return envKey, "ANTHROPIC_API_KEY environment variable", nil
}
// Check if OpenAI credentials exist to provide a helpful suggestion
if cm != nil {
hasOpenAI, _ := cm.HasOpenAICredentials()
if hasOpenAI {
return "", "", fmt.Errorf("no Anthropic API key found. Use 'kit auth login anthropic', set ANTHROPIC_API_KEY environment variable, or use --provider-api-key flag\n\nNote: OpenAI credentials were detected. To use OpenAI, run with --model openai/gpt-5.4 or set it as default:\n kit auth login openai --set-default")
}
}
return "", "", fmt.Errorf("no Anthropic API key found. Use 'kit auth login anthropic', set ANTHROPIC_API_KEY environment variable, or use --provider-api-key flag")
}
+81 -8
View File
@@ -22,6 +22,45 @@ type MCPServerConfig struct {
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
// OAuth configuration for remote servers that don't support dynamic
// client registration (e.g. GitHub). When OAuthClientID is set, it is
// passed directly to the transport's OAuthConfig instead of relying on
// dynamic registration.
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
// NoOAuth disables OAuth transport configuration for this server, even
// when the connection pool has an auth handler. Use this for public MCP
// servers (e.g. PubMed) that don't require authentication. Without this
// flag, the pool would attach OAuth transport to every remote server,
// causing proactive dynamic-client-registration attempts that fail on
// servers that don't support it.
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
// TasksMode controls when this server's tools/call requests are augmented
// with MCP task metadata (turning a synchronous call into an asynchronous,
// pollable job — see https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks).
//
// Valid values:
// - "" or "auto": (default) augment requests with task metadata only
// when the server advertises tasks/toolCalls capability during initialize.
// - "never": never augment — every tool call is synchronous, regardless
// of server capability.
// - "always": always augment, even when the server didn't advertise
// task support. The server may still respond synchronously; this just
// opts in unconditionally on the client side.
//
// In all modes, when the server returns a CreateTaskResult the client polls
// tasks/get / tasks/result until the task reaches a terminal state.
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
// InProcessServer holds a live *server.MCPServer for in-process transport.
// When set (and Type is "inprocess"), the connection pool creates an
// in-process client instead of spawning a subprocess or making HTTP calls.
// This field is never serialized — it is only used programmatically via the SDK.
InProcessServer any `json:"-" yaml:"-"`
// Legacy fields for backward compatibility
Transport string `json:"transport,omitempty"`
Args []string `json:"args,omitempty"`
@@ -35,13 +74,18 @@ type MCPServerConfig struct {
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
// First try to unmarshal as the new format
type newFormat struct {
Type string `json:"type"`
Command []string `json:"command,omitempty"`
Environment map[string]string `json:"environment,omitempty"`
URL string `json:"url,omitempty"`
Headers []string `json:"headers,omitempty"`
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
Type string `json:"type"`
Command []string `json:"command,omitempty"`
Environment map[string]string `json:"environment,omitempty"`
URL string `json:"url,omitempty"`
Headers []string `json:"headers,omitempty"`
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
}
// Also try legacy format
@@ -54,6 +98,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
Headers []string `json:"headers,omitempty"`
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
}
// Try new format first
@@ -66,6 +111,11 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
s.Headers = newConfig.Headers
s.AllowedTools = newConfig.AllowedTools
s.ExcludedTools = newConfig.ExcludedTools
s.OAuthClientID = newConfig.OAuthClientID
s.OAuthClientSecret = newConfig.OAuthClientSecret
s.OAuthScopes = newConfig.OAuthScopes
s.NoOAuth = newConfig.NoOAuth
s.TasksMode = newConfig.TasksMode
return nil
}
@@ -86,6 +136,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
s.Headers = legacyConfig.Headers
s.AllowedTools = legacyConfig.AllowedTools
s.ExcludedTools = legacyConfig.ExcludedTools
s.TasksMode = legacyConfig.TasksMode
// Infer type from legacy format for better compatibility
// Only set Type when it doesn't change existing transport behavior
@@ -263,11 +314,18 @@ func (s *MCPServerConfig) GetTransportType() string {
return "stdio"
case "remote":
return "streamable"
case "inprocess":
return "inprocess"
default:
return s.Type
}
}
// Programmatic in-process server detection.
if s.InProcessServer != nil {
return "inprocess"
}
// Backward compatibility: infer transport type
if len(s.Command) > 0 {
return "stdio"
@@ -287,6 +345,17 @@ func (c *Config) Validate() error {
return fmt.Errorf("server %s: allowedTools and excludedTools are mutually exclusive", serverName)
}
// Reject unknown tasksMode values up front so a typo (e.g. "alwasy")
// fails loud here instead of being silently downgraded to "auto" by
// the runtime parser. Comparison is case-insensitive to match
// tools.ParseTaskMode.
switch strings.ToLower(strings.TrimSpace(serverConfig.TasksMode)) {
case "", "auto", "never", "always":
// ok
default:
return fmt.Errorf("server %s: invalid tasksMode %q (expected one of: auto, never, always)", serverName, serverConfig.TasksMode)
}
transport := serverConfig.GetTransportType()
switch transport {
case "stdio":
@@ -298,8 +367,12 @@ func (c *Config) Validate() error {
if serverConfig.URL == "" {
return fmt.Errorf("server %s: url is required for %s transport", serverName, transport)
}
case "inprocess":
if serverConfig.InProcessServer == nil {
return fmt.Errorf("server %s: InProcessServer is required for inprocess transport", serverName)
}
default:
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable", serverName, transport)
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable, inprocess", serverName, transport)
}
}
return nil
+174
View File
@@ -6,6 +6,8 @@ import (
"path/filepath"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
func TestMCPServerConfig_NewFormat(t *testing.T) {
@@ -542,3 +544,175 @@ func TestEnsureConfigExistsWhenFileExists(t *testing.T) {
t.Error("Existing config file was modified when it shouldn't have been")
}
}
func TestMCPServerConfig_OAuthFields_JSON(t *testing.T) {
jsonData := `{
"type": "remote",
"url": "https://api.githubcopilot.com/mcp/",
"oauthClientId": "Ov23liXXXXXXXXXXXXXX",
"oauthClientSecret": "secret123",
"oauthScopes": ["read:user", "repo"]
}`
var cfg MCPServerConfig
err := json.Unmarshal([]byte(jsonData), &cfg)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if cfg.Type != "remote" {
t.Errorf("Expected type 'remote', got %q", cfg.Type)
}
if cfg.URL != "https://api.githubcopilot.com/mcp/" {
t.Errorf("Expected URL, got %q", cfg.URL)
}
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
}
if cfg.OAuthClientSecret != "secret123" {
t.Errorf("Expected OAuthClientSecret 'secret123', got %q", cfg.OAuthClientSecret)
}
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
}
}
func TestMCPServerConfig_OAuthFields_YAML(t *testing.T) {
yamlData := `
type: remote
url: https://api.githubcopilot.com/mcp/
oauthClientId: "Ov23liXXXXXXXXXXXXXX"
oauthScopes:
- read:user
- repo
`
var cfg MCPServerConfig
err := yaml.Unmarshal([]byte(yamlData), &cfg)
if err != nil {
t.Fatalf("Failed to unmarshal YAML: %v", err)
}
if cfg.Type != "remote" {
t.Errorf("Expected type 'remote', got %q", cfg.Type)
}
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
}
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
}
}
func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
// Verify that omitting OAuth fields still works (backward compat).
jsonData := `{
"type": "remote",
"url": "https://example.com/mcp"
}`
var cfg MCPServerConfig
err := json.Unmarshal([]byte(jsonData), &cfg)
if err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if cfg.OAuthClientID != "" {
t.Errorf("Expected empty OAuthClientID, got %q", cfg.OAuthClientID)
}
if cfg.OAuthClientSecret != "" {
t.Errorf("Expected empty OAuthClientSecret, got %q", cfg.OAuthClientSecret)
}
if len(cfg.OAuthScopes) != 0 {
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
}
}
func TestMCPServerConfig_TasksMode_NewFormat(t *testing.T) {
jsonData := `{
"type": "remote",
"url": "https://my-mcp-server.com",
"tasksMode": "always"
}`
var cfg MCPServerConfig
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if cfg.TasksMode != "always" {
t.Errorf("expected TasksMode 'always', got %q", cfg.TasksMode)
}
}
func TestMCPServerConfig_TasksMode_LegacyFormat(t *testing.T) {
// tasksMode also recognised in the legacy unmarshal path so users on
// the older command/args shape can opt in without migrating.
jsonData := `{
"command": "npx",
"args": ["@modelcontextprotocol/server-filesystem", "/path"],
"tasksMode": "never"
}`
var cfg MCPServerConfig
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if cfg.TasksMode != "never" {
t.Errorf("expected TasksMode 'never', got %q", cfg.TasksMode)
}
}
func TestMCPServerConfig_TasksMode_DefaultEmpty(t *testing.T) {
// When tasksMode is not set the field stays empty, which downstream
// resolves to "auto" via tools.ParseTaskMode.
jsonData := `{"type":"remote","url":"https://x.example"}`
var cfg MCPServerConfig
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if cfg.TasksMode != "" {
t.Errorf("expected default TasksMode to be empty, got %q", cfg.TasksMode)
}
}
func TestConfig_Validate_TasksMode(t *testing.T) {
t.Run("empty is valid", func(t *testing.T) {
cfg := &Config{
MCPServers: map[string]MCPServerConfig{
"a": {Type: "remote", URL: "https://x.example"},
},
}
if err := cfg.Validate(); err != nil {
t.Errorf("empty TasksMode should validate, got %v", err)
}
})
t.Run("known values are valid", func(t *testing.T) {
for _, mode := range []string{"auto", "never", "always", "AUTO", " always "} {
cfg := &Config{
MCPServers: map[string]MCPServerConfig{
"a": {Type: "remote", URL: "https://x.example", TasksMode: mode},
},
}
if err := cfg.Validate(); err != nil {
t.Errorf("TasksMode=%q should validate, got %v", mode, err)
}
}
})
t.Run("typo is rejected with a clear error", func(t *testing.T) {
cfg := &Config{
MCPServers: map[string]MCPServerConfig{
"buildbot": {Type: "remote", URL: "https://x.example", TasksMode: "alwasy"},
},
}
err := cfg.Validate()
if err == nil {
t.Fatal("expected validation error for invalid TasksMode")
}
// Error must mention the server name AND the bad value so the
// user knows where to look.
msg := err.Error()
if !strings.Contains(msg, "buildbot") || !strings.Contains(msg, `"alwasy"`) {
t.Errorf("error %q should mention both server name and bad value", msg)
}
})
}
+176 -6
View File
@@ -19,10 +19,18 @@ import (
// It receives tool call ID, tool name, output chunk, and whether it's stderr.
type ToolOutputCallback func(toolCallID, toolName, chunk string, isStderr bool)
// PasswordPromptCallback is the signature for password prompts.
// It receives a prompt message and returns the password and whether it was cancelled.
type PasswordPromptCallback func(prompt string) (password string, cancelled bool)
// contextKey is a custom type for context keys to avoid collisions.
type contextKey string
const toolOutputCallbackKey contextKey = "toolOutputCallback"
const (
toolOutputCallbackKey contextKey = "toolOutputCallback"
sudoPasswordKey contextKey = "sudoPassword"
passwordPromptKey contextKey = "passwordPrompt"
)
// ContextWithToolOutputCallback returns a new context with the tool output callback set.
func ContextWithToolOutputCallback(ctx context.Context, callback ToolOutputCallback) context.Context {
@@ -37,6 +45,34 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
return nil
}
// ContextWithPasswordPrompt returns a new context with the password prompt callback set.
// This allows the TUI to show a modal password prompt when sudo needs a password.
func ContextWithPasswordPrompt(ctx context.Context, callback PasswordPromptCallback) context.Context {
return context.WithValue(ctx, passwordPromptKey, callback)
}
// passwordPromptFromContext retrieves the password prompt callback from context.
func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
if cb, ok := ctx.Value(passwordPromptKey).(PasswordPromptCallback); ok {
return cb
}
return nil
}
// ContextWithSudoPassword returns a new context with the sudo password set.
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
return context.WithValue(ctx, sudoPasswordKey, password)
}
// sudoPasswordFromContext retrieves the sudo password from context.
func sudoPasswordFromContext(ctx context.Context) string {
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
return pw
}
return ""
}
const defaultBashTimeout = 120 * time.Second
const maxBashTimeout = 600 * time.Second
@@ -73,6 +109,66 @@ func NewBashTool(opts ...ToolOption) fantasy.AgentTool {
}
}
// sudoCommandRe matches sudo commands that need to be rewritten for -S mode.
// It matches "sudo" as a word boundary, optionally preceded by environment variables.
var sudoCommandRe = regexp.MustCompile(`(?i)(^|[&|;|]|\|\||&&)\s*(\w+=\S+\s+)?\bsudo\b`)
// truncateCommand truncates a long command for display.
func truncateCommand(cmd string, maxLen int) string {
if len(cmd) <= maxLen {
return cmd
}
return cmd[:maxLen-3] + "..."
}
// rewriteSudoForStdin rewrites sudo commands to use -S -p ” for stdin password input.
// It transforms: sudo cmd → sudo -S -p ” cmd
func rewriteSudoForStdin(command string) string {
// Find all matches and their positions
matches := sudoCommandRe.FindAllStringIndex(command, -1)
if matches == nil {
return command
}
// Build result from end to start to preserve indices
result := command
for i := len(matches) - 1; i >= 0; i-- {
match := matches[i]
start, end := match[0], match[1]
matchedText := result[start:end]
// Extract just the "sudo" part (after any prefix)
sudoIdx := strings.Index(strings.ToLower(matchedText), "sudo")
if sudoIdx == -1 {
continue
}
prefix := matchedText[:sudoIdx]
sudoPart := matchedText[sudoIdx:]
// Check if the text immediately after "sudo" in the result contains -S
afterSudo := result[end:]
if strings.HasPrefix(strings.TrimLeft(afterSudo, " \t"), "-S") {
// Already has -S flag, skip
continue
}
// Insert -S -p '' after "sudo"
newSudo := strings.Replace(sudoPart, "sudo", "sudo -S -p ''", 1)
result = result[:start] + prefix + newSudo + result[end:]
}
return result
}
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
// This is stored in tool response metadata to signal the TUI to prompt for password.
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
return resp.Metadata == SudoPasswordRequiredMetadata
}
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
var args bashArgs
if err := parseArgs(call.Input, &args); err != nil {
@@ -97,7 +193,47 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "bash", "-c", args.Command)
// Check for sudo password in context or environment
sudoPassword := sudoPasswordFromContext(ctx)
if sudoPassword == "" {
sudoPassword = os.Getenv("SUDO_PASSWORD")
}
command := args.Command
// If command contains sudo and we don't have a password, check if sudo needs one
if sudoPassword == "" && sudoCommandRe.MatchString(command) {
// Check if sudo credentials are cached using sudo -n (non-interactive)
testCmd := exec.CommandContext(cmdCtx, "sudo", "-n", "true")
testCmd.Dir = workDir
if err := testCmd.Run(); err != nil {
// Sudo needs a password - try to prompt via callback
if promptCallback := passwordPromptFromContext(ctx); promptCallback != nil {
pw, cancelled := promptCallback("Sudo password required for: " + truncateCommand(args.Command, 60))
if cancelled {
return fantasy.NewTextErrorResponse("sudo password prompt cancelled"), nil
}
if pw == "" {
return fantasy.NewTextErrorResponse("no sudo password provided"), nil
}
sudoPassword = pw
command = rewriteSudoForStdin(command)
} else {
// No callback available - return error with helpful message
return fantasy.NewTextErrorResponse(
"This command requires sudo access. " +
"Please run 'sudo -v' in your terminal first to cache credentials, " +
"or set the SUDO_PASSWORD environment variable."), nil
}
}
// Credentials are cached or password was provided, proceed
}
// If we have a sudo password, rewrite the command to use sudo -S
if sudoPassword != "" && sudoCommandRe.MatchString(command) {
command = rewriteSudoForStdin(command)
}
cmd := exec.CommandContext(cmdCtx, "bash", "-c", command)
if workDir != "" {
cmd.Dir = workDir
}
@@ -115,18 +251,18 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
if outputCallback != nil {
// Streaming mode: use pipes to capture output as it arrives
return executeBashStreaming(cmdCtx, call, cmd, outputCallback)
return executeBashStreaming(cmdCtx, call, cmd, outputCallback, sudoPassword)
}
// Non-streaming mode: collect all output at once (original behavior)
return executeBashBuffered(cmdCtx, call, cmd)
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
}
// executeBashBuffered collects all output before returning (original behavior).
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
// close them when grandchild processes hold pipe handles open after the
// direct child exits.
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd) (fantasy.ToolResponse, error) {
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
@@ -136,10 +272,27 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
}
// If we have a sudo password, create a stdin pipe and write the password
var stdinPipe io.WriteCloser
if sudoPassword != "" {
stdinPipe, err = cmd.StdinPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
}
}
if err := cmd.Start(); err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
}
// Write password to stdin if needed, then close stdin
if sudoPassword != "" && stdinPipe != nil {
go func() {
defer func() { _ = stdinPipe.Close() }()
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
}()
}
// Read pipes concurrently
var wg sync.WaitGroup
var stdout, stderr strings.Builder
@@ -181,7 +334,7 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
}
// executeBashStreaming streams output as it arrives via the callback.
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback) (fantasy.ToolResponse, error) {
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
@@ -191,11 +344,28 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
}
// If we have a sudo password, create a stdin pipe
var stdinPipe io.WriteCloser
if sudoPassword != "" {
stdinPipe, err = cmd.StdinPipe()
if err != nil {
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
}
}
// Start command execution
if err := cmd.Start(); err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
}
// Write password to stdin if needed, then close stdin
if sudoPassword != "" && stdinPipe != nil {
go func() {
defer func() { _ = stdinPipe.Close() }()
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
}()
}
// Stream stdout and stderr concurrently
var wg sync.WaitGroup
var mu sync.Mutex
+69
View File
@@ -127,3 +127,72 @@ func TestBash_EmptyCommand(t *testing.T) {
t.Fatal("expected error for empty command")
}
}
func TestRewriteSudoForStdin(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple sudo",
input: "sudo apt update",
expected: "sudo -S -p '' apt update",
},
{
name: "sudo with env var",
input: "DEBIAN_FRONTEND=noninteractive sudo apt update",
expected: "DEBIAN_FRONTEND=noninteractive sudo -S -p '' apt update",
},
{
name: "sudo in pipeline",
input: "echo test | sudo tee /etc/test.conf",
expected: "echo test | sudo -S -p '' tee /etc/test.conf",
},
{
name: "sudo after &&",
input: "apt update && sudo apt upgrade",
expected: "apt update && sudo -S -p '' apt upgrade",
},
{
name: "already has -S flag",
input: "sudo -S apt update",
expected: "sudo -S apt update",
},
{
name: "no sudo",
input: "apt update && apt upgrade",
expected: "apt update && apt upgrade",
},
{
name: "sudo in string (should not match)",
input: "echo 'use sudo carefully'",
expected: "echo 'use sudo carefully'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := rewriteSudoForStdin(tt.input)
if result != tt.expected {
t.Errorf("rewriteSudoForStdin(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestSudoPasswordFromContext(t *testing.T) {
// Test with password in context
ctx := ContextWithSudoPassword(context.Background(), "secret123")
pw := sudoPasswordFromContext(ctx)
if pw != "secret123" {
t.Errorf("expected password 'secret123', got %q", pw)
}
// Test without password
ctx = context.Background()
pw = sudoPasswordFromContext(ctx)
if pw != "" {
t.Errorf("expected empty password, got %q", pw)
}
}
+6 -42
View File
@@ -21,12 +21,9 @@ type Edit struct {
}
// editArgs holds the arguments for the edit tool.
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
type editArgs struct {
Path string `json:"path"`
OldText string `json:"old_text"` // Single-edit mode
NewText string `json:"new_text"` // Single-edit mode
Edits []Edit `json:"edits"` // Multi-edit mode
Path string `json:"path"`
Edits []Edit `json:"edits"`
}
// replacement represents a normalized edit ready for processing.
@@ -52,20 +49,12 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
return &coreTool{
info: fantasy.ToolInfo{
Name: "edit",
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
Description: "Edit a file by replacing exact text. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
Parameters: map[string]any{
"path": map[string]any{
"type": "string",
"description": "Path to the file to edit (relative or absolute)",
},
"old_text": map[string]any{
"type": "string",
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
},
"new_text": map[string]any{
"type": "string",
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
},
"edits": map[string]any{
"type": "array",
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
@@ -85,7 +74,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
},
},
},
Required: []string{"path"},
Required: []string{"path", "edits"},
},
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
return executeEdit(ctx, call, cfg.WorkDir)
@@ -163,36 +152,11 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
}
// normalizeEditInput validates and normalizes the edit input.
// Returns error if both single-edit and multi-edit modes are used.
func normalizeEditInput(args editArgs) ([]replacement, error) {
singleMode := args.OldText != "" || args.NewText != ""
multiMode := len(args.Edits) > 0
if singleMode && multiMode {
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
if len(args.Edits) == 0 {
return nil, fmt.Errorf("edits array is required and must not be empty")
}
if !singleMode && !multiMode {
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
}
if singleMode {
if args.OldText == "" {
return nil, fmt.Errorf("old_text is required when using single-edit mode")
}
if args.NewText == "" {
return nil, fmt.Errorf("new_text is required when using single-edit mode")
}
return []replacement{{
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
originalOld: args.OldText,
originalNew: args.NewText,
index: 0,
}}, nil
}
// Multi-edit mode
var reps []replacement
for i, edit := range args.Edits {
if edit.OldText == "" {
+62 -44
View File
@@ -389,9 +389,11 @@ func TestExecuteEdit_ExactMatch(t *testing.T) {
writeFileOrFail(t, path, original)
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "fmt.Println(\"hello\")",
NewText: "fmt.Println(\"world\")",
Path: path,
Edits: []Edit{{
OldText: "fmt.Println(\"hello\")",
NewText: "fmt.Println(\"world\")",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -426,9 +428,11 @@ func TestExecuteEdit_ExactMatch_DoesNotCorruptRest(t *testing.T) {
target := lines[49]
replacement := "REPLACED_LINE_50"
input, _ := json.Marshal(editArgs{
Path: path,
OldText: target,
NewText: replacement,
Path: path,
Edits: []Edit{{
OldText: target,
NewText: replacement,
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -470,9 +474,11 @@ func TestExecuteEdit_FuzzyMatch_TrailingWhitespace(t *testing.T) {
// Search without trailing whitespace (common LLM behavior)
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "func foo() {\n\treturn 1\n}",
NewText: "func foo() {\n\treturn 2\n}",
Path: path,
Edits: []Edit{{
OldText: "func foo() {\n\treturn 1\n}",
NewText: "func foo() {\n\treturn 2\n}",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -519,9 +525,11 @@ func TestExecuteEdit_FuzzyMatch_DoesNotCorruptRest(t *testing.T) {
search := strings.Repeat("x", 10) + "\n" + strings.Repeat("x", 10)
// But this matches lines 1-2, 2-3, etc. — should fail due to ambiguity.
input, _ := json.Marshal(editArgs{
Path: path,
OldText: search,
NewText: "REPLACED",
Path: path,
Edits: []Edit{{
OldText: search,
NewText: "REPLACED",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -546,9 +554,11 @@ func TestExecuteEdit_MultipleMatches_Fails(t *testing.T) {
writeFileOrFail(t, path, "hello\nworld\nhello\n")
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "hello",
NewText: "goodbye",
Path: path,
Edits: []Edit{{
OldText: "hello",
NewText: "goodbye",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -575,9 +585,11 @@ func TestExecuteEdit_NoMatch_Fails(t *testing.T) {
writeFileOrFail(t, path, "hello world\n")
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "nonexistent text",
NewText: "replacement",
Path: path,
Edits: []Edit{{
OldText: "nonexistent text",
NewText: "replacement",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -601,9 +613,11 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
writeFileOrFail(t, path, "line1\r\nline2\r\nline3\r\n")
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "line2",
NewText: "LINE2",
Path: path,
Edits: []Edit{{
OldText: "line2",
NewText: "LINE2",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -622,8 +636,10 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
func TestExecuteEdit_MissingPath(t *testing.T) {
input, _ := json.Marshal(editArgs{
OldText: "x",
NewText: "y",
Edits: []Edit{{
OldText: "x",
NewText: "y",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
if err != nil {
@@ -636,9 +652,11 @@ func TestExecuteEdit_MissingPath(t *testing.T) {
func TestExecuteEdit_NonexistentFile(t *testing.T) {
input, _ := json.Marshal(editArgs{
Path: "/tmp/nonexistent_edit_test_file_12345.go",
OldText: "x",
NewText: "y",
Path: "/tmp/nonexistent_edit_test_file_12345.go",
Edits: []Edit{{
OldText: "x",
NewText: "y",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
if err != nil {
@@ -661,9 +679,11 @@ func TestExecuteEdit_DiffContainsHunkHeader(t *testing.T) {
writeFileOrFail(t, path, strings.Join(lines, "\n")+"\n")
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "line_10_content",
NewText: "REPLACED",
Path: path,
Edits: []Edit{{
OldText: "line_10_content",
NewText: "REPLACED",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -684,9 +704,11 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
writeFileOrFail(t, path, "old content\n")
input, _ := json.Marshal(editArgs{
Path: path,
OldText: "old content",
NewText: "new content",
Path: path,
Edits: []Edit{{
OldText: "old content",
NewText: "new content",
}},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -905,18 +927,14 @@ func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
}
}
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
func TestExecuteEdit_EmptyEditsArray_Fails(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "mixed.txt")
path := filepath.Join(dir, "empty.txt")
writeFileOrFail(t, path, "hello\n")
input, _ := json.Marshal(map[string]any{
"path": path,
"old_text": "hello",
"new_text": "HELLO",
"edits": []Edit{
{OldText: "hello", NewText: "HI"},
},
input, _ := json.Marshal(editArgs{
Path: path,
Edits: []Edit{},
})
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
@@ -924,10 +942,10 @@ func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
t.Fatalf("executeEdit error: %v", err)
}
if !resp.IsError {
t.Error("expected error when mixing single and multi-edit modes")
t.Error("expected error for empty edits array")
}
if !strings.Contains(resp.Content, "cannot use") {
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
if !strings.Contains(resp.Content, "required") {
t.Errorf("expected 'required' in error, got: %s", resp.Content)
}
}
+2 -2
View File
@@ -86,7 +86,7 @@ Example use cases:
},
"model": map[string]any{
"type": "string",
"description": "Optional model override (e.g. 'anthropic/claude-haiku-3-5-20241022' for faster/cheaper tasks)",
"description": "Optional model override. Empty string uses the current model.",
},
"system_prompt": map[string]any{
"type": "string",
@@ -94,7 +94,7 @@ Example use cases:
},
"timeout_seconds": map[string]any{
"type": "number",
"description": "Maximum execution time in seconds (default: 300, max: 1800)",
"description": "Maximum execution time in seconds (default: 300, max: 1800, minimum recommended: 240)",
},
},
Required: []string{"task"},
+202 -1
View File
@@ -918,7 +918,7 @@ type ExtensionEntry struct {
type ContextMessage struct {
// Index is the position of this message in the original context array
// (0-based). When returning messages from a ContextPrepareResult,
// messages with Index >= 0 reuse the original fantasy.Message at that
// messages with Index >= 0 reuse the original LLM message at that
// position (preserving tool calls, reasoning, and other complex parts).
// Set Index to -1 for newly injected messages (created from Role + Content).
Index int
@@ -1063,6 +1063,9 @@ type PrintBlockOpts struct {
type API struct {
// Event-specific registration functions (wired by the loader).
onToolCall func(func(ToolCallEvent, Context) *ToolCallResult)
onToolCallInputStart func(func(ToolCallInputStartEvent, Context))
onToolCallInputDelta func(func(ToolCallInputDeltaEvent, Context))
onToolCallInputEnd func(func(ToolCallInputEndEvent, Context))
onToolExecStart func(func(ToolExecutionStartEvent, Context))
onToolExecEnd func(func(ToolExecutionEndEvent, Context))
onToolOutput func(func(ToolOutputEvent, Context))
@@ -1091,6 +1094,14 @@ type API struct {
onSubagentStart func(func(SubagentStartEvent, Context))
onSubagentChunk func(func(SubagentChunkEvent, Context))
onSubagentEnd func(func(SubagentEndEvent, Context))
onStepStart func(func(StepStartEvent, Context))
onStepFinish func(func(StepFinishEvent, Context))
onReasoningStart func(func(ReasoningStartEvent, Context))
onWarnings func(func(WarningsEvent, Context))
onSource func(func(SourceEvent, Context))
onError func(func(ErrorEvent, Context))
onRetry func(func(RetryEvent, Context))
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
}
// OnToolCall registers a handler that fires before a tool executes.
@@ -1099,6 +1110,26 @@ func (a *API) OnToolCall(handler func(ToolCallEvent, Context) *ToolCallResult) {
a.onToolCall(handler)
}
// OnToolCallInputStart registers a handler that fires when the LLM begins
// generating tool call arguments. The tool name is known but the full
// argument JSON is still being streamed. Useful for showing a "running"
// indicator immediately without waiting for the full arguments.
func (a *API) OnToolCallInputStart(handler func(ToolCallInputStartEvent, Context)) {
a.onToolCallInputStart(handler)
}
// OnToolCallInputDelta registers a handler that fires for each streamed
// fragment of tool call arguments as they arrive from the LLM.
func (a *API) OnToolCallInputDelta(handler func(ToolCallInputDeltaEvent, Context)) {
a.onToolCallInputDelta(handler)
}
// OnToolCallInputEnd registers a handler that fires when tool argument
// streaming is complete, before the tool call is parsed and execution begins.
func (a *API) OnToolCallInputEnd(handler func(ToolCallInputEndEvent, Context)) {
a.onToolCallInputEnd(handler)
}
// OnToolExecutionStart registers a handler for tool execution start.
func (a *API) OnToolExecutionStart(handler func(ToolExecutionStartEvent, Context)) {
a.onToolExecStart(handler)
@@ -1278,6 +1309,56 @@ func (a *API) OnBeforeCompact(handler func(BeforeCompactEvent, Context) *BeforeC
a.onBeforeCompact(handler)
}
// OnStepStart registers a handler that fires when a new LLM call begins
// within a multi-step agent turn.
func (a *API) OnStepStart(handler func(StepStartEvent, Context)) {
a.onStepStart(handler)
}
// OnStepFinish registers a handler that fires when a step completes,
// providing step number, finish reason, and decomposed token usage.
func (a *API) OnStepFinish(handler func(StepFinishEvent, Context)) {
a.onStepFinish(handler)
}
// OnReasoningStart registers a handler that fires when the LLM begins
// reasoning/thinking.
func (a *API) OnReasoningStart(handler func(ReasoningStartEvent, Context)) {
a.onReasoningStart(handler)
}
// OnWarnings registers a handler that fires when the LLM provider returns
// warnings about the request.
func (a *API) OnWarnings(handler func(WarningsEvent, Context)) {
a.onWarnings(handler)
}
// OnSource registers a handler that fires when the LLM references a source
// (e.g. from web search tools).
func (a *API) OnSource(handler func(SourceEvent, Context)) {
a.onSource(handler)
}
// OnError registers a handler that fires when an agent-level error occurs
// during streaming.
func (a *API) OnError(handler func(ErrorEvent, Context)) {
a.onError(handler)
}
// OnRetry registers a handler that fires when the LLM provider request is
// retried after a transient error.
func (a *API) OnRetry(handler func(RetryEvent, Context)) {
a.onRetry(handler)
}
// OnPrepareStep registers a handler that fires between steps within a
// multi-step agent turn, after steering messages are injected and before
// messages are sent to the LLM. Return a non-nil PrepareStepResult with
// Messages to replace the context window for this step.
func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStepResult) {
a.onPrepareStep(handler)
}
// RegisterToolRenderer registers a custom renderer for a specific tool's
// display in the TUI. The renderer controls the header (parameter summary)
// and/or body (result display) of the tool's output block. If multiple
@@ -1890,6 +1971,34 @@ type ToolCallResult struct {
func (ToolCallResult) isResult() {}
// ToolCallInputStartEvent fires when the LLM begins generating tool call
// arguments. The tool name is known but the full argument JSON is still
// being streamed.
type ToolCallInputStartEvent struct {
ToolCallID string
ToolName string
ToolKind string // Tool classification: "execute", "edit", "read", "search", "agent"
}
func (e ToolCallInputStartEvent) Type() EventType { return ToolCallInputStart }
// ToolCallInputDeltaEvent fires for each streamed fragment of tool call
// arguments as they arrive from the LLM.
type ToolCallInputDeltaEvent struct {
ToolCallID string
Delta string // JSON fragment of tool arguments
}
func (e ToolCallInputDeltaEvent) Type() EventType { return ToolCallInputDelta }
// ToolCallInputEndEvent fires when tool argument streaming is complete,
// before the tool call is parsed and execution begins.
type ToolCallInputEndEvent struct {
ToolCallID string
}
func (e ToolCallInputEndEvent) Type() EventType { return ToolCallInputEnd }
// ToolExecutionStartEvent fires when a tool begins executing.
type ToolExecutionStartEvent struct {
ToolCallID string
@@ -2202,6 +2311,98 @@ type SubagentEndEvent struct {
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
// ---------------------------------------------------------------------------
// Step lifecycle events (exposed to Yaegi — concrete structs)
// ---------------------------------------------------------------------------
// StepStartEvent fires when a new LLM call begins within a multi-step agent turn.
type StepStartEvent struct {
StepNumber int
}
func (e StepStartEvent) Type() EventType { return StepStart }
// StepFinishEvent fires when a step completes, providing step metadata and
// token usage. Usage fields are plain int64 (not LLMUsage) because Yaegi
// cannot handle fantasy types across the interpreter boundary.
type StepFinishEvent struct {
StepNumber int
HasToolCalls bool
FinishReason string
InputTokens int64
OutputTokens int64
CacheReadTokens int64
CacheWriteTokens int64
}
func (e StepFinishEvent) Type() EventType { return StepFinish }
// ReasoningStartEvent fires when the LLM begins reasoning/thinking.
type ReasoningStartEvent struct {
ID string
}
func (e ReasoningStartEvent) Type() EventType { return ReasoningStart }
// WarningsEvent fires when the LLM provider returns warnings about the request.
type WarningsEvent struct {
Warnings []string
}
func (e WarningsEvent) Type() EventType { return Warnings }
// SourceEvent fires when the LLM references a source (e.g. from web search).
type SourceEvent struct {
SourceType string
ID string
URL string
Title string
}
func (e SourceEvent) Type() EventType { return Source }
// ErrorEvent fires when an agent-level error occurs during streaming.
// Uses string instead of error because Yaegi cannot handle the error
// interface reliably across the interpreter boundary.
type ErrorEvent struct {
Error string
}
func (e ErrorEvent) Type() EventType { return Error }
// RetryEvent fires when the LLM provider request is retried after a
// transient error.
type RetryEvent struct {
Attempt int
Error string
}
func (e RetryEvent) Type() EventType { return Retry }
// PrepareStepEvent fires between steps within a multi-step agent turn,
// after steering messages are injected and before messages are sent to
// the LLM. Handlers can inspect and replace the context window.
type PrepareStepEvent struct {
// StepNumber is the zero-based step index within the current turn.
StepNumber int
// Messages is the current context window that will be sent to the LLM.
Messages []ContextMessage
}
func (e PrepareStepEvent) Type() EventType { return PrepareStep }
// PrepareStepResult allows extensions to replace the context window between
// steps. Return nil Messages to leave the context unchanged.
type PrepareStepResult struct {
// Messages replaces the entire context window for this step. If nil,
// the original messages are used unchanged. Messages with a non-negative
// Index reuse the original message at that position; messages with
// Index < 0 are created fresh from Role + Content.
Messages []ContextMessage
}
func (PrepareStepResult) isResult() {}
// ThemeColor is an adaptive color pair with light and dark hex values.
// Either field may be empty to inherit from the default theme.
type ThemeColor struct {
+46 -1
View File
@@ -13,6 +13,19 @@ const (
// ToolCall fires before a tool executes. Handlers can block execution.
ToolCall EventType = "tool_call"
// ToolCallInputStart fires when the LLM begins generating tool call
// arguments. The tool name is known but the full argument JSON is still
// being streamed.
ToolCallInputStart EventType = "tool_call_input_start"
// ToolCallInputDelta fires for each streamed fragment of tool call
// arguments as they arrive from the LLM.
ToolCallInputDelta EventType = "tool_call_input_delta"
// ToolCallInputEnd fires when tool argument streaming is complete,
// before the tool call is parsed and execution begins.
ToolCallInputEnd EventType = "tool_call_input_end"
// ToolExecutionStart fires when a tool begins executing.
ToolExecutionStart EventType = "tool_execution_start"
@@ -83,18 +96,50 @@ const (
// SubagentEnd fires when a subagent tool call completes (success
// or error). Carries the final response and any error message.
SubagentEnd EventType = "subagent_end"
// StepStart fires when a new LLM call begins within a multi-step
// agent turn.
StepStart EventType = "step_start"
// StepFinish fires when a step completes, providing step number,
// finish reason, and token usage.
StepFinish EventType = "step_finish"
// ReasoningStart fires when the LLM begins reasoning/thinking.
ReasoningStart EventType = "reasoning_start"
// Warnings fires when the LLM provider returns warnings.
Warnings EventType = "warnings"
// Source fires when the LLM references a source (e.g. web search).
Source EventType = "source"
// Error fires when an agent-level error occurs during streaming.
Error EventType = "error"
// Retry fires when the LLM provider request is retried after a
// transient error.
Retry EventType = "retry"
// PrepareStep fires between steps within a multi-step agent turn,
// after steering messages are injected and before messages are sent
// to the LLM. Handlers can replace the context window for this step.
PrepareStep EventType = "prepare_step"
)
// AllEventTypes returns every supported event type.
func AllEventTypes() []EventType {
return []EventType{
ToolCall, ToolExecutionStart, ToolExecutionEnd, ToolResult,
ToolCall, ToolCallInputStart, ToolCallInputDelta, ToolCallInputEnd,
ToolExecutionStart, ToolExecutionEnd, ToolResult,
Input, BeforeAgentStart, AgentStart, AgentEnd,
MessageStart, MessageUpdate, MessageEnd,
SessionStart, SessionShutdown,
ModelChange, ContextPrepare,
BeforeFork, BeforeSessionSwitch, BeforeCompact,
SubagentStart, SubagentChunk, SubagentEnd,
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
PrepareStep,
}
}
+5 -2
View File
@@ -4,8 +4,8 @@ import "testing"
func TestAllEventTypes_Count(t *testing.T) {
all := AllEventTypes()
if len(all) != 21 {
t.Fatalf("expected 21 event types, got %d", len(all))
if len(all) != 32 {
t.Fatalf("expected 32 event types, got %d", len(all))
}
}
@@ -38,6 +38,9 @@ func TestEventType_TypeMethod(t *testing.T) {
want EventType
}{
{ToolCallEvent{ToolName: "test"}, ToolCall},
{ToolCallInputStartEvent{ToolCallID: "x", ToolName: "test"}, ToolCallInputStart},
{ToolCallInputDeltaEvent{ToolCallID: "x", Delta: "{"}, ToolCallInputDelta},
{ToolCallInputEndEvent{ToolCallID: "x"}, ToolCallInputEnd},
{ToolExecutionStartEvent{ToolName: "test"}, ToolExecutionStart},
{ToolExecutionEndEvent{ToolName: "test"}, ToolExecutionEnd},
{ToolResultEvent{ToolName: "test"}, ToolResult},
+69
View File
@@ -429,6 +429,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
return *r
})
},
onToolCallInputStart: func(h func(ToolCallInputStartEvent, Context)) {
reg(ToolCallInputStart, func(e Event, c Context) Result {
h(e.(ToolCallInputStartEvent), c)
return nil
})
},
onToolCallInputDelta: func(h func(ToolCallInputDeltaEvent, Context)) {
reg(ToolCallInputDelta, func(e Event, c Context) Result {
h(e.(ToolCallInputDeltaEvent), c)
return nil
})
},
onToolCallInputEnd: func(h func(ToolCallInputEndEvent, Context)) {
reg(ToolCallInputEnd, func(e Event, c Context) Result {
h(e.(ToolCallInputEndEvent), c)
return nil
})
},
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
reg(ToolExecutionStart, func(e Event, c Context) Result {
h(e.(ToolExecutionStartEvent), c)
@@ -600,6 +618,57 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
return nil
})
},
onStepStart: func(h func(StepStartEvent, Context)) {
reg(StepStart, func(e Event, c Context) Result {
h(e.(StepStartEvent), c)
return nil
})
},
onStepFinish: func(h func(StepFinishEvent, Context)) {
reg(StepFinish, func(e Event, c Context) Result {
h(e.(StepFinishEvent), c)
return nil
})
},
onReasoningStart: func(h func(ReasoningStartEvent, Context)) {
reg(ReasoningStart, func(e Event, c Context) Result {
h(e.(ReasoningStartEvent), c)
return nil
})
},
onWarnings: func(h func(WarningsEvent, Context)) {
reg(Warnings, func(e Event, c Context) Result {
h(e.(WarningsEvent), c)
return nil
})
},
onSource: func(h func(SourceEvent, Context)) {
reg(Source, func(e Event, c Context) Result {
h(e.(SourceEvent), c)
return nil
})
},
onError: func(h func(ErrorEvent, Context)) {
reg(Error, func(e Event, c Context) Result {
h(e.(ErrorEvent), c)
return nil
})
},
onRetry: func(h func(RetryEvent, Context)) {
reg(Retry, func(e Event, c Context) Result {
h(e.(RetryEvent), c)
return nil
})
},
onPrepareStep: func(h func(PrepareStepEvent, Context) *PrepareStepResult) {
reg(PrepareStep, func(e Event, c Context) Result {
r := h(e.(PrepareStepEvent), c)
if r == nil {
return nil
}
return *r
})
},
}
// Call Init — the extension registers its handlers, tools, commands.
+92 -3
View File
@@ -1,21 +1,93 @@
package extensions
import (
"bytes"
"fmt"
"log"
"os"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"github.com/spf13/viper"
)
// ---------------------------------------------------------------------------
// reentrantMu — a per-extension mutex that allows the same goroutine to
// re-enter (e.g. handler → ctx.EmitCustomEvent → handler in same extension).
// Different goroutines are serialized, preventing concurrent state mutation.
// ---------------------------------------------------------------------------
type reentrantMu struct {
mu sync.Mutex
cond *sync.Cond
owner int64 // goroutine ID that holds the lock, or 0
depth int // re-entrancy depth
}
// initReentrantMu initializes the reentrant mutex in-place. Must be called
// after the struct is at its final memory location (not before copying).
func (r *reentrantMu) init() {
r.cond = sync.NewCond(&r.mu)
}
// lock acquires the mutex. If the calling goroutine already holds it, the
// call succeeds immediately (re-entrant). Every call to lock must be paired
// with a call to unlock.
func (r *reentrantMu) lock() {
gid := goroutineID()
r.mu.Lock()
if r.owner == gid {
// Re-entrant: same goroutine already holds the lock.
r.depth++
r.mu.Unlock()
return
}
// Wait for the current owner to release.
for r.owner != 0 {
r.cond.Wait() // releases mu, blocks, re-acquires mu on wake
}
r.owner = gid
r.depth = 1
r.mu.Unlock()
}
// unlock releases the mutex (or decrements re-entrancy depth).
func (r *reentrantMu) unlock() {
r.mu.Lock()
r.depth--
if r.depth == 0 {
r.owner = 0
r.cond.Signal()
}
r.mu.Unlock()
}
// goroutineID extracts the current goroutine's ID from runtime.Stack output.
// This is a well-known technique used by Go testing infrastructure.
func goroutineID() int64 {
var buf [64]byte
n := runtime.Stack(buf[:], false)
// Stack output starts with "goroutine NNN ["
s := buf[:n]
s = s[len("goroutine "):]
s = s[:bytes.IndexByte(s, ' ')]
id, _ := strconv.ParseInt(string(s), 10, 64)
return id
}
// Runner manages loaded extensions and dispatches events to their handlers
// sequentially. Handlers execute in extension
// load order; for cancellable events the first blocking result wins.
//
// Each extension has a dedicated reentrant mutex so that handlers for the
// same extension are serialized (preventing data races on shared package-level
// state), while handlers for different extensions may execute concurrently.
type Runner struct {
extensions []LoadedExtension
extMu []reentrantMu // per-extension reentrant mutex, indexed by extension position
ctx Context
widgets map[string]WidgetConfig // keyed by widget ID
statusEntries map[string]StatusBarEntry // keyed by status key
@@ -52,7 +124,11 @@ type LoadedExtension struct {
// NewRunner creates a Runner from a set of loaded extensions.
func NewRunner(exts []LoadedExtension) *Runner {
return &Runner{extensions: exts}
mus := make([]reentrantMu, len(exts))
for i := range mus {
mus[i].init()
}
return &Runner{extensions: exts, extMu: mus}
}
// SetContext updates the runtime context (session ID, model, etc.) that is
@@ -367,6 +443,11 @@ func (r *Runner) Emit(event Event) (Result, error) {
for i := range r.extensions {
ext := &r.extensions[i]
handlers := ext.Handlers[event.Type()]
if len(handlers) == 0 {
continue
}
r.extMu[i].lock()
for _, handler := range handlers {
result, err := safeCall(handler, event, ctx)
if err != nil {
@@ -379,6 +460,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
// Check for blocking/short-circuit results.
if isBlocking(result) {
r.extMu[i].unlock()
return result, nil
}
@@ -386,6 +468,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
// the caller is responsible for applying the modifications.
accumulated = result
}
r.extMu[i].unlock()
}
return accumulated, nil
}
@@ -712,11 +795,17 @@ func (r *Runner) EmitCustomEvent(name, data string) {
// Extension-registered handlers first (in load order).
for i := range r.extensions {
for _, h := range r.extensions[i].CustomEventHandlers[name] {
extHandlers := r.extensions[i].CustomEventHandlers[name]
if len(extHandlers) == 0 {
continue
}
r.extMu[i].lock()
for _, h := range extHandlers {
safeInvoke(h)
}
r.extMu[i].unlock()
}
// Then dynamic subscriptions.
// Then dynamic subscriptions (not extension-scoped, no per-ext lock).
for _, h := range dynamicHandlers {
safeInvoke(h)
}
+140
View File
@@ -1,6 +1,7 @@
package extensions
import (
"sync"
"testing"
)
@@ -571,3 +572,142 @@ func TestRunner_ContextPrintNilSafe(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestRunner_ConcurrentEmitSameExtension(t *testing.T) {
// Verify that concurrent Emit calls for the same extension are serialized
// and don't cause data races on shared handler state.
var counter int
ext := makeHandlerExt("shared-state.go", map[EventType][]HandlerFunc{
SubagentStart: {
func(e Event, c Context) Result {
// Read-modify-write: racy without serialization.
v := counter
counter = v + 1
return nil
},
},
SubagentChunk: {
func(e Event, c Context) Result {
v := counter
counter = v + 1
return nil
},
},
})
r := makeRunner(ext)
var wg sync.WaitGroup
const goroutines = 20
const iterations = 50
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
for range iterations {
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
_, _ = r.Emit(SubagentChunkEvent{ToolCallID: "x"})
}
}()
}
wg.Wait()
if counter != goroutines*iterations*2 {
t.Errorf("expected counter=%d, got %d (race detected)", goroutines*iterations*2, counter)
}
}
func TestRunner_ConcurrentEmitDifferentExtensions(t *testing.T) {
// Two extensions with independent state should not block each other
// and should both run correctly under concurrent Emit calls.
var counter1, counter2 int
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
SubagentStart: {
func(e Event, c Context) Result {
v := counter1
counter1 = v + 1
return nil
},
},
})
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
SubagentStart: {
func(e Event, c Context) Result {
v := counter2
counter2 = v + 1
return nil
},
},
})
r := makeRunner(ext1, ext2)
var wg sync.WaitGroup
const goroutines = 20
const iterations = 50
wg.Add(goroutines)
for range goroutines {
go func() {
defer wg.Done()
for range iterations {
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
}
}()
}
wg.Wait()
expected := goroutines * iterations
if counter1 != expected {
t.Errorf("ext1 counter: expected %d, got %d", expected, counter1)
}
if counter2 != expected {
t.Errorf("ext2 counter: expected %d, got %d", expected, counter2)
}
}
func TestRunner_ReentrantEmitCustomEvent(t *testing.T) {
// Verify that a handler can call EmitCustomEvent (which dispatches to
// the same extension's custom event handlers) without deadlocking.
var order []string
ext := LoadedExtension{
Path: "reentrant.go",
Handlers: map[EventType][]HandlerFunc{
SessionStart: {
func(e Event, c Context) Result {
order = append(order, "session_start")
// This triggers EmitCustomEvent for the same extension
// via a direct runner call (simulating ctx.EmitCustomEvent).
return nil
},
},
},
CustomEventHandlers: map[string][]func(string){
"test-event": {
func(data string) {
order = append(order, "custom:"+data)
},
},
},
}
r := makeRunner(ext)
// Wire up the handler to call EmitCustomEvent re-entrantly.
ext.Handlers[SessionStart] = []HandlerFunc{
func(e Event, c Context) Result {
order = append(order, "session_start")
r.EmitCustomEvent("test-event", "hello")
return nil
},
}
r.extensions[0] = ext
// Rebuild mutexes after modifying extensions slice.
r.extMu = make([]reentrantMu, len(r.extensions))
for i := range r.extMu {
r.extMu[i].init()
}
_, err := r.Emit(SessionStartEvent{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(order) != 2 || order[0] != "session_start" || order[1] != "custom:hello" {
t.Errorf("expected [session_start, custom:hello], got %v", order)
}
}
+14
View File
@@ -152,6 +152,9 @@ func Symbols() interp.Exports {
// Event structs
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
"ToolCallInputStartEvent": reflect.ValueOf((*ToolCallInputStartEvent)(nil)),
"ToolCallInputDeltaEvent": reflect.ValueOf((*ToolCallInputDeltaEvent)(nil)),
"ToolCallInputEndEvent": reflect.ValueOf((*ToolCallInputEndEvent)(nil)),
"ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)),
"ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)),
"ToolOutputEvent": reflect.ValueOf((*ToolOutputEvent)(nil)),
@@ -169,6 +172,17 @@ func Symbols() interp.Exports {
"SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)),
"SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)),
"ModelChangeEvent": reflect.ValueOf((*ModelChangeEvent)(nil)),
// Step lifecycle events
"StepStartEvent": reflect.ValueOf((*StepStartEvent)(nil)),
"StepFinishEvent": reflect.ValueOf((*StepFinishEvent)(nil)),
"ReasoningStartEvent": reflect.ValueOf((*ReasoningStartEvent)(nil)),
"WarningsEvent": reflect.ValueOf((*WarningsEvent)(nil)),
"SourceEvent": reflect.ValueOf((*SourceEvent)(nil)),
"ErrorEvent": reflect.ValueOf((*ErrorEvent)(nil)),
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
},
}
}
+8 -10
View File
@@ -28,11 +28,11 @@ func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantas
return wrapped
}
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
// into fantasy.AgentTool implementations so the LLM can invoke them.
// ExtensionToolsAsLLMTools converts ToolDef values registered by extensions
// into LLM agent tool implementations so the LLM can invoke them.
// The runner is optional; if provided, ToolContext.OnProgress routes
// progress messages through the runner's Print function.
func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
tools := make([]fantasy.AgentTool, 0, len(defs))
for _, def := range defs {
tools = append(tools, &extensionTool{def: def, runner: runner})
@@ -90,8 +90,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
// 0. Check if tool is disabled via SetActiveTools.
if w.runner.IsToolDisabled(toolName) {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("Error: tool %q is currently disabled", toolName)),
fmt.Errorf("tool %q disabled by extension", toolName)
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
}
kind := toolKindFor(toolName)
@@ -111,8 +110,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
if reason == "" {
reason = "blocked by extension"
}
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
fmt.Errorf("tool blocked by extension: %s", reason)
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)), nil
}
}
@@ -154,7 +152,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
}
// ---------------------------------------------------------------------------
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
// extensionTool — wraps a ToolDef into an LLM agent tool
// ---------------------------------------------------------------------------
type extensionTool struct {
@@ -182,7 +180,7 @@ func (t *extensionTool) Info() fantasy.ToolInfo {
info.Parameters = props
} else {
// Schema doesn't have "properties" — use as-is (may be
// a flat property map already matching fantasy's format).
// a flat property map already matching the expected format).
info.Parameters = schema
}
// Extract required fields if present.
@@ -238,7 +236,7 @@ func (t *extensionTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy
}
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), err
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return fantasy.NewTextResponse(result), nil
}
+12 -12
View File
@@ -142,8 +142,8 @@ func TestWrappedTool_BlockExecution(t *testing.T) {
if toolRan {
t.Error("tool should not have run after block")
}
if err == nil {
t.Error("expected error from blocked tool")
if err != nil {
t.Error("expected nil error for blocked tool (error is conveyed via IsError response)")
}
if resp.IsError != true {
t.Error("expected IsError=true from blocked response")
@@ -192,7 +192,7 @@ func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
}
}
func TestExtensionToolsAsFantasy(t *testing.T) {
func TestExtensionToolsAsLLMTools(t *testing.T) {
defs := []ToolDef{
{
Name: "greet",
@@ -202,7 +202,7 @@ func TestExtensionToolsAsFantasy(t *testing.T) {
},
}
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
@@ -232,10 +232,10 @@ func TestExtensionTool_Error(t *testing.T) {
},
}
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
if err == nil {
t.Error("expected error")
if err != nil {
t.Error("expected nil error (error is conveyed via IsError response)")
}
if !resp.IsError {
t.Error("expected IsError=true")
@@ -259,7 +259,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
}
// Without runner, OnProgress is a no-op.
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "test"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
@@ -285,7 +285,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
},
},
}
tools2 := ExtensionToolsAsFantasy(defs2, runner)
tools2 := ExtensionToolsAsLLMTools(defs2, runner)
_, err = tools2[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
if err != nil {
t.Fatalf("unexpected error: %v", err)
@@ -306,7 +306,7 @@ func TestExtensionTool_ExecuteWithContextPriority(t *testing.T) {
},
},
}
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
if err != nil {
t.Fatalf("unexpected error: %v", err)
@@ -330,7 +330,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
},
},
}
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
_, _ = tools[0].Run(ctx, fantasy.ToolCall{Input: ""})
if !sawCancelled {
t.Error("expected IsCancelled=true for cancelled context")
@@ -339,7 +339,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
func TestExtensionTool_ProviderOptions(t *testing.T) {
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
tools := ExtensionToolsAsFantasy(defs, nil)
tools := ExtensionToolsAsLLMTools(defs, nil)
// Initially nil.
opts := tools[0].ProviderOptions()
+248
View File
@@ -0,0 +1,248 @@
// Package fences provides utilities for detecting markdown code regions
// (fenced code blocks and inline code spans) and applying transformations
// only to text outside those regions.
//
// This prevents special tokens like $1, $@, or @file from being interpreted
// when they appear inside ``` fences, ~~~ fences, or `inline` code spans.
package fences
import "strings"
// Ranges returns byte ranges [start, end) of fenced code blocks in content.
// Recognises both backtick (```) and tilde (~~~) fences, with optional
// leading indentation (up to 3 spaces) and optional info strings.
// An unclosed fence extends to the end of content.
func Ranges(content string) [][2]int {
var result [][2]int
var inFence bool
var fenceChar byte
var fenceCount int
var fenceStart int
pos := 0
for pos < len(content) {
// Find the end of the current line.
lineEnd := strings.IndexByte(content[pos:], '\n')
var line string
var nextPos int
if lineEnd < 0 {
line = content[pos:]
nextPos = len(content)
} else {
line = content[pos : pos+lineEnd]
nextPos = pos + lineEnd + 1
}
trimmed := strings.TrimLeft(line, " ")
indent := len(line) - len(trimmed)
if !inFence {
if indent <= 3 {
if ch, n := parseFenceOpen(trimmed); n > 0 {
inFence = true
fenceChar = ch
fenceCount = n
fenceStart = pos
}
}
} else {
if indent <= 3 && isFenceClose(trimmed, fenceChar, fenceCount) {
result = append(result, [2]int{fenceStart, nextPos})
inFence = false
}
}
pos = nextPos
}
// Unclosed fence extends to end of content.
if inFence {
result = append(result, [2]int{fenceStart, len(content)})
}
return result
}
// ReplaceOutside applies fn to each text segment that is outside fenced code
// blocks and inline code spans, leaving code content unchanged. This is the
// primary entry point for callers that need to do regex replacement only on
// non-code text.
func ReplaceOutside(content string, fn func(string) string) string {
ranges := Ranges(content)
if len(ranges) == 0 {
return replaceOutsideInline(content, fn)
}
var b strings.Builder
b.Grow(len(content))
pos := 0
for _, r := range ranges {
if pos < r[0] {
// Within non-fenced segments, also skip inline code spans.
b.WriteString(replaceOutsideInline(content[pos:r[0]], fn))
}
// Preserve fenced content verbatim.
b.WriteString(content[r[0]:r[1]])
pos = r[1]
}
if pos < len(content) {
b.WriteString(replaceOutsideInline(content[pos:], fn))
}
return b.String()
}
// StripCode returns content with fenced code blocks and inline code spans
// removed. Useful for detection/matching where only non-code text matters.
func StripCode(content string) string {
// First strip fenced blocks.
stripped := StripFenced(content)
// Then strip inline code spans from what remains.
return stripInlineCode(stripped)
}
// StripFenced returns content with fenced code block regions removed.
// Useful for detection/matching where only non-fenced text matters.
// NOTE: this does NOT strip inline code spans; use StripCode for both.
func StripFenced(content string) string {
ranges := Ranges(content)
if len(ranges) == 0 {
return content
}
var b strings.Builder
b.Grow(len(content))
pos := 0
for _, r := range ranges {
b.WriteString(content[pos:r[0]])
pos = r[1]
}
b.WriteString(content[pos:])
return b.String()
}
// parseFenceOpen checks whether trimmed (leading spaces already removed)
// starts a fenced code block. Returns the fence character and count, or
// (0, 0) if it is not a fence opener.
func parseFenceOpen(trimmed string) (byte, int) {
if len(trimmed) == 0 {
return 0, 0
}
ch := trimmed[0]
if ch != '`' && ch != '~' {
return 0, 0
}
count := 0
for count < len(trimmed) && trimmed[count] == ch {
count++
}
if count < 3 {
return 0, 0
}
// Per CommonMark: backtick fences cannot have backticks in the info string.
if ch == '`' && strings.ContainsRune(trimmed[count:], '`') {
return 0, 0
}
return ch, count
}
// isFenceClose checks whether trimmed is a closing fence matching fenceChar
// with at least minCount characters. A closing fence line contains only the
// fence characters and optional trailing spaces.
func isFenceClose(trimmed string, fenceChar byte, minCount int) bool {
if len(trimmed) == 0 || trimmed[0] != fenceChar {
return false
}
count := 0
for count < len(trimmed) && trimmed[count] == fenceChar {
count++
}
if count < minCount {
return false
}
// Closing fence must contain only fence chars (and optional trailing spaces).
return strings.TrimRight(trimmed[count:], " ") == ""
}
// --------------------------------------------------------------------------
// Inline code span handling
// --------------------------------------------------------------------------
// inlineCodeRanges returns byte ranges [start, end) of inline code spans
// in segment. Per CommonMark, a code span opens with N backticks and closes
// with exactly N backticks.
func inlineCodeRanges(s string) [][2]int {
var result [][2]int
i := 0
for i < len(s) {
if s[i] != '`' {
i++
continue
}
// Count opening backticks.
start := i
n := 0
for i < len(s) && s[i] == '`' {
n++
i++
}
// Scan for a closing run of exactly n backticks.
for j := i; j < len(s); {
if s[j] != '`' {
j++
continue
}
m := 0
for j < len(s) && s[j] == '`' {
m++
j++
}
if m == n {
result = append(result, [2]int{start, j})
i = j
break
}
}
// If no closing run was found, i is already past the opening
// backticks so the outer loop advances naturally.
}
return result
}
// replaceOutsideInline applies fn only to text outside inline code spans.
func replaceOutsideInline(segment string, fn func(string) string) string {
ranges := inlineCodeRanges(segment)
if len(ranges) == 0 {
return fn(segment)
}
var b strings.Builder
b.Grow(len(segment))
pos := 0
for _, r := range ranges {
if pos < r[0] {
b.WriteString(fn(segment[pos:r[0]]))
}
b.WriteString(segment[r[0]:r[1]])
pos = r[1]
}
if pos < len(segment) {
b.WriteString(fn(segment[pos:]))
}
return b.String()
}
// stripInlineCode removes inline code spans from s.
func stripInlineCode(s string) string {
ranges := inlineCodeRanges(s)
if len(ranges) == 0 {
return s
}
var b strings.Builder
b.Grow(len(s))
pos := 0
for _, r := range ranges {
b.WriteString(s[pos:r[0]])
pos = r[1]
}
b.WriteString(s[pos:])
return b.String()
}
+313
View File
@@ -0,0 +1,313 @@
package fences
import (
"testing"
)
func TestRanges(t *testing.T) {
tests := []struct {
name string
content string
want [][2]int
}{
{
name: "no fences",
content: "hello world\nno code here",
want: nil,
},
{
name: "single backtick fence",
content: "before\n```\ncode\n```\nafter",
want: [][2]int{{7, 20}},
},
{
name: "single tilde fence",
content: "before\n~~~\ncode\n~~~\nafter",
want: [][2]int{{7, 20}},
},
{
name: "fence with info string",
content: "before\n```go\ncode\n```\nafter",
want: [][2]int{{7, 22}},
},
{
name: "multiple fences",
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
want: [][2]int{{2, 12}, {14, 24}},
},
{
name: "unclosed fence",
content: "before\n```\ncode\nmore code",
want: [][2]int{{7, 25}},
},
{
name: "longer closing fence",
content: "before\n```\ncode\n`````\nafter",
want: [][2]int{{7, 22}},
},
{
name: "shorter closing fence ignored",
content: "before\n`````\ncode\n```\nmore\n`````\nafter",
want: [][2]int{{7, 33}},
},
{
name: "indented fence up to 3 spaces",
content: "before\n ```\ncode\n ```\nafter",
want: [][2]int{{7, 26}},
},
{
name: "4 space indent is not a fence",
content: "before\n ```\ncode\n ```\nafter",
want: nil,
},
{
name: "backtick in info string rejects open",
// The ```foo`bar line is not a valid opener (backtick in info).
// The standalone ``` becomes an opener with no close.
content: "before\n```foo`bar\ncode\n```\nafter",
want: [][2]int{{23, 32}},
},
{
name: "empty content",
content: "",
want: nil,
},
{
name: "fence only",
content: "```\ncode\n```",
want: [][2]int{{0, 12}},
},
{
name: "fence at end without trailing newline",
content: "```\ncode\n```",
want: [][2]int{{0, 12}},
},
{
name: "tilde fence does not close with backticks",
content: "~~~\ncode\n```\nmore\n~~~\nafter",
want: [][2]int{{0, 22}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := Ranges(tt.content)
if len(got) != len(tt.want) {
t.Fatalf("Ranges() = %v, want %v", got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("Ranges()[%d] = %v, want %v", i, got[i], tt.want[i])
}
}
})
}
}
func TestReplaceOutside(t *testing.T) {
upper := func(s string) string {
b := []byte(s)
for i, c := range b {
if c >= 'a' && c <= 'z' {
b[i] = c - 32
}
}
return string(b)
}
tests := []struct {
name string
content string
want string
}{
{
name: "no fences",
content: "hello world",
want: "HELLO WORLD",
},
{
name: "text around fence",
content: "before\n```\ncode\n```\nafter",
want: "BEFORE\n```\ncode\n```\nAFTER",
},
{
name: "multiple fences",
content: "aaa\n```\nxxx\n```\nbbb\n~~~\nyyy\n~~~\nccc",
want: "AAA\n```\nxxx\n```\nBBB\n~~~\nyyy\n~~~\nCCC",
},
{
name: "unclosed fence preserves code",
content: "before\n```\ncode",
want: "BEFORE\n```\ncode",
},
{
name: "only fenced content",
content: "```\ncode\n```",
want: "```\ncode\n```",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ReplaceOutside(tt.content, upper)
if got != tt.want {
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
}
})
}
}
func TestStripFenced(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "no fences",
content: "hello $1 world",
want: "hello $1 world",
},
{
name: "strips fenced code",
content: "before $1\n```\n$2 inside\n```\nafter $3",
want: "before $1\nafter $3",
},
{
name: "multiple fences",
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
want: "a\nb\nc",
},
{
name: "unclosed fence",
content: "before\n```\n$1 inside",
want: "before\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := StripFenced(tt.content)
if got != tt.want {
t.Errorf("StripFenced() = %q, want %q", got, tt.want)
}
})
}
}
func TestInlineCodeRanges(t *testing.T) {
tests := []struct {
name string
s string
want [][2]int
}{
{"no backticks", "hello world", nil},
{"single backtick span", "use `$1` here", [][2]int{{4, 8}}},
{"double backtick span", "use ``$1`` here", [][2]int{{4, 10}}},
{"multiple spans", "`$1` and `$2`", [][2]int{{0, 4}, {9, 13}}},
{"unmatched backtick", "use `$1 here", nil},
{"mismatched backtick counts", "use ``$1` here", nil},
{"empty inline content", "use `` `` here", [][2]int{{4, 9}}},
{"backticks inside double", "use ``foo`bar`` here", [][2]int{{4, 15}}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := inlineCodeRanges(tt.s)
if len(got) != len(tt.want) {
t.Fatalf("inlineCodeRanges() = %v, want %v", got, tt.want)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("inlineCodeRanges()[%d] = %v, want %v", i, got[i], tt.want[i])
}
}
})
}
}
func TestReplaceOutside_InlineCode(t *testing.T) {
upper := func(s string) string {
b := []byte(s)
for i, c := range b {
if c >= 'a' && c <= 'z' {
b[i] = c - 32
}
}
return string(b)
}
tests := []struct {
name string
content string
want string
}{
{
name: "inline code preserved",
content: "use `code` here",
want: "USE `code` HERE",
},
{
name: "double backtick inline code",
content: "use ``co`de`` here",
want: "USE ``co`de`` HERE",
},
{
name: "mixed fenced and inline",
content: "before `x` mid\n```\nfenced\n```\nafter `y` end",
want: "BEFORE `x` MID\n```\nfenced\n```\nAFTER `y` END",
},
{
name: "only inline code",
content: "`code`",
want: "`code`",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ReplaceOutside(tt.content, upper)
if got != tt.want {
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
}
})
}
}
func TestStripCode(t *testing.T) {
tests := []struct {
name string
content string
want string
}{
{
name: "no code",
content: "hello $1 world",
want: "hello $1 world",
},
{
name: "strips inline code",
content: "use `$1` and `$2` for positional args",
want: "use and for positional args",
},
{
name: "strips fenced and inline",
content: "before `$1`\n```\n$2 inside\n```\nafter",
want: "before \nafter",
},
{
name: "real world prompt template",
content: "Use $@ for all args.\n`$1`, `$2` for positional.\n```bash\necho $1\n```\n",
want: "Use $@ for all args.\n, for positional.\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := StripCode(tt.content)
if got != tt.want {
t.Errorf("StripCode() = %q, want %q", got, tt.want)
}
})
}
}
+5 -1
View File
@@ -72,6 +72,9 @@ type AgentSetupOptions struct {
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
// loading (successfully or with error). Called from the background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// MCPTaskConfig configures task-augmented tools/call execution. The
// zero value preserves historical synchronous-only behaviour.
MCPTaskConfig tools.MCPTaskConfig
}
// AgentSetupResult bundles the created agent and any debug logger so the caller
@@ -229,6 +232,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
ToolWrapper: toolWrapper,
ExtraTools: extraTools,
OnMCPServerLoaded: opts.OnMCPServerLoaded,
MCPTaskConfig: opts.MCPTaskConfig,
})
if err != nil {
return nil, fmt.Errorf("failed to create agent: %w", err)
@@ -267,7 +271,7 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
return extensions.WrapToolsWithExtensions(tools, runner)
}
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools(), runner)
extTools := extensions.ExtensionToolsAsLLMTools(runner.RegisteredTools(), runner)
return runner, extensionCreationOpts{
toolWrapper: wrapper,
-13
View File
@@ -325,12 +325,6 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
// mixed TextPart and ToolCallPart content. Tool-role messages produce
// ToolResultPart entries.
func (m *Message) ToLLMMessages() []fantasy.Message {
return m.ToFantasyMessages()
}
// Deprecated: Use ToLLMMessages instead.
// ToFantasyMessages converts a Message to one or more LLM message values.
func (m *Message) ToFantasyMessages() []fantasy.Message {
switch m.Role {
case RoleAssistant:
var parts []fantasy.MessagePart
@@ -431,13 +425,6 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
// FromLLMMessage converts an LLM message into our Message type,
// extracting all content parts into the appropriate block types.
func FromLLMMessage(msg fantasy.Message) Message {
return FromFantasyMessage(msg)
}
// Deprecated: Use FromLLMMessage instead.
// FromFantasyMessage converts an LLM message into our Message type,
// extracting all content parts into the appropriate block types.
func FromFantasyMessage(msg fantasy.Message) Message {
m := Message{
Role: MessageRole(msg.Role),
Parts: make([]ContentPart, 0),
File diff suppressed because one or more lines are too long
+120 -17
View File
@@ -25,7 +25,6 @@ import (
openaisdk "github.com/charmbracelet/openai-go"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/ui/progress"
)
const (
@@ -86,6 +85,7 @@ type ThinkingLevel string
const (
ThinkingOff ThinkingLevel = "off"
ThinkingNone ThinkingLevel = "none"
ThinkingMinimal ThinkingLevel = "minimal"
ThinkingLow ThinkingLevel = "low"
ThinkingMedium ThinkingLevel = "medium"
@@ -94,12 +94,14 @@ const (
// ThinkingLevels returns the ordered list of available thinking levels for cycling.
func ThinkingLevels() []ThinkingLevel {
return []ThinkingLevel{ThinkingOff, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh}
return []ThinkingLevel{ThinkingOff, ThinkingNone, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh}
}
// thinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off".
// thinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off" or "none".
func thinkingBudgetTokens(level ThinkingLevel) int64 {
switch level {
case ThinkingNone:
return 1024
case ThinkingMinimal:
return 1024
case ThinkingLow:
@@ -118,6 +120,8 @@ func ThinkingLevelDescription(level ThinkingLevel) string {
switch level {
case ThinkingOff:
return "No reasoning"
case ThinkingNone:
return "Minimal reasoning (OpenAI 'none')"
case ThinkingMinimal:
return "Very brief reasoning (~1k tokens)"
case ThinkingLow:
@@ -134,7 +138,7 @@ func ThinkingLevelDescription(level ThinkingLevel) string {
// ParseThinkingLevel converts a string to a ThinkingLevel, defaulting to ThinkingOff.
func ParseThinkingLevel(s string) ThinkingLevel {
switch ThinkingLevel(s) {
case ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh:
case ThinkingNone, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh:
return ThinkingLevel(s)
default:
return ThinkingOff
@@ -159,6 +163,12 @@ type ProviderConfig struct {
TLSSkipVerify bool
ThinkingLevel ThinkingLevel
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
// for long operations like Ollama model pulls. The returned io.ReadCloser
// must be closed when done. When nil, the raw reader is consumed directly
// with no progress UI.
ProgressReaderFunc func(io.Reader) io.ReadCloser
}
// ProviderResult contains the result of provider creation.
@@ -246,6 +256,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
// via CLI flag or global config.
ApplyModelSettings(config, modelInfo)
// Auto-raise MaxTokens toward the model's known output ceiling when the
// user hasn't explicitly set --max-tokens and no per-model override
// applied. Runs after ApplyModelSettings so explicit modelSettings win.
rightSizeMaxTokens(config, modelInfo)
// Create the base provider
var result *ProviderResult
var createErr error
@@ -290,9 +305,18 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
// Only add cache options for providers that don't already have
// options set, to avoid type conflicts (e.g., Anthropic has
// different types for regular options vs cache control options).
for k, v := range cacheOpts {
if _, exists := result.ProviderOptions[k]; !exists {
result.ProviderOptions[k] = v
//
// For OpenAI Responses API models, we skip merging entirely because
// ResponsesProviderOptions and ProviderOptions are incompatible types.
skipMerge := false
if provider == "openai" && openai.IsResponsesModel(modelName) {
skipMerge = true
}
if !skipMerge {
for k, v := range cacheOpts {
if _, exists := result.ProviderOptions[k]; !exists {
result.ProviderOptions[k] = v
}
}
}
}
@@ -484,6 +508,37 @@ func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) {
}
}
// defaultRightSizeCap bounds auto-raised MaxTokens so that we don't silently
// allocate enormous output budgets for models with very high ceilings (e.g.
// Devstral at 262144, Mistral at 128000). Users who genuinely want more can
// pass --max-tokens explicitly or set modelSettings[...].maxTokens in config.
const defaultRightSizeCap = 32768
// rightSizeMaxTokens raises config.MaxTokens toward the model's known output
// ceiling when:
// - the user has not explicitly set --max-tokens (or the KIT_MAX_TOKENS env
// var, or the top-level max-tokens key in config.yaml), AND
// - no per-model override already bumped MaxTokens (ApplyModelSettings runs
// before this function), AND
// - modelInfo.Limit.Output is known and larger than the current MaxTokens.
//
// The raised value is capped at defaultRightSizeCap to keep accidental
// allocations reasonable on very-large-output models. This prevents the
// common "ghost" where the agent's reply is silently truncated at the 8192
// default even though the selected model supports 64k or 262k output tokens.
func rightSizeMaxTokens(config *ProviderConfig, modelInfo *ModelInfo) {
if modelInfo == nil || modelInfo.Limit.Output <= 0 {
return
}
if isExplicitlySet("max-tokens") {
return
}
target := min(modelInfo.Limit.Output, defaultRightSizeCap)
if config.MaxTokens < target {
config.MaxTokens = target
}
}
// clearConflictingAnthropicSamplingParams ensures that temperature and top_p are
// not both sent to the Anthropic API, which rejects requests containing both.
// When both are set (typically from defaults), top_p is cleared so that
@@ -530,6 +585,8 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
// Returns nil for ThinkingOff (use the model's default).
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
switch level {
case ThinkingNone:
return new(openai.ReasoningEffortNone)
case ThinkingMinimal:
return new(openai.ReasoningEffortMinimal)
case ThinkingLow:
@@ -543,6 +600,56 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
}
}
// IsValidThinkingLevelForModel checks if a thinking level is valid for the given
// model. Some OpenAI models like gpt-5.4 don't support "minimal" and require
// "none" instead.
func IsValidThinkingLevelForModel(level ThinkingLevel, modelName string) bool {
if level == ThinkingOff {
return true
}
// Check if this is an OpenAI model that doesn't support "minimal"
// gpt-5.4 and newer gpt-5.x models use "none" instead of "minimal"
if level == ThinkingMinimal {
if strings.Contains(modelName, "gpt-5.4") ||
strings.Contains(modelName, "gpt-5-pro") ||
strings.Contains(modelName, "gpt-5-chat") {
return false
}
}
// Check if this is an OpenAI model that doesn't support "none"
// Older gpt-5 models only support "minimal", not "none"
if level == ThinkingNone {
if strings.Contains(modelName, "gpt-5") &&
!strings.Contains(modelName, "gpt-5.4") &&
!strings.Contains(modelName, "gpt-5-pro") &&
!strings.Contains(modelName, "gpt-5-chat") {
// Older gpt-5 models might not support "none"
// They only added "none" support in newer versions
return false
}
}
// All other levels are generally valid for reasoning models
return true
}
// SuggestThinkingLevelFallback returns a recommended fallback level when the
// requested level is not valid for the model. Returns ThinkingOff if no
// suitable fallback exists.
func SuggestThinkingLevelFallback(level ThinkingLevel, modelName string) ThinkingLevel {
if level == ThinkingMinimal && !IsValidThinkingLevelForModel(level, modelName) {
// For models that don't support "minimal", suggest "none" (~same token budget)
return ThinkingNone
}
if level == ThinkingNone && !IsValidThinkingLevelForModel(level, modelName) {
// For models that don't support "none", suggest "minimal" (~same token budget)
return ThinkingMinimal
}
return ThinkingOff
}
// buildAnthropicProviderOptions returns fantasy.ProviderOptions configured for
// Anthropic models with extended thinking. When thinking is enabled, it sets
// SendReasoning to true and configures the thinking budget. For thinking-off
@@ -1128,7 +1235,7 @@ func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string,
// Phase 1: Check if model exists locally
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
// Phase 2: Pull model if not found
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
if err := pullOllamaModel(ctx, client, baseURL, modelName, config.ProgressReaderFunc); err != nil {
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
}
}
@@ -1217,11 +1324,7 @@ func checkOllamaModelExists(client *http.Client, baseURL, modelName string) erro
return nil
}
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
}
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string, progressFn func(io.Reader) io.ReadCloser) error {
reqBody := map[string]string{"name": modelName}
jsonBody, _ := json.Marshal(reqBody)
@@ -1245,10 +1348,10 @@ func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseU
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
}
if showProgress {
progressReader := progress.NewProgressReader(resp.Body)
defer func() { _ = progressReader.Close() }()
_, err = io.ReadAll(progressReader)
if progressFn != nil {
pr := progressFn(resp.Body)
defer func() { _ = pr.Close() }()
_, err = io.ReadAll(pr)
} else {
_, err = io.ReadAll(resp.Body)
}
+24 -11
View File
@@ -4,6 +4,7 @@ import (
_ "embed"
"encoding/json"
"fmt"
"maps"
"os"
"strings"
@@ -111,13 +112,30 @@ func NewModelsRegistry() *ModelsRegistry {
}
// buildFromModelsDB converts models.dev provider data into our internal format.
// It tries the on-disk cache first and falls back to the embedded database.
// It starts from the compile-time embedded database and merges on-disk cached
// data from `kit update-models` on top. Cached provider metadata replaces
// embedded metadata, and model entries are merged with cached models taking
// precedence. This means newly synced models are available while embedded
// models that haven't been synced yet are still reachable.
func buildFromModelsDB() map[string]ProviderInfo {
// Try cached data first (from `kit update-models`)
dbProviders, _ := LoadCachedProviders()
if len(dbProviders) == 0 {
// Fall back to compile-time embedded data
dbProviders = loadEmbeddedProviders()
// Start with compile-time embedded data as the base.
dbProviders := loadEmbeddedProviders()
if dbProviders == nil {
dbProviders = make(ModelsDBProviders)
}
// Merge on-disk cached data on top (cached takes precedence).
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
for providerID, cp := range cached {
if existing, ok := dbProviders[providerID]; ok {
// Merge models: embedded base + cached overrides.
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
maps.Copy(mergedModels, existing.Models)
maps.Copy(mergedModels, cp.Models)
cp.Models = mergedModels
}
dbProviders[providerID] = cp
}
}
providers := make(map[string]ProviderInfo, len(dbProviders))
@@ -379,11 +397,6 @@ func (r *ModelsRegistry) GetLLMProviders() []string {
return providers
}
// Deprecated: Use GetLLMProviders instead.
func (r *ModelsRegistry) GetFantasyProviders() []string {
return r.GetLLMProviders()
}
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
// Ollama and custom are always supported (model names are user-defined).
+148
View File
@@ -0,0 +1,148 @@
package models
import (
"testing"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
// bindMaxTokensFlag wires a fresh pflag-backed "max-tokens" key into viper so
// isExplicitlySet behaves the same way it does in production. Returns a
// cleanup function that removes the binding so sibling tests see a clean
// state.
func bindMaxTokensFlag(t *testing.T, args []string) func() {
t.Helper()
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
fs.Int("max-tokens", 8192, "")
if err := viper.BindPFlag("max-tokens", fs.Lookup("max-tokens")); err != nil {
t.Fatalf("BindPFlag: %v", err)
}
if err := fs.Parse(args); err != nil {
t.Fatalf("fs.Parse: %v", err)
}
return func() {
viper.Reset()
}
}
func TestRightSizeMaxTokens_RaisesWhenBelowCeiling(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil) // no args → flag.Changed = false
defer cleanup()
config := &ProviderConfig{MaxTokens: 8192}
modelInfo := &ModelInfo{
ID: "claude-sonnet-4-5",
Limit: Limit{Context: 200000, Output: 64000},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != 32768 {
t.Errorf("expected MaxTokens raised to defaultRightSizeCap (32768), got %d", config.MaxTokens)
}
}
func TestRightSizeMaxTokens_CapsAtDefaultRightSizeCap(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil)
defer cleanup()
config := &ProviderConfig{MaxTokens: 8192}
// Mistral Devstral has 262144 output — we should still cap at 32768.
modelInfo := &ModelInfo{
ID: "devstral-medium-latest",
Limit: Limit{Context: 262144, Output: 262144},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != defaultRightSizeCap {
t.Errorf("expected MaxTokens capped at %d, got %d", defaultRightSizeCap, config.MaxTokens)
}
}
func TestRightSizeMaxTokens_UsesExactOutputWhenBelowCap(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil)
defer cleanup()
config := &ProviderConfig{MaxTokens: 4096}
// Model with output limit smaller than the cap.
modelInfo := &ModelInfo{
ID: "gpt-4",
Limit: Limit{Context: 8192, Output: 8192},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != 8192 {
t.Errorf("expected MaxTokens raised to model output ceiling (8192), got %d", config.MaxTokens)
}
}
func TestRightSizeMaxTokens_DoesNotLowerCurrentValue(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil)
defer cleanup()
// User (via per-model settings, applied earlier) already bumped MaxTokens
// above the cap — we must not clobber their choice.
config := &ProviderConfig{MaxTokens: 100000}
modelInfo := &ModelInfo{
ID: "devstral-medium-latest",
Limit: Limit{Context: 262144, Output: 262144},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != 100000 {
t.Errorf("expected MaxTokens preserved at 100000, got %d", config.MaxTokens)
}
}
func TestRightSizeMaxTokens_RespectsExplicitFlag(t *testing.T) {
// Simulate `--max-tokens 4096` on the command line.
cleanup := bindMaxTokensFlag(t, []string{"--max-tokens", "4096"})
defer cleanup()
config := &ProviderConfig{MaxTokens: 4096}
modelInfo := &ModelInfo{
ID: "claude-sonnet-4-5",
Limit: Limit{Context: 200000, Output: 64000},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != 4096 {
t.Errorf("expected explicit --max-tokens to be preserved (4096), got %d", config.MaxTokens)
}
}
func TestRightSizeMaxTokens_NilModelInfo(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil)
defer cleanup()
config := &ProviderConfig{MaxTokens: 8192}
// Custom model / Ollama / unknown provider → no model info.
rightSizeMaxTokens(config, nil)
if config.MaxTokens != 8192 {
t.Errorf("expected MaxTokens unchanged with nil modelInfo, got %d", config.MaxTokens)
}
}
func TestRightSizeMaxTokens_ZeroOutputLimit(t *testing.T) {
cleanup := bindMaxTokensFlag(t, nil)
defer cleanup()
config := &ProviderConfig{MaxTokens: 8192}
// Model present in catalog but with no known output limit.
modelInfo := &ModelInfo{
ID: "unknown-model",
Limit: Limit{Context: 0, Output: 0},
}
rightSizeMaxTokens(config, modelInfo)
if config.MaxTokens != 8192 {
t.Errorf("expected MaxTokens unchanged with zero output limit, got %d", config.MaxTokens)
}
}
+59 -6
View File
@@ -7,10 +7,12 @@ import (
"regexp"
"strconv"
"strings"
"github.com/mark3labs/kit/internal/fences"
)
// PromptTemplate is a named prompt template with shell-style argument placeholders.
// It supports Pi-style $1, $2, $@, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
// It supports Pi-style $1, $2, $@, $+, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
type PromptTemplate struct {
// Name is the human-readable identifier for this template.
Name string
@@ -120,19 +122,28 @@ func ParseCommandArgs(input string) []string {
// argPlaceholder matches shell-style argument placeholders:
// - $1, $2, etc. - positional arguments
// - $@ - all arguments
// - $@ - all arguments (zero or more)
// - $+ - all arguments (one or more required)
// - $ARGUMENTS - all arguments (alias for $@)
// - ${@:N} - arguments from N onwards
// - ${@:N:L} - L arguments starting from N
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$ARGUMENTS`)
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$\+|\$ARGUMENTS`)
// SubstituteArgs replaces argument placeholders in content with values from args.
// Supported placeholders:
// - $N, ${N} - the Nth argument (1-indexed)
// - $@, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
// - $@, $+, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
// - ${@:N} - arguments from index N onwards (0-indexed)
// - ${@:N:L} - L arguments starting from index N (0-indexed)
func SubstituteArgs(content string, args []string) string {
return fences.ReplaceOutside(content, func(segment string) string {
return substituteArgsInSegment(segment, args)
})
}
// substituteArgsInSegment performs argument substitution on a single text
// segment that is known to be outside fenced code blocks.
func substituteArgsInSegment(content string, args []string) string {
return argPlaceholder.ReplaceAllStringFunc(content, func(match string) string {
// Check for ${N} or ${N:M} format
if strings.HasPrefix(match, "${") && strings.Contains(match, "}") {
@@ -191,8 +202,8 @@ func SubstituteArgs(content string, args []string) string {
if strings.HasPrefix(match, "$") && !strings.HasPrefix(match, "${") {
suffix := match[1:]
// $@ or $ARGUMENTS
if suffix == "@" || suffix == "ARGUMENTS" {
// $@, $+, or $ARGUMENTS
if suffix == "@" || suffix == "+" || suffix == "ARGUMENTS" {
return strings.Join(args, " ")
}
@@ -266,6 +277,48 @@ func joinArgsRange(args []string, start, length int) string {
return strings.Join(args[start:end], " ")
}
// HasArgPlaceholders reports whether the template content contains any
// argument placeholders ($1, $@, $ARGUMENTS, ${@:...}, etc.).
// Placeholders inside fenced code blocks and inline code spans are ignored.
func (t *PromptTemplate) HasArgPlaceholders() bool {
return argPlaceholder.MatchString(fences.StripCode(t.Content))
}
// RequiredArgs returns the number of positional arguments the template
// expects. This is determined by the highest $N or ${N} placeholder found
// in the content (1-indexed, so $2 means 2 args required). The $+
// placeholder (required variadic) ensures at least 1. Optional wildcards
// ($@, $ARGUMENTS) do not contribute to the count.
func (t *PromptTemplate) RequiredArgs() int {
content := fences.StripCode(t.Content)
maxN := 0
hasRequiredVariadic := strings.Contains(content, "$+")
for _, match := range argPlaceholder.FindAllStringSubmatch(content, -1) {
// Group 1: ${N} format — the N value.
if match[1] != "" {
if n, err := strconv.Atoi(match[1]); err == nil && n > maxN {
maxN = n
}
}
// Group 2: ${N:M} format — the N value (start index).
if match[2] != "" {
if n, err := strconv.Atoi(match[2]); err == nil && n > maxN {
maxN = n
}
}
// Group 6: $N format (no braces) — the N value.
if match[6] != "" {
if n, err := strconv.Atoi(match[6]); err == nil && n > maxN {
maxN = n
}
}
}
if hasRequiredVariadic && maxN < 1 {
maxN = 1
}
return maxN
}
// Expand substitutes arguments into the template content and returns the result.
// It first parses args from the input string, then substitutes them into the template.
func (t *PromptTemplate) Expand(argsInput string) string {
+117
View File
@@ -129,6 +129,48 @@ func TestSubstituteArgs(t *testing.T) {
args: []string{},
expected: "Args: ",
},
{
name: "$1 inside code block preserved",
content: "Use $1 here\n```bash\necho $1\n```\ndone",
args: []string{"foo"},
expected: "Use foo here\n```bash\necho $1\n```\ndone",
},
{
name: "$@ inside code block preserved",
content: "Run $@\n```\necho $@\n```\n",
args: []string{"a", "b"},
expected: "Run a b\n```\necho $@\n```\n",
},
{
name: "all placeholders inside code block",
content: "Prompt\n```\n$1 $2 $@\n```\n",
args: []string{"x"},
expected: "Prompt\n```\n$1 $2 $@\n```\n",
},
{
name: "$1 inside inline code preserved",
content: "Use `$1` here and $1 outside",
args: []string{"foo"},
expected: "Use `$1` here and foo outside",
},
{
name: "$+ required variadic",
content: "Args: $+",
args: []string{"a", "b", "c"},
expected: "Args: a b c",
},
{
name: "$+ with empty args",
content: "Args: $+",
args: []string{},
expected: "Args: ",
},
{
name: "all placeholders in inline code",
content: "Use `$1` and `$@` for args",
args: []string{"x"},
expected: "Use `$1` and `$@` for args",
},
}
for _, tt := range tests {
@@ -213,3 +255,78 @@ func TestPromptTemplateExpand(t *testing.T) {
})
}
}
func TestHasArgPlaceholders(t *testing.T) {
tests := []struct {
name string
content string
want bool
}{
{"no placeholders", "Just a plain prompt with no args", false},
{"$1 placeholder", "Create a $1 component", true},
{"$@ placeholder", "Run with args: $@", true},
{"$ARGUMENTS placeholder", "Features: $ARGUMENTS", true},
{"${1} placeholder", "Name: ${1}", true},
{"${ARGUMENTS} placeholder", "All: ${ARGUMENTS}", true},
{"${@:1} placeholder", "Rest: ${@:1}", true},
{"${@:1:2} placeholder", "Slice: ${@:1:2}", true},
{"dollar in text", "Cost is one hundred dollars", false},
{"empty content", "", false},
{"$1 inside code block only", "Prompt\n```\necho $1\n```\n", false},
{"$1 outside and inside code block", "Use $1 here\n```\necho $1\n```\n", true},
{"$@ inside code block only", "Prompt\n```bash\necho $@\n```\n", false},
{"$+ placeholder", "Run with args: $+", true},
{"$+ inside inline code only", "Use `$+` for required args", false},
{"$1 inside inline code only", "Use `$1` for positional args", false},
{"$1 outside and in inline code", "Create $1 (see `$1` syntax)", true},
{"$@ outside $1 in inline code", "Run $@ with `$1` syntax", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tpl := &PromptTemplate{Content: tt.content}
if got := tpl.HasArgPlaceholders(); got != tt.want {
t.Errorf("HasArgPlaceholders() = %v, want %v", got, tt.want)
}
})
}
}
func TestRequiredArgs(t *testing.T) {
tests := []struct {
name string
content string
want int
}{
{"no placeholders", "Just a plain prompt", 0},
{"$1 only", "Create a $1 component", 1},
{"$1 and $2", "Create $1 with $2", 2},
{"$3 skipping $2", "Use $1 and $3", 3},
{"${1} braced", "Name: ${1}", 1},
{"${2} braced", "Name: ${1} Desc: ${2}", 2},
{"$@ only", "Run with: $@", 0},
{"$ARGUMENTS only", "Features: $ARGUMENTS", 0},
{"${ARGUMENTS} only", "All: ${ARGUMENTS}", 0},
{"$1 and $@", "Create $1 with extras: $@", 1},
{"${@:1} slice only", "Rest: ${@:1}", 0},
{"${@:1:2} slice only", "Slice: ${@:1:2}", 0},
{"mixed $1 $2 and $@", "Create $1 named $2: $@", 2},
{"empty content", "", 0},
{"$2 inside code block only", "Prompt\n```\n$1 $2\n```\n", 0},
{"$1 outside $2 inside code block", "Use $1\n```\n$2 inside\n```\n", 1},
{"$+ only", "Run with: $+", 1},
{"$+ and $2", "Create $2 with: $+", 2},
{"$+ inside inline code only", "Use `$+` for required args", 0},
{"$1 and $2 in inline code only", "Use `$1` and `$2` for args", 0},
{"$1 outside $2 in inline code", "Create $1 (see `$2`)", 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tpl := &PromptTemplate{Content: tt.content}
if got := tpl.RequiredArgs(); got != tt.want {
t.Errorf("RequiredArgs() = %d, want %d", got, tt.want)
}
})
}
}
+66
View File
@@ -0,0 +1,66 @@
package session
import (
"testing"
"github.com/mark3labs/kit/internal/message"
)
// TestCompactionParentCycleRegression tests that after multiple compactions,
// newly appended messages always have a valid parent chain and BuildContext
// returns the correct messages.
func TestCompactionParentCycleRegression(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Simulate a long conversation with multiple compactions.
msg1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
msg2, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
// First compaction
comp1, _ := tm.AppendCompaction("Summary 1", msg1, 1000, 500, 1, []string{}, []string{})
msg3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
msg4, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg4"}}})
// Second compaction
comp2, _ := tm.AppendCompaction("Summary 2", msg3, 1000, 500, 1, []string{}, []string{})
msg5, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg5"}}})
msg6, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg6"}}})
// Verify parent chain integrity
for _, id := range []string{msg1, msg2, comp1, msg3, msg4, comp2, msg5, msg6} {
entry := tm.GetEntry(id)
if entry == nil {
t.Fatalf("entry %s not found in index", id)
}
}
// Walk parent chain from msg6 — must reach root without cycles
visited := make(map[string]bool)
current := msg6
for current != "" {
if visited[current] {
t.Fatalf("cycle detected at entry %s", current)
}
visited[current] = true
entry := tm.GetEntry(current)
if entry == nil {
t.Fatalf("entry %s missing from index during parent walk", current)
}
parent := ""
switch e := entry.(type) {
case *MessageEntry:
parent = e.ParentID
case *CompactionEntry:
parent = e.ParentID
}
current = parent
}
// BuildContext should return: Summary2 + msg6 + msg5 + msg3 + msg4 = 5 messages
msgs, _, _ := tm.BuildContext()
if len(msgs) != 5 {
t.Fatalf("expected 5 messages, got %d: %+v", len(msgs), msgs)
}
}
+70
View File
@@ -0,0 +1,70 @@
package session
import (
"strings"
"testing"
)
// TestEncodeCwdForDir verifies the working-directory → session-directory
// name encoding strips characters that are illegal on Windows (notably the
// drive-letter colon, see issue #18) while preserving the previous output
// for the typical Unix paths.
func TestEncodeCwdForDir(t *testing.T) {
tests := []struct {
name string
cwd string
want string
}{
{
name: "unix absolute path",
cwd: "/home/user/proj",
want: "home--user--proj",
},
{
name: "unix relative path",
cwd: "proj/sub",
want: "proj--sub",
},
{
name: "windows drive root",
cwd: `C:\test`,
want: "C--test",
},
{
name: "windows nested path",
cwd: `C:\Users\User\code`,
want: "C--Users--User--code",
},
{
name: "windows secondary drive",
cwd: `S:\work\repo`,
want: "S--work--repo",
},
{
name: "windows mixed separators",
cwd: `C:\Users/User\code`,
want: "C--Users--User--code",
},
{
name: "windows other illegal chars stripped",
cwd: `C:\a<b>c|d?e*f"g`,
want: "C--abcdefg",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := encodeCwdForDir(tc.cwd)
if got != tc.want {
t.Errorf("encodeCwdForDir(%q) = %q, want %q", tc.cwd, got, tc.want)
}
// Encoded directory must never contain characters that are
// illegal in Windows directory names.
for _, bad := range []string{":", "<", ">", "\"", "|", "?", "*", "\\", "/"} {
if strings.Contains(got, bad) {
t.Errorf("encodeCwdForDir(%q) = %q contains illegal char %q", tc.cwd, got, bad)
}
}
})
}
}
+109
View File
@@ -0,0 +1,109 @@
package session
import (
"testing"
"github.com/mark3labs/kit/internal/message"
)
// TestDetectCycleWithCorruptedParentChain tests that cycle detection works
// when a corrupted session has circular parent references.
func TestDetectCycleWithCorruptedParentChain(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Create normal chain: msg1 -> msg2 -> msg3
id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
_, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
// Simulate corruption: manually set msg1's parent to msg3, creating cycle
// This simulates the condition seen in the user's session
for _, entry := range tm.entries {
if e, ok := entry.(*MessageEntry); ok && e.ID == id1 {
e.ParentID = id3 // Create cycle: msg1 -> msg3 -> ... -> msg1
break
}
}
// DetectCycle should find the cycle
// The cycle is: id1 -> id3 -> id2 -> id1
// So detecting from id3 should find id1 as the repeat
cycle, entry := tm.DetectCycle(id3)
if !cycle {
t.Fatal("expected to detect cycle, but none found")
}
// The cycle entry could be id1 or id3 depending on where we start
if entry != id1 && entry != id3 {
t.Fatalf("expected cycle at %s or %s, got %s", id1, id3, entry)
}
// BuildContext should still work (it has its own cycle detection)
// but will truncate at the cycle point
msgs, _, _ := tm.BuildContext()
if len(msgs) == 0 {
t.Fatal("BuildContext returned no messages")
}
}
// TestAppendMessageRejectsInvalidParent tests that AppendMessage rejects
// appending when the current leaf has a broken parent chain.
func TestAppendMessageRejectsInvalidParent(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Create normal message
id1, err := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
if err != nil {
t.Fatalf("failed to append msg1: %v", err)
}
// Simulate corruption: set leafID to a non-existent ID
tm.leafID = "non-existent-id"
// Next append should fail validation
_, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
if err == nil {
t.Fatal("expected error when appending with invalid leafID, got nil")
}
// Restore valid leafID
tm.leafID = id1
// Append should succeed now
_, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
if err != nil {
t.Fatalf("failed to append msg3 after restoring leafID: %v", err)
}
}
// TestBuildContextHandlesCycleGracefully tests that BuildContext handles
// cycles gracefully by truncating the branch.
func TestBuildContextHandlesCycleGracefully(t *testing.T) {
tm := InMemoryTreeSession("/test")
// Create messages
id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
_, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
// Verify normal case works
msgs, _, _ := tm.BuildContext()
if len(msgs) != 3 {
t.Fatalf("expected 3 messages, got %d", len(msgs))
}
// Simulate cycle: set msg1's parent to msg3
for _, entry := range tm.entries {
if e, ok := entry.(*MessageEntry); ok && e.ID == id1 {
e.ParentID = id3
break
}
}
// BuildContext should handle cycle gracefully (getBranchLocked has cycle detection)
msgs, _, _ = tm.BuildContext()
// Should only include messages from the cycle: msg3, msg2, msg1
// (msg3 is leaf, walks to msg2 -> msg1 -> msg3 (cycle detected, stops))
if len(msgs) != 3 {
t.Fatalf("expected 3 messages in cycle case, got %d: %+v", len(msgs), msgs)
}
}
+150 -9
View File
@@ -63,6 +63,11 @@ type TreeManager struct {
// file is the open file handle for appending entries. Nil for in-memory.
file *os.File
// writer is a buffered writer wrapping file. Writes go through this
// buffer and are flushed to disk at explicit sync points (after each
// public Append* call, in Close, etc.) to reduce syscall overhead.
writer *bufio.Writer
}
// --- Constructors ---
@@ -105,11 +110,16 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
return nil, fmt.Errorf("failed to create session file: %w", err)
}
tm.file = f
tm.writer = bufio.NewWriter(f)
if err := tm.writeEntry(&header); err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to write session header: %w", err)
}
if err := tm.flushLocked(); err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to flush session header: %w", err)
}
return tm, nil
}
@@ -150,6 +160,7 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
return nil, fmt.Errorf("failed to recreate session file: %w", err)
}
newTm.file = f
newTm.writer = bufio.NewWriter(f)
if err := newTm.writeEntry(&newTm.header); err != nil {
_ = f.Close()
@@ -289,6 +300,12 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
}
}
// Flush all buffered writes from the fork in a single syscall.
if err := newTm.flushLocked(); err != nil {
_ = f.Close()
return nil, fmt.Errorf("failed to flush forked session: %w", err)
}
// Set the leaf to the last entry in the new session.
newTm.leafID = prevNewID
@@ -365,12 +382,16 @@ func OpenTreeSession(path string) (*TreeManager, error) {
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
}
// Validate tree integrity and log diagnostics
tm.LogTreeDiagnostics()
// Open file for appending.
f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open session file for append: %w", err)
}
tm.file = f
tm.writer = bufio.NewWriter(f)
return tm, nil
}
@@ -410,6 +431,12 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
tm.mu.Lock()
defer tm.mu.Unlock()
// Validate parent chain before appending to detect/prevent cycles
// that could be caused by external file corruption or race conditions.
if err := tm.validateParentChainLocked(tm.leafID, ""); err != nil {
return "", fmt.Errorf("parent chain validation failed: %w", err)
}
entry, err := NewMessageEntry(tm.leafID, msg)
if err != nil {
return "", err
@@ -418,6 +445,9 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush message: %w", err)
}
tm.leafID = entry.ID
return entry.ID, nil
@@ -442,6 +472,9 @@ func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, erro
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush model change: %w", err)
}
tm.leafID = entry.ID
return entry.ID, nil
@@ -456,6 +489,9 @@ func (tm *TreeManager) AppendBranchSummary(fromID, summary string) (string, erro
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush branch summary: %w", err)
}
tm.leafID = entry.ID
return entry.ID, nil
@@ -470,6 +506,9 @@ func (tm *TreeManager) AppendLabel(targetID, label string) (string, error) {
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush label: %w", err)
}
tm.labels[targetID] = label
tm.leafID = entry.ID
@@ -485,6 +524,9 @@ func (tm *TreeManager) AppendSessionInfo(name string) (string, error) {
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush session info: %w", err)
}
tm.sessionName = name
tm.leafID = entry.ID
@@ -501,6 +543,9 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush extension data: %w", err)
}
tm.leafID = entry.ID
return entry.ID, nil
@@ -518,6 +563,13 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens
tm.mu.Lock()
defer tm.mu.Unlock()
// Validate that firstKeptEntryID exists if provided
if firstKeptEntryID != "" {
if _, ok := tm.index[firstKeptEntryID]; !ok {
return "", fmt.Errorf("first kept entry %q does not exist", firstKeptEntryID)
}
}
// The compaction entry has no parent, making it a new "root" for the
// post-compaction branch. This ensures old compacted messages are not
// traversed when walking from the current leaf.
@@ -525,6 +577,9 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens
if err := tm.appendAndPersist(entry); err != nil {
return "", err
}
if err := tm.flushLocked(); err != nil {
return "", fmt.Errorf("failed to flush compaction: %w", err)
}
tm.leafID = entry.ID
return entry.ID, nil
@@ -910,11 +965,31 @@ func (tm *TreeManager) IsEmpty() bool {
return tm.MessageCount() == 0
}
// Close closes the underlying file handle.
// Flush writes any buffered data to the underlying file.
func (tm *TreeManager) Flush() error {
tm.mu.Lock()
defer tm.mu.Unlock()
return tm.flushLocked()
}
// flushLocked writes buffered data to disk. Caller must hold the lock.
func (tm *TreeManager) flushLocked() error {
if tm.writer != nil {
return tm.writer.Flush()
}
return nil
}
// Close flushes any buffered writes and closes the underlying file handle.
func (tm *TreeManager) Close() error {
tm.mu.Lock()
defer tm.mu.Unlock()
if tm.file != nil {
// Flush buffered data before closing.
if tm.writer != nil {
_ = tm.writer.Flush()
tm.writer = nil
}
err := tm.file.Close()
tm.file = nil
return err
@@ -1074,13 +1149,22 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
// AddLLMMessages appends multiple LLM messages as entries. This is
// used when syncing from the agent's ConversationMessages after a step.
// All entries are buffered and flushed to disk in a single batch.
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
tm.mu.Lock()
defer tm.mu.Unlock()
for _, msg := range msgs {
if _, err := tm.AppendLLMMessage(msg); err != nil {
entry, err := NewMessageEntry(tm.leafID, message.FromLLMMessage(msg))
if err != nil {
return err
}
if err := tm.appendAndPersist(entry); err != nil {
return err
}
tm.leafID = entry.ID
}
return nil
return tm.flushLocked()
}
// Deprecated: Use AddLLMMessages instead.
@@ -1132,12 +1216,20 @@ func (tm *TreeManager) appendAndPersist(entry any) error {
return nil
}
// writeEntry serializes an entry and appends it as a line to the file.
// writeEntry serializes an entry and appends it to the buffered writer.
// The data is not flushed to disk until flushLocked is called.
func (tm *TreeManager) writeEntry(entry any) error {
data, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to marshal entry: %w", err)
}
if tm.writer != nil {
if _, err := tm.writer.Write(data); err != nil {
return err
}
return tm.writer.WriteByte('\n')
}
// Fallback for direct file writes (shouldn't happen in normal flow).
data = append(data, '\n')
_, err = tm.file.Write(data)
return err
@@ -1213,12 +1305,32 @@ func (tm *TreeManager) getBranchLocked(fromID string) []any {
}
// buildTreeNode recursively builds a TreeNode from an entry ID.
// It includes a depth limit to prevent infinite recursion in case of
// corrupted parent-child relationships.
func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
return tm.buildTreeNodeDepth(id, 0, make(map[string]bool))
}
// buildTreeNodeDepth is the internal implementation with depth tracking.
func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[string]bool) *TreeNode {
const maxDepth = 1000
if depth > maxDepth {
// Cycle or extremely deep tree detected, stop recursing
return nil
}
if visited[id] {
// Cycle detected, stop recursing
return nil
}
entry, ok := tm.index[id]
if !ok {
return nil
}
visited[id] = true
defer delete(visited, id)
node := &TreeNode{
Entry: entry,
ID: id,
@@ -1226,7 +1338,7 @@ func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
}
for _, childID := range tm.childIndex[id] {
child := tm.buildTreeNode(childID)
child := tm.buildTreeNodeDepth(childID, depth+1, visited)
if child != nil {
node.Children = append(node.Children, child)
}
@@ -1238,15 +1350,44 @@ func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
// --- Path conventions ---
// DefaultSessionDir returns the default session storage directory for a cwd.
// Convention: ~/.kit/sessions/--<cwd-path>--/
// Convention: ~/.kit/sessions/<encoded-cwd>, where path separators are
// encoded as "--" with no leading or trailing dashes — e.g.
// /home/user/proj becomes home--user--proj. See encodeCwdForDir for the
// full encoding rules (including Windows path handling).
func DefaultSessionDir(cwd string) string {
home, err := os.UserHomeDir()
if err != nil {
home = "."
}
// Convert path separators to double dashes.
safeCwd := strings.ReplaceAll(cwd, string(filepath.Separator), "--")
return filepath.Join(home, ".kit", "sessions", encodeCwdForDir(cwd))
}
// encodeCwdForDir converts a working-directory path into a single, filesystem-
// safe directory name. Path separators are replaced with double dashes and
// characters that are illegal in Windows directory names — most importantly
// the colon that follows the drive letter (e.g. `C:\foo` → `C--foo`) — are
// stripped. The result is identical to the previous Unix-only encoding for
// paths that do not contain such characters, so existing session directories
// are preserved.
func encodeCwdForDir(cwd string) string {
// Convert both `/` and `\` to double dashes so encoding is stable across
// platforms and remains correct on Windows where `filepath.Separator`
// would otherwise miss forward-slash style paths.
safeCwd := strings.ReplaceAll(cwd, "\\", "--")
safeCwd = strings.ReplaceAll(safeCwd, "/", "--")
// Remove leading separator replacement.
safeCwd = strings.TrimPrefix(safeCwd, "--")
return filepath.Join(home, ".kit", "sessions", safeCwd)
// Strip characters that are illegal in directory names on Windows
// (`< > : " | ? *`). On Unix these characters are legal but rare in
// practice; stripping them keeps the encoding portable.
replacer := strings.NewReplacer(
":", "",
"<", "",
">", "",
"\"", "",
"|", "",
"?", "",
"*", "",
)
return replacer.Replace(safeCwd)
}
+143
View File
@@ -0,0 +1,143 @@
package session
import (
"fmt"
"log"
)
// ValidateParentChain checks that the parent ID points to an existing entry
// and that appending this entry would not create a cycle. This should be called
// before appending any entry to the tree.
// Returns an error if the parent is invalid or would create a cycle.
func (tm *TreeManager) ValidateParentChain(parentID string, newEntryID string) error {
if parentID == "" {
// Empty parent is valid (root entry)
return nil
}
// Check that parent exists
if _, ok := tm.index[parentID]; !ok {
return fmt.Errorf("parent entry %q does not exist in index", parentID)
}
// Check that we're not creating a cycle by walking up the parent chain
// from parentID and ensuring we don't hit newEntryID (or any node that
// has newEntryID as an ancestor, but since newEntryID is new, just check
// that parentID isn't newEntryID, which it can't be since we check existence)
visited := make(map[string]bool)
current := parentID
for current != "" {
if visited[current] {
return fmt.Errorf("existing cycle detected at entry %q", current)
}
visited[current] = true
// Safety check: if somehow we reach the new entry ID, that's a cycle
if current == newEntryID {
return fmt.Errorf("would create cycle: entry %q cannot be its own ancestor", newEntryID)
}
entry, ok := tm.index[current]
if !ok {
return fmt.Errorf("broken parent chain: entry %q not found", current)
}
current = tm.entryParentID(entry)
}
return nil
}
// DetectCycle walks the parent chain from the given entry ID and returns true
// if a cycle is detected. This is used for diagnostics.
func (tm *TreeManager) DetectCycle(fromID string) (cycleDetected bool, cycleEntry string) {
visited := make(map[string]bool)
current := fromID
for current != "" {
if visited[current] {
return true, current
}
visited[current] = true
entry, ok := tm.index[current]
if !ok {
return false, ""
}
current = tm.entryParentID(entry)
}
return false, ""
}
// LogTreeDiagnostics logs information about the tree structure for debugging.
// Call this after OpenTreeSession or when anomalies are detected.
func (tm *TreeManager) LogTreeDiagnostics() {
tm.mu.RLock()
defer tm.mu.RUnlock()
log.Printf("[TreeManager] Entry count: %d, Leaf ID: %s", len(tm.entries), tm.leafID)
// Check for cycles from leaf
if tm.leafID != "" {
if cycle, entry := tm.detectCycleLocked(tm.leafID); cycle {
log.Printf("[TreeManager] WARNING: Cycle detected in tree at entry %s", entry)
}
}
// Count entries by type
counts := make(map[EntryType]int)
for _, entry := range tm.entries {
var et EntryType
switch e := entry.(type) {
case *MessageEntry:
et = e.Type
case *ModelChangeEntry:
et = e.Type
case *BranchSummaryEntry:
et = e.Type
case *LabelEntry:
et = e.Type
case *SessionInfoEntry:
et = e.Type
case *ExtensionDataEntry:
et = e.Type
case *CompactionEntry:
et = e.Type
default:
et = "unknown"
}
counts[et]++
}
log.Printf("[TreeManager] Entry types: %+v", counts)
}
// detectCycleLocked is the internal version of DetectCycle (must hold read lock)
func (tm *TreeManager) detectCycleLocked(fromID string) (bool, string) {
visited := make(map[string]bool)
current := fromID
for current != "" {
if visited[current] {
return true, current
}
visited[current] = true
entry, ok := tm.index[current]
if !ok {
return false, ""
}
current = tm.entryParentID(entry)
}
return false, ""
}
// validateParentChainLocked is the internal version used by append methods.
// Must be called with the write lock held.
func (tm *TreeManager) validateParentChainLocked(parentID string, newEntryID string) error {
if parentID == "" {
return nil
}
if _, ok := tm.index[parentID]; !ok {
return fmt.Errorf("parent entry %q does not exist", parentID)
}
// Check for existing cycles in the parent chain
if cycle, entry := tm.detectCycleLocked(parentID); cycle {
return fmt.Errorf("existing cycle detected at entry %q in parent chain", entry)
}
return nil
}
+129 -36
View File
@@ -8,11 +8,11 @@ import (
"sync"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
@@ -47,6 +47,7 @@ type MCPConnection struct {
client client.MCPClient
serverName string
serverConfig config.MCPServerConfig
initResult *mcp.InitializeResult // captured at handshake; nil before initialize
lastUsed time.Time
isHealthy bool
errorCount int
@@ -63,7 +64,6 @@ type MCPConnectionPool struct {
connections map[string]*MCPConnection
config *ConnectionPoolConfig
mu sync.RWMutex
model fantasy.LanguageModel
ctx context.Context
cancel context.CancelFunc
debug bool
@@ -75,9 +75,8 @@ type MCPConnectionPool struct {
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
// If config is nil, default configuration values will be used. The pool starts a background
// goroutine for periodic health checks that runs until Close is called.
// The model parameter is used for MCP servers that require sampling support.
// Thread-safe for concurrent use immediately after creation.
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
func NewMCPConnectionPool(config *ConnectionPoolConfig, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
if config == nil {
config = DefaultConnectionPoolConfig()
}
@@ -86,7 +85,6 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
pool := &MCPConnectionPool{
connections: make(map[string]*MCPConnection),
config: config,
model: model,
ctx: ctx,
cancel: cancel,
debug: debug,
@@ -246,10 +244,12 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
// createConnection creates a new connection
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
oauthEnabled := p.oauthFlow != nil && !serverConfig.NoOAuth
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
if err != nil {
// SSE transport can return OAuth error during Start()
if p.oauthFlow != nil && IsOAuthError(err) {
if oauthEnabled && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
@@ -263,15 +263,17 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
}
}
if err := p.initializeClient(ctx, mcpClient); err != nil {
conn := &MCPConnection{}
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
// Streamable HTTP transport returns OAuth error during Initialize()
if p.oauthFlow != nil && IsOAuthError(err) {
if oauthEnabled && IsOAuthError(err) {
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
_ = mcpClient.Close()
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
}
// Retry initialization after successful auth
if err := p.initializeClient(ctx, mcpClient); err != nil {
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
_ = mcpClient.Close()
return nil, err
}
@@ -281,15 +283,11 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
}
}
conn := &MCPConnection{
client: mcpClient,
serverName: serverName,
serverConfig: serverConfig,
lastUsed: time.Now(),
isHealthy: true,
errorCount: 0,
lastError: nil,
}
conn.client = mcpClient
conn.serverName = serverName
conn.serverConfig = serverConfig
conn.lastUsed = time.Now()
conn.isHealthy = true
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Created connection for %s", serverName))
@@ -308,6 +306,8 @@ func (p *MCPConnectionPool) createMCPClient(ctx context.Context, serverName stri
return p.createSSEClient(ctx, serverConfig)
case "streamable":
return p.createStreamableClient(ctx, serverConfig)
case "inprocess":
return p.createInProcessClient(serverConfig)
default:
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
}
@@ -364,20 +364,30 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
// Enable OAuth for remote transports when an auth handler is configured
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
// PubMed) set NoOAuth to skip dynamic client registration and token
// exchange, which would otherwise fail with a 404.
if p.oauthFlow != nil && !serverConfig.NoOAuth {
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithOAuth(transport.OAuthConfig{
oauthCfg := transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
if serverConfig.OAuthClientID != "" {
oauthCfg.ClientID = serverConfig.OAuthClientID
}
if serverConfig.OAuthClientSecret != "" {
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
}
if len(serverConfig.OAuthScopes) > 0 {
oauthCfg.Scopes = serverConfig.OAuthScopes
}
options = append(options, transport.WithOAuth(oauthCfg))
}
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
@@ -411,20 +421,28 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
}
}
// Enable OAuth for remote transports when an auth handler is configured.
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
// scopes are discovered automatically via dynamic client registration and
// server metadata (RFC 9728).
if p.oauthFlow != nil {
// Enable OAuth for remote transports when an auth handler is configured
// and the server hasn't opted out via NoOAuth.
if p.oauthFlow != nil && !serverConfig.NoOAuth {
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
if tsErr != nil {
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
}
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
oauthCfg := transport.OAuthConfig{
RedirectURI: p.oauthFlow.handler.RedirectURI(),
PKCEEnabled: true,
TokenStore: tokenStore,
}))
}
if serverConfig.OAuthClientID != "" {
oauthCfg.ClientID = serverConfig.OAuthClientID
}
if serverConfig.OAuthClientSecret != "" {
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
}
if len(serverConfig.OAuthScopes) > 0 {
oauthCfg.Scopes = serverConfig.OAuthScopes
}
options = append(options, transport.WithHTTPOAuth(oauthCfg))
}
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
@@ -439,6 +457,22 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
return streamableClient, nil
}
// createInProcessClient creates an in-process MCP client that communicates
// directly with an *server.MCPServer in the same process. No subprocess is
// spawned and no network I/O occurs — calls go through JSON marshal →
// MCPServer.HandleMessage → JSON unmarshal, all in-memory.
func (p *MCPConnectionPool) createInProcessClient(serverConfig config.MCPServerConfig) (client.MCPClient, error) {
srv, ok := serverConfig.InProcessServer.(*server.MCPServer)
if !ok {
return nil, fmt.Errorf("InProcessServer must be *server.MCPServer, got %T", serverConfig.InProcessServer)
}
inProcessClient, err := client.NewInProcessClient(srv)
if err != nil {
return nil, fmt.Errorf("failed to create in-process client: %w", err)
}
return inProcessClient, nil
}
// createTokenStore creates a token store for the given server URL.
// If a custom TokenStoreFactory is configured, it is used; otherwise the
// default file-backed token store is created.
@@ -449,8 +483,10 @@ func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenS
return NewFileTokenStore(serverURL)
}
// initializeClient initializes the client
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
// initializeClient initializes the client and captures the server's
// initialize result on the supplied connection so callers can later
// inspect advertised capabilities (e.g. task support).
func (p *MCPConnectionPool) initializeClient(ctx context.Context, c client.MCPClient, conn *MCPConnection) error {
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
@@ -460,12 +496,21 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
Name: "kit",
Version: "1.0.0",
}
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
// Advertise task support so servers may return CreateTaskResult for
// long-running tools/call requests instead of blocking the connection
// until completion. The client is responsible for polling tasks/get and
// tasks/result until the task reaches a terminal state.
initRequest.Params.Capabilities = mcp.ClientCapabilities{
Tasks: mcp.NewTasksCapability(),
}
_, err := client.Initialize(initCtx, initRequest)
initResult, err := c.Initialize(initCtx, initRequest)
if err != nil {
return fmt.Errorf("initialization timeout or failed: %w", err)
}
if conn != nil {
conn.initResult = initResult
}
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
p.debugLogger.LogDebug("[POOL] Initialized MCP client")
@@ -580,6 +625,54 @@ func (c *MCPConnection) ServerName() string {
return c.serverName
}
// InitializeResult returns the result captured from the server's initialize
// response, or nil if the connection was created before initialize completed.
// Callers can inspect ServerCapabilities.Tasks to discover task-related
// capability advertisements.
func (c *MCPConnection) InitializeResult() *mcp.InitializeResult {
c.mu.RLock()
defer c.mu.RUnlock()
return c.initResult
}
// SupportsToolTasks reports whether the server advertised support for
// task-augmented tools/call requests. Returns false when the connection has
// not yet completed initialization or when the server omitted task
// capabilities.
func (c *MCPConnection) SupportsToolTasks() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return supportsToolTasksFromInit(c.initResult)
}
// supportsToolTasksFromInit reports whether the supplied InitializeResult
// advertises task-augmented tools/call support. Extracted to a free function
// for unit testing without standing up a connection.
func supportsToolTasksFromInit(init *mcp.InitializeResult) bool {
if init == nil || init.Capabilities.Tasks == nil {
return false
}
req := init.Capabilities.Tasks.Requests
if req == nil || req.Tools == nil {
return false
}
return req.Tools.Call != nil
}
// ServerSupportsToolTasks reports whether the named server's connection
// advertises task-augmented tools/call support. Returns false when no
// connection exists for the server or when the server didn't advertise the
// capability.
func (p *MCPConnectionPool) ServerSupportsToolTasks(serverName string) bool {
p.mu.RLock()
conn, ok := p.connections[serverName]
p.mu.RUnlock()
if !ok {
return false
}
return conn.SupportsToolTasks()
}
// GetClients returns a map of all MCP clients currently in the pool.
// The map keys are server names and values are the corresponding MCP client instances.
// The returned map is a copy and modifications won't affect the pool.
-109
View File
@@ -1,109 +0,0 @@
package tools
import (
"context"
"encoding/json"
"fmt"
"charm.land/fantasy"
"github.com/mark3labs/mcp-go/mcp"
)
// mcpFantasyTool adapts an MCP tool to the fantasy.AgentTool interface.
// It bridges the MCP tool protocol with fantasy's agent tool system, handling
// name prefixing, schema conversion, connection pooling, and result marshaling.
type mcpFantasyTool struct {
toolInfo fantasy.ToolInfo
mapping *toolMapping
providerOptions fantasy.ProviderOptions
}
// Info returns the fantasy tool info including name, description, and parameter schema.
func (t *mcpFantasyTool) Info() fantasy.ToolInfo {
return t.toolInfo
}
// Run executes the MCP tool by routing through the connection pool.
// It maps the prefixed tool name back to the original name, retrieves a healthy
// connection, invokes the tool, and converts the MCP result to a fantasy ToolResponse.
func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
// Parse and validate JSON arguments
var arguments any
input := call.Input
if input == "" || input == "{}" {
arguments = nil
} else {
var temp any
if err := json.Unmarshal([]byte(input), &temp); err != nil {
return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid JSON arguments: %v", err)), nil
}
arguments = json.RawMessage(input)
}
// Get connection from pool with health check
conn, err := t.mapping.manager.connectionPool.GetConnectionWithHealthCheck(
ctx, t.mapping.serverName, t.mapping.serverConfig,
)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("failed to get healthy connection from pool: %w", err)
}
// Call the MCP tool using the original (unprefixed) name
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{
Method: "tools/call",
},
Params: mcp.CallToolParams{
Name: t.mapping.originalName,
Arguments: arguments,
},
})
if err != nil {
// Handle OAuth re-authorization: token may have expired mid-session.
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
}
// Retry the tool call after successful re-auth.
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{
Method: "tools/call",
},
Params: mcp.CallToolParams{
Name: t.mapping.originalName,
Arguments: arguments,
},
})
if err != nil {
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
}
} else {
// Mark connection as unhealthy for automatic recovery
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
}
}
// Marshal the MCP result to JSON string
marshaledResult, err := json.Marshal(result)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal mcp tool result: %w", err)
}
// Return as text response, preserving error status from MCP
if result.IsError {
return fantasy.NewTextErrorResponse(string(marshaledResult)), nil
}
return fantasy.NewTextResponse(string(marshaledResult)), nil
}
// ProviderOptions returns provider-specific options for this tool.
func (t *mcpFantasyTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOptions
}
// SetProviderOptions sets provider-specific options for this tool.
func (t *mcpFantasyTool) SetProviderOptions(opts fantasy.ProviderOptions) {
t.providerOptions = opts
}
+244
View File
@@ -0,0 +1,244 @@
package tools
import (
"context"
"encoding/json"
"strings"
"testing"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// newTestInProcessServer creates a simple MCP server with one tool for testing.
func newTestInProcessServer() *server.MCPServer {
srv := server.NewMCPServer("test-server", "1.0.0",
server.WithToolCapabilities(true),
)
srv.AddTool(
mcp.NewTool("greet",
mcp.WithDescription("Say hello"),
mcp.WithString("name", mcp.Required(), mcp.Description("Name to greet")),
),
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
name, _ := req.GetArguments()["name"].(string)
return mcp.NewToolResultText("Hello, " + name + "!"), nil
},
)
return srv
}
func TestInProcessTransportType(t *testing.T) {
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: newTestInProcessServer(),
}
if got := cfg.GetTransportType(); got != "inprocess" {
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
}
}
func TestInProcessTransportTypeInferred(t *testing.T) {
// When Type is empty but InProcessServer is set, infer "inprocess".
cfg := config.MCPServerConfig{
InProcessServer: newTestInProcessServer(),
}
if got := cfg.GetTransportType(); got != "inprocess" {
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
}
}
func TestInProcessValidation(t *testing.T) {
// Valid: InProcessServer is set.
validCfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"test": {
Type: "inprocess",
InProcessServer: newTestInProcessServer(),
},
},
}
if err := validCfg.Validate(); err != nil {
t.Errorf("expected valid config, got error: %v", err)
}
// Invalid: type is inprocess but InProcessServer is nil.
invalidCfg := &config.Config{
MCPServers: map[string]config.MCPServerConfig{
"test": {
Type: "inprocess",
},
},
}
if err := invalidCfg.Validate(); err == nil {
t.Error("expected validation error for nil InProcessServer, got nil")
}
}
func TestConnectionPoolInProcessClient(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
ctx := context.Background()
srv := newTestInProcessServer()
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: srv,
}
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
if err != nil {
t.Fatalf("GetConnection failed: %v", err)
}
// Verify the connection is healthy and functional.
if !conn.isHealthy {
t.Error("expected connection to be healthy")
}
// List tools to verify the connection works end-to-end.
toolsResp, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
t.Fatalf("ListTools failed: %v", err)
}
if len(toolsResp.Tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(toolsResp.Tools))
}
if toolsResp.Tools[0].Name != "greet" {
t.Errorf("expected tool name 'greet', got %q", toolsResp.Tools[0].Name)
}
}
func TestConnectionPoolInProcessToolExecution(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
ctx := context.Background()
srv := newTestInProcessServer()
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: srv,
}
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
if err != nil {
t.Fatalf("GetConnection failed: %v", err)
}
// Call the tool.
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
Request: mcp.Request{Method: "tools/call"},
Params: mcp.CallToolParams{
Name: "greet",
Arguments: map[string]any{"name": "World"},
},
})
if err != nil {
t.Fatalf("CallTool failed: %v", err)
}
if result.IsError {
t.Error("expected non-error result")
}
if len(result.Content) == 0 {
t.Fatal("expected at least one content block")
}
text, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatalf("expected TextContent, got %T", result.Content[0])
}
if text.Text != "Hello, World!" {
t.Errorf("expected 'Hello, World!', got %q", text.Text)
}
}
func TestMCPToolManagerInProcess(t *testing.T) {
ctx := context.Background()
srv := newTestInProcessServer()
mgr := NewMCPToolManager()
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: srv,
}
count, err := mgr.AddServer(ctx, "myserver", cfg)
if err != nil {
t.Fatalf("AddServer failed: %v", err)
}
if count != 1 {
t.Errorf("expected 1 tool, got %d", count)
}
tools := mgr.GetTools()
if len(tools) != 1 {
t.Fatalf("expected 1 tool, got %d", len(tools))
}
if tools[0].Name != "myserver__greet" {
t.Errorf("expected tool name 'myserver__greet', got %q", tools[0].Name)
}
// Execute the tool.
input, _ := json.Marshal(map[string]any{"name": "SDK"})
result, err := mgr.ExecuteTool(ctx, "myserver__greet", string(input))
if err != nil {
t.Fatalf("ExecuteTool failed: %v", err)
}
if result.IsError {
t.Error("expected non-error result")
}
if result.Content == "" {
t.Error("expected non-empty result content")
}
// Verify result contains our greeting.
if !strings.Contains(result.Content, "Hello, SDK!") {
t.Errorf("expected 'Hello, SDK!' in result, got %q", result.Content)
}
}
func TestConnectionPoolInProcessInvalidServer(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
ctx := context.Background()
// Pass a non-*server.MCPServer value.
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: "not a server",
}
_, err := pool.GetConnection(ctx, "bad", cfg)
if err == nil {
t.Fatal("expected error for invalid InProcessServer type")
}
}
func TestConnectionPoolInProcessReuse(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
ctx := context.Background()
srv := newTestInProcessServer()
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: srv,
}
// Get connection twice — should reuse.
conn1, err := pool.GetConnection(ctx, "reuse-test", cfg)
if err != nil {
t.Fatalf("first GetConnection failed: %v", err)
}
conn2, err := pool.GetConnection(ctx, "reuse-test", cfg)
if err != nil {
t.Fatalf("second GetConnection failed: %v", err)
}
if conn1 != conn2 {
t.Error("expected same connection object on reuse")
}
}
+888 -46
View File
File diff suppressed because it is too large Load Diff
@@ -101,7 +101,7 @@ func TestMCPToolManager_AddServer_Integration(t *testing.T) {
// Verify tool names are prefixed.
toolNames := make(map[string]bool)
for _, tool := range tools {
toolNames[tool.Info().Name] = true
toolNames[tool.Name] = true
}
if !toolNames["echo__echo"] {
t.Error("Expected tool 'echo__echo'")
@@ -234,8 +234,8 @@ func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
// Remaining tools should all be from server-b.
for _, tool := range tools {
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
if !strings.HasPrefix(tool.Name, "server-b__") {
t.Errorf("Expected tool from server-b, got: %s", tool.Name)
}
}
+1 -1
View File
@@ -122,7 +122,7 @@ func TestMCPToolManager_Close_NilPool(t *testing.T) {
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
// non-existent connection returns an error.
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
err := pool.RemoveConnection("nonexistent")
+691
View File
@@ -0,0 +1,691 @@
package tools
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
mcpclient "github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// newTestPromptServer creates an in-process MCP server with prompt capabilities
// and the specified prompts + handlers. Returns an initialized MCPClient.
func newTestPromptServer(t *testing.T, prompts ...server.ServerPrompt) mcpclient.MCPClient {
t.Helper()
mcpServer := server.NewMCPServer(
"test-prompt-server", "1.0.0",
server.WithPromptCapabilities(true),
server.WithToolCapabilities(true),
)
if len(prompts) > 0 {
mcpServer.AddPrompts(prompts...)
}
// Add a dummy tool so loadServerTools has something to list.
mcpServer.AddTool(
mcp.NewTool("noop", mcp.WithDescription("no-op tool")),
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("ok"), nil
},
)
client, err := mcpclient.NewInProcessClient(mcpServer)
if err != nil {
t.Fatalf("NewInProcessClient: %v", err)
}
ctx := context.Background()
if err := client.Start(ctx); err != nil {
t.Fatalf("client.Start: %v", err)
}
initReq := mcp.InitializeRequest{}
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
if _, err := client.Initialize(ctx, initReq); err != nil {
t.Fatalf("client.Initialize: %v", err)
}
t.Cleanup(func() { _ = client.Close() })
return client
}
// injectClientIntoManager sets up an MCPToolManager with a pre-connected
// in-process client, bypassing the normal connection pool flow.
func injectClientIntoManager(t *testing.T, serverName string, client mcpclient.MCPClient) *MCPToolManager {
t.Helper()
m := NewMCPToolManager()
// Create a minimal connection pool and inject our client.
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
pool.mu.Lock()
pool.connections[serverName] = &MCPConnection{
client: client,
serverName: serverName,
isHealthy: true,
}
pool.mu.Unlock()
m.connectionPool = pool
return m
}
func TestLoadServerPrompts_Basic(t *testing.T) {
ctx := context.Background()
client := newTestPromptServer(t,
server.ServerPrompt{
Prompt: mcp.NewPrompt("review-pr",
mcp.WithPromptDescription("Review a pull request"),
mcp.WithArgument("pr_number",
mcp.ArgumentDescription("The PR number to review"),
mcp.RequiredArgument(),
),
mcp.WithArgument("focus",
mcp.ArgumentDescription("Area to focus on"),
),
),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
prNum := req.Params.Arguments["pr_number"]
return &mcp.GetPromptResult{
Description: "PR review prompt",
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Type: "text",
Text: fmt.Sprintf("Please review PR #%s", prNum),
},
},
},
}, nil
},
},
server.ServerPrompt{
Prompt: mcp.NewPrompt("explain-code",
mcp.WithPromptDescription("Explain a piece of code"),
),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Type: "text",
Text: "Please explain the following code.",
},
},
},
}, nil
},
},
)
m := injectClientIntoManager(t, "github", client)
conn := &MCPConnection{
client: client,
serverName: "github",
isHealthy: true,
}
m.loadServerPrompts(ctx, "github", conn)
prompts := m.GetPrompts()
if len(prompts) != 2 {
t.Fatalf("expected 2 prompts, got %d", len(prompts))
}
// Find review-pr prompt.
var reviewPR *MCPPrompt
for i := range prompts {
if prompts[i].Name == "review-pr" {
reviewPR = &prompts[i]
break
}
}
if reviewPR == nil {
t.Fatal("review-pr prompt not found")
}
if reviewPR.Description != "Review a pull request" {
t.Errorf("unexpected description: %q", reviewPR.Description)
}
if reviewPR.ServerName != "github" {
t.Errorf("unexpected server name: %q", reviewPR.ServerName)
}
if len(reviewPR.Arguments) != 2 {
t.Fatalf("expected 2 arguments, got %d", len(reviewPR.Arguments))
}
// Verify argument metadata.
arg0 := reviewPR.Arguments[0]
if arg0.Name != "pr_number" {
t.Errorf("expected first arg name 'pr_number', got %q", arg0.Name)
}
if !arg0.Required {
t.Error("expected first arg to be required")
}
arg1 := reviewPR.Arguments[1]
if arg1.Name != "focus" {
t.Errorf("expected second arg name 'focus', got %q", arg1.Name)
}
if arg1.Required {
t.Error("expected second arg to be optional")
}
}
func TestGetPrompt_ExpandsWithArgs(t *testing.T) {
ctx := context.Background()
client := newTestPromptServer(t,
server.ServerPrompt{
Prompt: mcp.NewPrompt("greet",
mcp.WithPromptDescription("Greet someone"),
mcp.WithArgument("name", mcp.RequiredArgument()),
),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
name := req.Params.Arguments["name"]
return &mcp.GetPromptResult{
Description: "Greeting",
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Type: "text",
Text: fmt.Sprintf("Hello, %s!", name),
},
},
},
}, nil
},
},
)
m := injectClientIntoManager(t, "myserver", client)
result, err := m.GetPrompt(ctx, "myserver", "greet", map[string]string{"name": "World"})
if err != nil {
t.Fatalf("GetPrompt error: %v", err)
}
if result.Description != "Greeting" {
t.Errorf("unexpected description: %q", result.Description)
}
if len(result.Messages) != 1 {
t.Fatalf("expected 1 message, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" {
t.Errorf("unexpected role: %q", result.Messages[0].Role)
}
if result.Messages[0].Content != "Hello, World!" {
t.Errorf("unexpected content: %q", result.Messages[0].Content)
}
}
func TestGetPrompt_MultipleMessages(t *testing.T) {
ctx := context.Background()
client := newTestPromptServer(t,
server.ServerPrompt{
Prompt: mcp.NewPrompt("chat-starter"),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{Type: "text", Text: "What is Go?"},
},
{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{Type: "text", Text: "Go is a programming language."},
},
{
Role: mcp.RoleUser,
Content: mcp.TextContent{Type: "text", Text: "Tell me more."},
},
},
}, nil
},
},
)
m := injectClientIntoManager(t, "server", client)
result, err := m.GetPrompt(ctx, "server", "chat-starter", nil)
if err != nil {
t.Fatalf("GetPrompt error: %v", err)
}
if len(result.Messages) != 3 {
t.Fatalf("expected 3 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != "user" {
t.Errorf("msg[0] role: got %q, want 'user'", result.Messages[0].Role)
}
if result.Messages[1].Role != "assistant" {
t.Errorf("msg[1] role: got %q, want 'assistant'", result.Messages[1].Role)
}
if result.Messages[2].Content != "Tell me more." {
t.Errorf("msg[2] content: got %q, want 'Tell me more.'", result.Messages[2].Content)
}
}
func TestGetPrompt_ServerNotFound(t *testing.T) {
m := NewMCPToolManager()
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
m.connectionPool = pool
_, err := m.GetPrompt(context.Background(), "nonexistent", "foo", nil)
if err == nil {
t.Fatal("expected error for nonexistent server")
}
}
func TestGetPrompt_NoPool(t *testing.T) {
m := NewMCPToolManager()
_, err := m.GetPrompt(context.Background(), "any", "foo", nil)
if err == nil {
t.Fatal("expected error with no pool")
}
}
func TestRemoveServer_RemovesPrompts(t *testing.T) {
ctx := context.Background()
client := newTestPromptServer(t,
server.ServerPrompt{
Prompt: mcp.NewPrompt("my-prompt",
mcp.WithPromptDescription("A test prompt"),
),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{
Messages: []mcp.PromptMessage{
{Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "hi"}},
},
}, nil
},
},
)
m := injectClientIntoManager(t, "testsvr", client)
// Manually populate tools and prompts as loadServerTools would.
conn := m.connectionPool.connections["testsvr"]
m.loadServerPrompts(ctx, "testsvr", conn)
// Also add a fake tool mapping so RemoveServer finds the server.
m.toolMap["testsvr__noop"] = &toolMapping{
serverName: "testsvr",
originalName: "noop",
}
m.tools = append(m.tools, MCPTool{
Name: "testsvr__noop",
ServerName: "testsvr",
})
// Verify prompts exist before removal.
if got := len(m.GetPrompts()); got != 1 {
t.Fatalf("expected 1 prompt before removal, got %d", got)
}
// Remove the server.
err := m.RemoveServer("testsvr")
if err != nil {
t.Fatalf("RemoveServer error: %v", err)
}
// Verify prompts are gone.
if got := len(m.GetPrompts()); got != 0 {
t.Fatalf("expected 0 prompts after removal, got %d", got)
}
}
func TestLoadServerPrompts_NoPromptCapability(t *testing.T) {
// Server without prompt capabilities — ListPrompts should fail gracefully.
mcpServer := server.NewMCPServer("no-prompts", "1.0.0",
server.WithToolCapabilities(true),
// No WithPromptCapabilities
)
mcpServer.AddTool(
mcp.NewTool("noop"),
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return mcp.NewToolResultText("ok"), nil
},
)
client, err := mcpclient.NewInProcessClient(mcpServer)
if err != nil {
t.Fatalf("NewInProcessClient: %v", err)
}
ctx := context.Background()
_ = client.Start(ctx)
initReq := mcp.InitializeRequest{}
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
_, _ = client.Initialize(ctx, initReq)
t.Cleanup(func() { _ = client.Close() })
m := NewMCPToolManager()
conn := &MCPConnection{
client: client,
serverName: "no-prompts",
isHealthy: true,
}
// Should not panic or error — just silently skip.
m.loadServerPrompts(ctx, "no-prompts", conn)
if got := len(m.GetPrompts()); got != 0 {
t.Fatalf("expected 0 prompts from server without prompt capability, got %d", got)
}
}
func TestExtractPromptContent(t *testing.T) {
t.Run("TextContent", func(t *testing.T) {
text, parts := extractPromptContent(mcp.TextContent{Type: "text", Text: "hello world"})
if text != "hello world" {
t.Errorf("text = %q, want %q", text, "hello world")
}
if len(parts) != 0 {
t.Errorf("expected 0 file parts, got %d", len(parts))
}
})
t.Run("ImageContent", func(t *testing.T) {
// base64 of "fake image"
encoded := base64.StdEncoding.EncodeToString([]byte("fake image"))
text, parts := extractPromptContent(mcp.ImageContent{
Type: "image",
Data: encoded,
MIMEType: "image/png",
})
if text != "" {
t.Errorf("expected empty text, got %q", text)
}
if len(parts) != 1 {
t.Fatalf("expected 1 file part, got %d", len(parts))
}
if parts[0].MediaType != "image/png" {
t.Errorf("media type = %q, want %q", parts[0].MediaType, "image/png")
}
if parts[0].Filename != "image.png" {
t.Errorf("filename = %q, want %q", parts[0].Filename, "image.png")
}
if string(parts[0].Data) != "fake image" {
t.Errorf("data = %q, want %q", string(parts[0].Data), "fake image")
}
})
t.Run("ImageContent_DefaultMIME", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("img"))
_, parts := extractPromptContent(mcp.ImageContent{
Type: "image",
Data: encoded,
// no MIMEType → should default to image/png
})
if len(parts) != 1 {
t.Fatalf("expected 1 file part, got %d", len(parts))
}
if parts[0].MediaType != "image/png" {
t.Errorf("default MIME = %q, want %q", parts[0].MediaType, "image/png")
}
})
t.Run("AudioContent", func(t *testing.T) {
encoded := base64.StdEncoding.EncodeToString([]byte("fake audio"))
text, parts := extractPromptContent(mcp.AudioContent{
Type: "audio",
Data: encoded,
MIMEType: "audio/mp3",
})
if text != "" {
t.Errorf("expected empty text, got %q", text)
}
if len(parts) != 1 {
t.Fatalf("expected 1 file part, got %d", len(parts))
}
if parts[0].MediaType != "audio/mp3" {
t.Errorf("media type = %q, want %q", parts[0].MediaType, "audio/mp3")
}
if parts[0].Filename != "audio.wav" {
t.Errorf("filename = %q, want %q", parts[0].Filename, "audio.wav")
}
})
t.Run("EmbeddedResource_Text", func(t *testing.T) {
text, parts := extractPromptContent(mcp.EmbeddedResource{
Type: "resource",
Resource: mcp.TextResourceContents{
URI: "file:///project/main.go",
MIMEType: "text/x-go",
Text: "package main",
},
})
if text == "" {
t.Fatal("expected non-empty text for text resource")
}
if !strings.Contains(text, "package main") {
t.Errorf("text should contain resource content, got %q", text)
}
if !strings.Contains(text, "file:///project/main.go") {
t.Errorf("text should contain URI, got %q", text)
}
if len(parts) != 0 {
t.Errorf("expected 0 file parts for text resource, got %d", len(parts))
}
})
t.Run("EmbeddedResource_Blob", func(t *testing.T) {
blobData := []byte("binary content")
encoded := base64.StdEncoding.EncodeToString(blobData)
text, parts := extractPromptContent(mcp.EmbeddedResource{
Type: "resource",
Resource: mcp.BlobResourceContents{
URI: "file:///project/data.bin",
MIMEType: "application/octet-stream",
Blob: encoded,
},
})
if text != "" {
t.Errorf("expected empty text for blob resource, got %q", text)
}
if len(parts) != 1 {
t.Fatalf("expected 1 file part for blob resource, got %d", len(parts))
}
if parts[0].Filename != "data.bin" {
t.Errorf("filename = %q, want %q", parts[0].Filename, "data.bin")
}
if parts[0].MediaType != "application/octet-stream" {
t.Errorf("media type = %q, want %q", parts[0].MediaType, "application/octet-stream")
}
if string(parts[0].Data) != "binary content" {
t.Errorf("data = %q, want %q", string(parts[0].Data), "binary content")
}
})
t.Run("ResourceLink", func(t *testing.T) {
text, parts := extractPromptContent(mcp.ResourceLink{
Type: "resource_link",
URI: "file:///docs/readme.md",
Name: "readme.md",
})
if text == "" {
t.Fatal("expected non-empty text for resource link")
}
if !strings.Contains(text, "file:///docs/readme.md") {
t.Errorf("text should contain URI, got %q", text)
}
if !strings.Contains(text, "readme.md") {
t.Errorf("text should contain name, got %q", text)
}
if len(parts) != 0 {
t.Errorf("expected 0 file parts for resource link, got %d", len(parts))
}
})
t.Run("InvalidBase64", func(t *testing.T) {
_, parts := extractPromptContent(mcp.ImageContent{
Type: "image",
Data: "not-valid-base64!!!",
MIMEType: "image/png",
})
if len(parts) != 0 {
t.Errorf("expected 0 file parts for invalid base64, got %d", len(parts))
}
})
t.Run("NilContent", func(t *testing.T) {
text, parts := extractPromptContent((*mcp.TextContent)(nil))
if text != "" {
t.Errorf("expected empty text for nil, got %q", text)
}
if len(parts) != 0 {
t.Errorf("expected 0 parts for nil, got %d", len(parts))
}
})
}
func TestFilenameFromURI(t *testing.T) {
tests := []struct {
uri string
want string
}{
{"file:///path/to/image.png", "image.png"},
{"file:///single.txt", "single.txt"},
{"resource://server/data.json", "data.json"},
{"nopath", "nopath"},
{"", "resource"},
}
for _, tt := range tests {
t.Run(tt.uri, func(t *testing.T) {
got := filenameFromURI(tt.uri)
if got != tt.want {
t.Errorf("filenameFromURI(%q) = %q, want %q", tt.uri, got, tt.want)
}
})
}
}
func TestGetPrompt_EmbeddedResources(t *testing.T) {
ctx := context.Background()
imgData := base64.StdEncoding.EncodeToString([]byte("fake-png"))
blobData := base64.StdEncoding.EncodeToString([]byte("binary-blob"))
client := newTestPromptServer(t,
server.ServerPrompt{
Prompt: mcp.NewPrompt("review-with-files",
mcp.WithPromptDescription("Review with embedded resources"),
),
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
return &mcp.GetPromptResult{
Description: "Review prompt with embedded files",
Messages: []mcp.PromptMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{Type: "text", Text: "Please review these files:"},
},
{
Role: mcp.RoleUser,
Content: mcp.EmbeddedResource{
Type: "resource",
Resource: mcp.TextResourceContents{
URI: "file:///src/main.go",
MIMEType: "text/x-go",
Text: "package main\n\nfunc main() {}",
},
},
},
{
Role: mcp.RoleUser,
Content: mcp.ImageContent{
Type: "image",
Data: imgData,
MIMEType: "image/png",
},
},
{
Role: mcp.RoleUser,
Content: mcp.EmbeddedResource{
Type: "resource",
Resource: mcp.BlobResourceContents{
URI: "file:///data/model.bin",
MIMEType: "application/octet-stream",
Blob: blobData,
},
},
},
},
}, nil
},
},
)
m := injectClientIntoManager(t, "test", client)
result, err := m.GetPrompt(ctx, "test", "review-with-files", nil)
if err != nil {
t.Fatalf("GetPrompt error: %v", err)
}
if result.Description != "Review prompt with embedded files" {
t.Errorf("unexpected description: %q", result.Description)
}
// Should have 4 messages: text, embedded text resource, image, embedded blob
if len(result.Messages) != 4 {
t.Fatalf("expected 4 messages, got %d", len(result.Messages))
}
// Message 0: plain text
msg0 := result.Messages[0]
if msg0.Content != "Please review these files:" {
t.Errorf("msg[0] content = %q", msg0.Content)
}
if len(msg0.FileParts) != 0 {
t.Errorf("msg[0] expected 0 file parts, got %d", len(msg0.FileParts))
}
// Message 1: embedded text resource → inlined as text
msg1 := result.Messages[1]
if !strings.Contains(msg1.Content, "package main") {
t.Errorf("msg[1] should contain resource text, got %q", msg1.Content)
}
if len(msg1.FileParts) != 0 {
t.Errorf("msg[1] expected 0 file parts (text resource), got %d", len(msg1.FileParts))
}
// Message 2: image → file part
msg2 := result.Messages[2]
if msg2.Content != "" {
t.Errorf("msg[2] expected empty text for image, got %q", msg2.Content)
}
if len(msg2.FileParts) != 1 {
t.Fatalf("msg[2] expected 1 file part, got %d", len(msg2.FileParts))
}
if msg2.FileParts[0].MediaType != "image/png" {
t.Errorf("msg[2] file part MIME = %q", msg2.FileParts[0].MediaType)
}
if string(msg2.FileParts[0].Data) != "fake-png" {
t.Errorf("msg[2] file part data = %q", string(msg2.FileParts[0].Data))
}
// Message 3: embedded blob resource → file part
msg3 := result.Messages[3]
if msg3.Content != "" {
t.Errorf("msg[3] expected empty text for blob resource, got %q", msg3.Content)
}
if len(msg3.FileParts) != 1 {
t.Fatalf("msg[3] expected 1 file part, got %d", len(msg3.FileParts))
}
if msg3.FileParts[0].Filename != "model.bin" {
t.Errorf("msg[3] filename = %q, want %q", msg3.FileParts[0].Filename, "model.bin")
}
if string(msg3.FileParts[0].Data) != "binary-blob" {
t.Errorf("msg[3] file part data = %q", string(msg3.FileParts[0].Data))
}
}
+404
View File
@@ -0,0 +1,404 @@
package tools
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
// MCPTaskMode controls when the connection pool augments tools/call requests
// with MCP task metadata. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks.
type MCPTaskMode string
const (
// MCPTaskModeAuto augments tools/call with task metadata only when the
// server advertises tasks/toolCalls capability during initialize.
MCPTaskModeAuto MCPTaskMode = "auto"
// MCPTaskModeNever forces every tools/call to be issued synchronously
// (no Task field in the request), regardless of server capability.
MCPTaskModeNever MCPTaskMode = "never"
// MCPTaskModeAlways always sets a Task field on the tools/call request,
// even when the server didn't advertise task support. The server may
// still respond synchronously; this just opts in unconditionally on
// the client side.
MCPTaskModeAlways MCPTaskMode = "always"
)
// ParseTaskMode normalises a per-server tasks-mode string from
// configuration. Empty input maps to MCPTaskModeAuto. Unknown values are
// also treated as MCPTaskModeAuto so a stray config typo never breaks
// existing flows.
func ParseTaskMode(s string) MCPTaskMode {
switch strings.ToLower(strings.TrimSpace(s)) {
case "", "auto":
return MCPTaskModeAuto
case "never", "off", "disabled":
return MCPTaskModeNever
case "always", "force":
return MCPTaskModeAlways
default:
return MCPTaskModeAuto
}
}
// MCPTaskInfo is the connection-layer view of an MCP Task. It mirrors the
// upstream mcp.Task but exposes Go-native types and includes the originating
// server name. SDK-level wrappers re-export this under public-facing names.
type MCPTaskInfo struct {
// Server is the configured MCP server name this task lives on.
Server string
// TaskID is the server-assigned identifier for the task.
TaskID string
// Status is the current task lifecycle state.
Status mcp.TaskStatus
// StatusMessage is an optional human-readable description.
StatusMessage string
// CreatedAt is the wall-clock time the task was created (best-effort
// parsed from the server's ISO-8601 timestamp; zero on parse failure).
CreatedAt time.Time
// UpdatedAt is the wall-clock time the task was last updated (best-
// effort parsed; zero on parse failure).
UpdatedAt time.Time
// TTL is the time-to-live the server intends to retain the task after
// creation. Zero means the server did not advertise a TTL.
TTL time.Duration
// PollInterval is the suggested polling interval. Zero means use the
// client's default.
PollInterval time.Duration
}
// MCPTaskProgress is emitted while the connection pool is waiting on a
// task-augmented tool call. It provides minimal feedback for SDK consumers
// that want to render progress widgets without subscribing to the full
// notifications/tasks/status channel (Phase 2).
type MCPTaskProgress struct {
Server string
TaskID string
Status mcp.TaskStatus
Message string
}
// MCPTaskProgressHandler is invoked once after a task is accepted and on
// every status transition observed by the polling loop. The final
// invocation always carries a terminal status. Implementations must not
// block; long work should be queued on a goroutine.
type MCPTaskProgressHandler func(MCPTaskProgress)
// MCPTaskConfig configures task-aware tool execution on the manager.
// All fields are optional; the zero value disables progress callbacks and
// applies sensible defaults.
type MCPTaskConfig struct {
// PerServerMode overrides the per-server TasksMode resolved from
// MCPServerConfig. Keys are server names. Missing entries fall back
// to the value from config. Used by SDK consumers that want to set
// modes programmatically.
PerServerMode map[string]MCPTaskMode
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
// tools/call. Zero means omit the TTL — let the server pick its own.
DefaultTTL time.Duration
// PollInterval is the fallback interval between tasks/get requests
// when the server does not suggest one. Zero defaults to 1 second.
PollInterval time.Duration
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
MaxPollInterval time.Duration
// Timeout is the maximum wall-clock duration to wait for a task to
// reach a terminal state. Zero defaults to 15 minutes. Independent
// of the per-call context deadline; whichever fires first wins.
Timeout time.Duration
// Progress, if non-nil, receives every status transition observed by
// the polling loop.
Progress MCPTaskProgressHandler
}
func (c MCPTaskConfig) resolved() MCPTaskConfig {
if c.PollInterval <= 0 {
c.PollInterval = 1 * time.Second
}
if c.MaxPollInterval <= 0 {
c.MaxPollInterval = 5 * time.Second
}
if c.Timeout <= 0 {
c.Timeout = 15 * time.Minute
}
return c
}
// requestIDCounter generates monotonically increasing JSON-RPC request IDs
// for low-level tools/call invocations that bypass the upstream client's
// ParseCallToolResult helper (necessary because that helper rejects task
// responses for lacking a "content" field).
//
// The counter is process-wide rather than per-manager so multiple managers
// or repeated calls within the same connection produce unique IDs.
var requestIDCounter atomic.Int64
func nextRequestID() mcp.RequestId {
return mcp.NewRequestId(requestIDCounter.Add(1))
}
// callToolWithTask issues tools/call directly on the transport so we can
// observe both response shapes:
//
// - {"content": [...], ...} — synchronous CallToolResult.
// - {"task": {...}, ...} — asynchronous CreateTaskResult.
//
// On success exactly one of (callResult, taskResult) is non-nil. The
// upstream client.CallTool helper parses the response with
// mcp.ParseCallToolResult which requires a "content" field, so it cannot
// be used for task-augmented calls.
func callToolWithTask(
ctx context.Context,
c *client.Client,
params mcp.CallToolParams,
) (callResult *mcp.CallToolResult, taskResult *mcp.CreateTaskResult, err error) {
tr := c.GetTransport()
if tr == nil {
return nil, nil, errors.New("mcp client has no transport")
}
req := transport.JSONRPCRequest{
JSONRPC: mcp.JSONRPC_VERSION,
ID: nextRequestID(),
Method: string(mcp.MethodToolsCall),
Params: params,
}
resp, sendErr := tr.SendRequest(ctx, req)
if sendErr != nil {
return nil, nil, sendErr
}
if resp.Error != nil {
return nil, nil, resp.Error.AsError()
}
// Peek at the raw result to decide which shape we got.
var probe struct {
Task json.RawMessage `json:"task"`
Content json.RawMessage `json:"content"`
}
raw := resp.Result
if len(raw) == 0 {
return nil, nil, errors.New("empty tools/call result")
}
if uErr := json.Unmarshal(raw, &probe); uErr != nil {
return nil, nil, fmt.Errorf("decode tools/call result: %w", uErr)
}
if len(probe.Task) > 0 && string(probe.Task) != "null" {
// Task-augmented response.
var ct mcp.CreateTaskResult
if uErr := json.Unmarshal(raw, &ct); uErr != nil {
return nil, nil, fmt.Errorf("decode CreateTaskResult: %w", uErr)
}
return nil, &ct, nil
}
// Synchronous response — defer to the upstream parser so content blocks
// are typed correctly (TextContent, ImageContent, ResourceLink, etc.).
cr, pErr := mcp.ParseCallToolResult(&raw)
if pErr != nil {
return nil, nil, fmt.Errorf("parse CallToolResult: %w", pErr)
}
return cr, nil, nil
}
// pollTaskUntilTerminal blocks until the task reaches a terminal status,
// the context is cancelled, or the configured timeout elapses. On
// cancellation it best-effort issues tasks/cancel before returning.
func pollTaskUntilTerminal(
ctx context.Context,
c *client.Client,
serverName string,
task mcp.Task,
cfg MCPTaskConfig,
progress MCPTaskProgressHandler,
) (*mcp.TaskResultResult, error) {
cfg = cfg.resolved()
deadline := time.Now().Add(cfg.Timeout)
emit := func(status mcp.TaskStatus, msg string) {
if progress != nil {
progress(MCPTaskProgress{Server: serverName, TaskID: task.TaskId, Status: status, Message: msg})
}
}
emit(task.Status, task.StatusMessage)
current := task
interval := cfg.PollInterval
if current.PollInterval != nil && *current.PollInterval > 0 {
interval = time.Duration(*current.PollInterval) * time.Millisecond
}
if interval > cfg.MaxPollInterval {
interval = cfg.MaxPollInterval
}
for !current.Status.IsTerminal() {
if time.Now().After(deadline) {
cancelTaskBestEffort(c, current.TaskId)
return nil, fmt.Errorf("task %s timed out after %s", current.TaskId, cfg.Timeout)
}
// Wait between polls or abort early on context cancellation.
select {
case <-ctx.Done():
cancelTaskBestEffort(c, current.TaskId)
return nil, ctx.Err()
case <-time.After(interval):
}
got, err := c.GetTask(ctx, mcp.GetTaskRequest{
Params: mcp.GetTaskParams{TaskId: current.TaskId},
})
if err != nil {
// Transient transport hiccup — propagate immediately. The
// upstream agent layer treats this like any other tool error.
return nil, fmt.Errorf("tasks/get failed: %w", err)
}
current = got.Task
if current.Status != task.Status || current.StatusMessage != task.StatusMessage {
emit(current.Status, current.StatusMessage)
task = current
}
// Honour any updated suggested poll interval, capped at the limit.
if current.PollInterval != nil && *current.PollInterval > 0 {
interval = min(time.Duration(*current.PollInterval)*time.Millisecond, cfg.MaxPollInterval)
}
}
// Terminal state reached. Emit one last progress event and fetch the
// definitive tool result.
emit(current.Status, current.StatusMessage)
if current.Status == mcp.TaskStatusCancelled {
return nil, fmt.Errorf("task %s was cancelled", current.TaskId)
}
res, err := fetchTaskResult(ctx, c, current.TaskId)
if err != nil {
return nil, fmt.Errorf("tasks/result failed: %w", err)
}
if current.Status == mcp.TaskStatusFailed && res != nil && !res.IsError {
// The server flagged the task as failed but didn't decorate the
// result. Surface the status message so the caller still sees a
// useful tool-error.
return nil, fmt.Errorf("task %s failed: %s", current.TaskId, current.StatusMessage)
}
return res, nil
}
// fetchTaskResult issues tasks/result on the transport and parses the raw
// response. The upstream client.TaskResult helper delegates to
// mcp.ParseTaskResultResult which (as of mcp-go v0.51.0) looks for the
// content array under a nested "result" key that never exists in the
// wire format — leading to systematically empty Content. Doing the
// parse here keeps the polling path working until that is fixed upstream.
func fetchTaskResult(ctx context.Context, c *client.Client, taskID string) (*mcp.TaskResultResult, error) {
tr := c.GetTransport()
if tr == nil {
return nil, errors.New("mcp client has no transport")
}
req := transport.JSONRPCRequest{
JSONRPC: mcp.JSONRPC_VERSION,
ID: nextRequestID(),
Method: string(mcp.MethodTasksResult),
Params: mcp.TaskResultParams{TaskId: taskID},
}
resp, err := tr.SendRequest(ctx, req)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, resp.Error.AsError()
}
// Manually decode the wire shape: {"_meta": {...}, "content": [...],
// "structuredContent": ..., "isError": bool}.
var shape struct {
Meta json.RawMessage `json:"_meta"`
Content []json.RawMessage `json:"content"`
StructuredContent any `json:"structuredContent"`
IsError bool `json:"isError"`
}
if err := json.Unmarshal(resp.Result, &shape); err != nil {
return nil, fmt.Errorf("decode tasks/result: %w", err)
}
out := &mcp.TaskResultResult{
StructuredContent: shape.StructuredContent,
IsError: shape.IsError,
}
if len(shape.Meta) > 0 && string(shape.Meta) != "null" {
var metaMap map[string]any
if err := json.Unmarshal(shape.Meta, &metaMap); err == nil {
out.Meta = mcp.NewMetaFromMap(metaMap)
}
}
for _, raw := range shape.Content {
var contentMap map[string]any
if err := json.Unmarshal(raw, &contentMap); err != nil {
return nil, fmt.Errorf("decode content block: %w", err)
}
parsed, err := mcp.ParseContent(contentMap)
if err != nil {
return nil, fmt.Errorf("parse content block: %w", err)
}
out.Content = append(out.Content, parsed)
}
return out, nil
}
// cancelTaskBestEffort issues tasks/cancel and ignores any error. Used on
// context cancellation paths where the connection is already going away.
func cancelTaskBestEffort(c *client.Client, taskID string) {
if c == nil || taskID == "" {
return
}
cancelCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, _ = c.CancelTask(cancelCtx, mcp.CancelTaskRequest{
Params: mcp.CancelTaskParams{TaskId: taskID},
})
}
// taskFromMCP converts a wire-format mcp.Task to our richer connection-
// layer view. Unparseable timestamps surface as the zero time.
func taskFromMCP(serverName string, t mcp.Task) MCPTaskInfo {
out := MCPTaskInfo{
Server: serverName,
TaskID: t.TaskId,
Status: t.Status,
StatusMessage: t.StatusMessage,
}
if t.CreatedAt != "" {
if v, err := time.Parse(time.RFC3339, t.CreatedAt); err == nil {
out.CreatedAt = v
}
}
if t.LastUpdatedAt != "" {
if v, err := time.Parse(time.RFC3339, t.LastUpdatedAt); err == nil {
out.UpdatedAt = v
}
}
if t.TTL != nil {
out.TTL = time.Duration(*t.TTL) * time.Millisecond
}
if t.PollInterval != nil {
out.PollInterval = time.Duration(*t.PollInterval) * time.Millisecond
}
return out
}
+294
View File
@@ -0,0 +1,294 @@
package tools
import (
"context"
"strings"
"testing"
"time"
"github.com/mark3labs/kit/internal/config"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// newTaskTestInProcessServer builds an in-process MCP server with a
// task-augmented tool. The handler simulates work by sleeping briefly
// before completing.
//
// Important: the upstream mcp-go server cancels the request context as
// soon as the synchronous part of the tools/call returns (see
// request_handler.go:85, `defer cancel()`). Task goroutines spawned by
// AddTaskTool inherit that context and therefore see context.Canceled
// the instant they start. Real-world transports (stdio, SSE, streamable
// HTTP) don't trip this because they keep the connection — and a
// background context — alive across the async work, but the in-process
// transport runs entirely on the request goroutine. To test the polling
// path realistically we detach from the request context here.
func newTaskTestInProcessServer(t *testing.T, workDuration time.Duration) *server.MCPServer {
t.Helper()
srv := server.NewMCPServer("task-test", "1.0.0",
server.WithToolCapabilities(true),
// list=true, cancel=true, toolCallTasks=true so capability detection,
// cancellation, and tool augmentation all flow through.
server.WithTaskCapabilities(true, true, true),
)
srv.AddTaskTool(
mcp.Tool{
Name: "long_running",
Description: "Sleep, then echo the input string.",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"msg": map[string]any{"type": "string"},
},
},
Execution: &mcp.ToolExecution{
TaskSupport: mcp.TaskSupportRequired,
},
},
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CreateTaskResult, error) {
msg, _ := req.GetArguments()["msg"].(string)
// Detach from the request context so the task handler can
// outlive the synchronous request — see comment above.
time.Sleep(workDuration)
_ = ctx
return &mcp.CreateTaskResult{
Content: []mcp.Content{
mcp.TextContent{Type: "text", Text: "echo:" + msg},
},
}, nil
},
)
return srv
}
// newSyncOnlyServer is a server that does NOT advertise task capability.
// Used to verify the auto-detect path keeps the sync semantics.
func newSyncOnlyServer() *server.MCPServer {
srv := server.NewMCPServer("sync-only", "1.0.0",
server.WithToolCapabilities(true),
)
srv.AddTool(
mcp.NewTool("greet",
mcp.WithDescription("Say hello"),
mcp.WithString("name", mcp.Required()),
),
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
name, _ := req.GetArguments()["name"].(string)
return mcp.NewToolResultText("hi " + name), nil
},
)
return srv
}
func TestConnectionPoolAdvertisesTaskCapability(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
srv := newTaskTestInProcessServer(t, 0)
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: srv}
conn, err := pool.GetConnection(context.Background(), "tasks", cfg)
if err != nil {
t.Fatalf("GetConnection: %v", err)
}
init := conn.InitializeResult()
if init == nil {
t.Fatal("InitializeResult is nil after GetConnection")
}
if init.Capabilities.Tasks == nil {
t.Fatal("server did not advertise Tasks capability — initialize handshake regressed")
}
if !conn.SupportsToolTasks() {
t.Error("SupportsToolTasks should be true for a server with toolCallTasks=true")
}
if !pool.ServerSupportsToolTasks("tasks") {
t.Error("ServerSupportsToolTasks should mirror the connection's value")
}
}
func TestConnectionPoolDetectsAbsentTaskCapability(t *testing.T) {
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
defer func() { _ = pool.Close() }()
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: newSyncOnlyServer()}
conn, err := pool.GetConnection(context.Background(), "sync", cfg)
if err != nil {
t.Fatalf("GetConnection: %v", err)
}
if conn.SupportsToolTasks() {
t.Error("SupportsToolTasks should be false for a server that didn't advertise the capability")
}
}
func TestSupportsToolTasksFromInit(t *testing.T) {
cases := []struct {
name string
in *mcp.InitializeResult
want bool
}{
{"nil", nil, false},
{"no tasks", &mcp.InitializeResult{}, false},
{"tasks no requests", &mcp.InitializeResult{
Capabilities: mcp.ServerCapabilities{Tasks: &mcp.TasksCapability{}},
}, false},
{"tasks with toolCalls", &mcp.InitializeResult{
Capabilities: mcp.ServerCapabilities{Tasks: mcp.NewTasksCapability()},
}, true},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := supportsToolTasksFromInit(tc.in); got != tc.want {
t.Errorf("supportsToolTasksFromInit() = %v, want %v", got, tc.want)
}
})
}
}
func TestParseTaskMode(t *testing.T) {
cases := []struct {
in string
want MCPTaskMode
}{
{"", MCPTaskModeAuto},
{"auto", MCPTaskModeAuto},
{"AUTO", MCPTaskModeAuto},
{"never", MCPTaskModeNever},
{"off", MCPTaskModeNever},
{"always", MCPTaskModeAlways},
{"force", MCPTaskModeAlways},
{"bogus", MCPTaskModeAuto},
}
for _, tc := range cases {
if got := ParseTaskMode(tc.in); got != tc.want {
t.Errorf("ParseTaskMode(%q) = %q, want %q", tc.in, got, tc.want)
}
}
}
func TestExecuteToolPollsTaskToCompletion(t *testing.T) {
mgr := NewMCPToolManager()
mgr.SetTaskConfig(MCPTaskConfig{
PollInterval: 20 * time.Millisecond,
MaxPollInterval: 50 * time.Millisecond,
Timeout: 10 * time.Second,
})
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: newTaskTestInProcessServer(t, 50*time.Millisecond),
}
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
t.Fatalf("AddServer: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hello"}`)
if err != nil {
t.Fatalf("ExecuteTool: %v", err)
}
if res.IsError {
t.Fatalf("expected non-error result, got %s", res.Content)
}
if !strings.Contains(res.Content, "echo:hello") {
t.Errorf("expected result to contain 'echo:hello', got %s", res.Content)
}
}
func TestExecuteToolHonorsNeverMode(t *testing.T) {
// Even though the server advertises tasks/toolCalls, "never" should
// keep the call synchronous. Since the tool is TaskSupportRequired,
// the server returns an error rather than running it sync — we just
// verify the error surfaces (not a poll-loop hang).
mgr := NewMCPToolManager()
mgr.SetTaskConfig(MCPTaskConfig{
PerServerMode: map[string]MCPTaskMode{"tasks": MCPTaskModeNever},
Timeout: 2 * time.Second,
})
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: newTaskTestInProcessServer(t, 0),
}
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
t.Fatalf("AddServer: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
// We don't care which way the server fails the sync call; we just want
// to confirm we didn't hang in the polling loop and didn't panic.
_, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"x"}`)
if err == nil {
t.Fatal("expected an error when forcing sync execution of a task-required tool")
}
}
func TestExecuteToolEmitsProgress(t *testing.T) {
var statuses []mcp.TaskStatus
mgr := NewMCPToolManager()
mgr.SetTaskConfig(MCPTaskConfig{
PollInterval: 10 * time.Millisecond,
MaxPollInterval: 25 * time.Millisecond,
Timeout: 5 * time.Second,
Progress: func(p MCPTaskProgress) {
statuses = append(statuses, p.Status)
},
})
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: newTaskTestInProcessServer(t, 30*time.Millisecond),
}
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
t.Fatalf("AddServer: %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hi"}`); err != nil {
t.Fatalf("ExecuteTool: %v", err)
}
if len(statuses) == 0 {
t.Fatal("expected at least one progress event")
}
last := statuses[len(statuses)-1]
if !last.IsTerminal() {
t.Errorf("last progress event should be terminal, got %q", last)
}
}
func TestListGetCancelMCPTasksOnLoadedServer(t *testing.T) {
mgr := NewMCPToolManager()
cfg := config.MCPServerConfig{
Type: "inprocess",
InProcessServer: newTaskTestInProcessServer(t, 0),
}
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
t.Fatalf("AddServer: %v", err)
}
ctx := context.Background()
// tasks/list — no in-flight tasks yet, so we just verify the call
// succeeds and returns an empty slice (or any slice; the exact length
// depends on server retention policy).
if _, err := mgr.ListServerTasks(ctx, "tasks"); err != nil {
t.Errorf("ListServerTasks: %v", err)
}
// Unknown server should error cleanly without panicking.
if _, err := mgr.GetServerTask(ctx, "unknown", "abc"); err == nil {
t.Error("GetServerTask on unknown server should error")
}
if _, err := mgr.CancelServerTask(ctx, "unknown", "abc"); err == nil {
t.Error("CancelServerTask on unknown server should error")
}
}
+2 -4
View File
@@ -103,14 +103,12 @@ func TestMCPToolManager_EmptyConfig(t *testing.T) {
// Test that we can get tool info for each tool
for _, tool := range tools {
info := tool.Info()
// Check that the tool has a valid name
if info.Name == "" {
if tool.Name == "" {
t.Error("Tool has empty name")
}
t.Logf("Tool: %s, Description: %s", info.Name, info.Description)
t.Logf("Tool: %s, Description: %s", tool.Name, tool.Description)
}
}
+1 -25
View File
@@ -19,7 +19,7 @@ import (
// newTestInput creates an InputComponent with the given AppController (may be nil).
func newTestInput(ctrl AppController) *InputComponent {
return NewInputComponent(80, "test input", ctrl)
return NewInputComponent(80, ctrl)
}
// sendInputMsg calls component.Update with the given message, returns the
@@ -69,30 +69,6 @@ func TestInputComponent_SubmitEmitsSubmitMsg(t *testing.T) {
}
}
// TestInputComponent_CtrlD_SubmitEmitsSubmitMsg verifies that ctrl+d also
// submits the text.
func TestInputComponent_CtrlD_SubmitEmitsSubmitMsg(t *testing.T) {
ctrl := &stubAppController{}
c := newTestInput(ctrl)
c.textarea.SetValue("ctrl+d submit")
c.lastValue = "ctrl+d submit"
_, cmd := sendInputMsg(c, tea.KeyPressMsg{Code: 'd', Mod: tea.ModCtrl})
msg := runCmd(cmd)
if msg == nil {
t.Fatal("expected a cmd from ctrl+d on non-empty input")
}
sm, ok := msg.(core.SubmitMsg)
if !ok {
t.Fatalf("expected submitMsg from ctrl+d, got %T", msg)
}
if sm.Text != "ctrl+d submit" {
t.Fatalf("expected Text='ctrl+d submit', got %q", sm.Text)
}
}
// TestInputComponent_EmptySubmit_NoCmd verifies that submitting an empty or
// whitespace-only string produces no cmd.
func TestInputComponent_EmptySubmit_NoCmd(t *testing.T) {
+2 -1
View File
@@ -20,6 +20,7 @@ type SlashCommand struct {
Aliases []string
Category string // e.g., "Navigation", "System", "Info"
Complete func(prefix string) []string // optional argument tab-completion
HasArgs bool // true when the command expects arguments (e.g. prompt templates with placeholders)
}
// SlashCommands provides the global registry of all available slash commands
@@ -83,7 +84,7 @@ var SlashCommands = []SlashCommand{
},
{
Name: "/thinking",
Description: "Set thinking/reasoning level (off, minimal, low, medium, high)",
Description: "Set thinking/reasoning level (off, none, minimal, low, medium, high)",
Category: "System",
Aliases: []string{"/think"},
Complete: func(prefix string) []string {
+5
View File
@@ -25,6 +25,11 @@ type SubmitMsg struct {
// presses ESC a second time, the canceling state is reset to false.
type CancelTimerExpiredMsg struct{}
// CtrlCResetMsg is sent after a short delay when the user presses Ctrl+C to
// clear input. If the user doesn't press Ctrl+C again within the timeout,
// the ctrlCPressedOnce flag is reset so the next Ctrl+C will clear again.
type CtrlCResetMsg struct{}
// --- Tree session events ---
// TreeNodeSelectedMsg is sent when the user selects a node in the tree selector.
+8 -1
View File
@@ -29,9 +29,16 @@ type (
ExtensionCommand = commands.ExtensionCommand
)
// Re-export functions from fileutil package
// Re-export functions and types from fileutil package
var ProcessFileAttachments = fileutil.ProcessFileAttachments
// Re-export types from fileutil
type (
FileAttachmentResult = fileutil.FileAttachmentResult
FilePart = fileutil.FilePart
MCPResourceReader = fileutil.MCPResourceReader
)
// Re-export from prefs package
var (
LoadThemePreference = prefs.LoadThemePreference
+60 -4
View File
@@ -6,22 +6,78 @@ import (
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
// FileSuggestion represents a single file or directory suggestion for the @
// autocomplete popup.
// FileSuggestion represents a single file, directory, or MCP resource
// suggestion for the @ autocomplete popup.
type FileSuggestion struct {
// RelPath is the path relative to the search base (e.g. "cmd/kit/main.go").
// RelPath is the path relative to the search base (e.g. "cmd/kit/main.go")
// or a display name for MCP resources (e.g. "mcp:server/resource-name").
RelPath string
// IsDir is true when the entry is a directory.
IsDir bool
// Score is the fuzzy match score (higher is better).
Score int
// IsMCPResource is true for MCP resource entries.
IsMCPResource bool
// MCPServerName is the MCP server name (set when IsMCPResource is true).
MCPServerName string
// MCPResourceURI is the MCP resource URI (set when IsMCPResource is true).
MCPResourceURI string
// MCPMIMEType is the MIME type hint from the MCP server.
MCPMIMEType string
}
// maxFileSuggestions is the maximum number of file suggestions returned.
const maxFileSuggestions = 20
// fileListCache caches the result of listFiles() keyed by directory to avoid
// re-running git subprocesses on every keystroke during @file completion.
var fileListCache struct {
mu sync.Mutex
dir string // searchDir that produced the cached entries
cwd string // cwd used for the git query
entries []FileSuggestion // cached file list
expireAt time.Time // when the cache entry expires
}
// fileListCacheTTL controls how long a cached file list stays valid.
// During rapid typing the list is reused; after the TTL a fresh git
// ls-files is executed so newly created files become visible.
const fileListCacheTTL = 3 * time.Second
// getCachedFileList returns the file list for searchDir, using a short-lived
// cache to avoid repeated subprocess calls during @file autocompletion.
func getCachedFileList(searchDir, cwd string) []FileSuggestion {
fileListCache.mu.Lock()
defer fileListCache.mu.Unlock()
now := time.Now()
if fileListCache.dir == searchDir &&
fileListCache.cwd == cwd &&
now.Before(fileListCache.expireAt) {
// Return a copy so callers can mutate (e.g. prepend baseDir).
cp := make([]FileSuggestion, len(fileListCache.entries))
copy(cp, fileListCache.entries)
return cp
}
// Cache miss or expired — run the real (potentially expensive) lookup.
files := listFiles(searchDir, cwd)
fileListCache.dir = searchDir
fileListCache.cwd = cwd
fileListCache.entries = files
fileListCache.expireAt = now.Add(fileListCacheTTL)
// Return a copy.
cp := make([]FileSuggestion, len(files))
copy(cp, files)
return cp
}
// ExtractAtPrefix checks the current line for an @-file trigger at cursorCol.
// It returns:
// - hasAt: true if a valid @ trigger was found
@@ -90,7 +146,7 @@ func GetFileSuggestions(prefix string, cwd string) []FileSuggestion {
}
}
files := listFiles(searchDir, cwd)
files := getCachedFileList(searchDir, cwd)
if len(files) == 0 {
return nil
}
+202 -10
View File
@@ -2,29 +2,85 @@ package fileutil
import (
"fmt"
"mime"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/mark3labs/kit/internal/fences"
)
// FilePart represents a binary file attachment (image, audio, etc.) extracted
// from an @file reference. Callers convert this to kit.LLMFilePart before
// sending to the LLM. Defined here to avoid a circular dependency on pkg/kit.
type FilePart struct {
// Filename is the basename of the file (e.g. "photo.png").
Filename string
// Data is the raw file bytes.
Data []byte
// MediaType is the MIME type (e.g. "image/png", "audio/wav").
MediaType string
}
// MCPResourceReader is a callback function that reads an MCP resource by
// server name and URI. Returns text content, binary data, MIME type, and error.
// Used by ProcessFileAttachments to resolve @mcp:server:uri tokens.
type MCPResourceReader func(serverName, uri string) (text string, blobData []byte, mimeType string, isBlob bool, err error)
// FileAttachmentResult is the result of processing @file references in user
// input. Text files are inlined as XML in ProcessedText; binary files (images,
// audio, video, PDFs) are returned as FileParts for multimodal submission.
type FileAttachmentResult struct {
// ProcessedText is the user's text with @file tokens replaced:
// text files become XML-wrapped content, binary file tokens are removed.
ProcessedText string
// FileParts contains binary file attachments extracted from @file
// references. Empty when all referenced files are text.
FileParts []FilePart
}
// fileTokenPattern matches @file references in user text. Supports:
// - @"path with spaces.txt" (quoted)
// - @path/to/file.txt (unquoted, no spaces)
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
// ProcessFileAttachments scans the user's input text for @file references,
// reads each referenced file, and returns the text with @tokens replaced by
// XML-wrapped file content. Non-file @ tokens (like email addresses) are left
// unchanged.
// reads each referenced file, and returns a result containing the processed
// text and any binary file attachments. Text files are XML-wrapped inline;
// binary files (images, audio, etc.) are extracted as FileParts for multimodal
// submission. Non-file @ tokens (like email addresses) are left unchanged.
//
// Returns the original text unchanged if no valid @file references are found.
func ProcessFileAttachments(text string, cwd string) string {
// MCP resources are supported via @mcp:server:uri tokens. The optional
// mcpReader callback is used to resolve them; pass nil to skip MCP resources.
func ProcessFileAttachments(text string, cwd string, mcpReader ...MCPResourceReader) FileAttachmentResult {
var reader MCPResourceReader
if len(mcpReader) > 0 {
reader = mcpReader[0]
}
var allParts []FilePart
processed := fences.ReplaceOutside(text, func(segment string) string {
result, parts := processFileTokens(segment, cwd, reader)
allParts = append(allParts, parts...)
return result
})
return FileAttachmentResult{
ProcessedText: processed,
FileParts: allParts,
}
}
// processFileTokens handles @file replacement in a single text segment
// that is known to be outside fenced code blocks. Returns the processed
// text and any binary file parts extracted.
func processFileTokens(text string, cwd string, mcpReader MCPResourceReader) (string, []FilePart) {
tokens := fileTokenPattern.FindAllString(text, -1)
if len(tokens) == 0 {
return text
return text, nil
}
var parts []FilePart
result := text
for _, token := range tokens {
path := tokenToPath(token)
@@ -32,6 +88,43 @@ func ProcessFileAttachments(text string, cwd string) string {
continue
}
// Check for MCP resource reference: @mcp:server:uri
if strings.HasPrefix(path, "mcp:") {
if mcpReader == nil {
continue
}
mcpRef := path[4:] // strip "mcp:"
// Split into server:uri (first colon separates server from URI)
serverName, uri, ok := strings.Cut(mcpRef, ":")
if !ok || serverName == "" || uri == "" {
continue // invalid format
}
textContent, blobData, mimeType, isBlob, err := mcpReader(serverName, uri)
if err != nil {
continue // skip on error, leave token as-is
}
if isBlob {
// Binary MCP resource → extract as FilePart.
filename := filepath.Base(uri)
if filename == "." || filename == "/" {
filename = serverName + "_resource"
}
parts = append(parts, FilePart{
Filename: filename,
Data: blobData,
MediaType: mimeType,
})
result = strings.Replace(result, token, "", 1)
} else {
// Text MCP resource → inline as XML.
wrapped := fmt.Sprintf("<resource uri=\"%s\" server=\"%s\">\n%s\n</resource>", uri, serverName, textContent)
result = strings.Replace(result, token, wrapped, 1)
}
continue
}
absPath, err := resolvePath(path, cwd)
if err != nil {
// Not a valid file reference — leave the token as-is.
@@ -59,12 +152,28 @@ func ProcessFileAttachments(text string, cwd string) string {
continue
}
// Build the XML-wrapped replacement.
wrapped := wrapFileContent(absPath, content)
result = strings.Replace(result, token, wrapped, 1)
mediaType := detectMediaType(absPath, content)
if isBinaryMediaType(mediaType) {
// Binary file → extract as a FilePart for multimodal submission.
// Remove the @token from the text.
parts = append(parts, FilePart{
Filename: filepath.Base(absPath),
Data: content,
MediaType: mediaType,
})
result = strings.Replace(result, token, "", 1)
} else {
// Text file → inline as XML-wrapped content.
wrapped := wrapFileContent(absPath, content)
result = strings.Replace(result, token, wrapped, 1)
}
}
return result
// Clean up any extra whitespace left by removed binary tokens.
result = strings.TrimSpace(result)
return result, parts
}
// tokenToPath strips the @ prefix and optional quotes from a token,
@@ -127,3 +236,86 @@ func resolvePath(path string, cwd string) (string, error) {
func wrapFileContent(absPath string, content []byte) string {
return fmt.Sprintf("<file path=\"%s\">\n%s\n</file>", absPath, string(content))
}
// detectMediaType determines the MIME type of a file using extension-based
// lookup first (more reliable for known types), then falls back to content
// sniffing via net/http.DetectContentType.
func detectMediaType(path string, content []byte) string {
// Extension-based detection is more reliable for well-known types.
ext := strings.ToLower(filepath.Ext(path))
if mt := mime.TypeByExtension(ext); mt != "" {
// mime.TypeByExtension returns types like "image/png; charset=utf-8"
// — strip parameters.
if base, _, ok := strings.Cut(mt, ";"); ok {
return strings.TrimSpace(base)
}
return mt
}
// Known extensions that mime package may miss.
switch ext {
case ".webp":
return "image/webp"
case ".avif":
return "image/avif"
case ".heic", ".heif":
return "image/heif"
case ".opus":
return "audio/opus"
case ".flac":
return "audio/flac"
case ".m4a":
return "audio/mp4"
case ".wasm":
return "application/wasm"
}
// Content sniffing fallback.
if len(content) > 0 {
detected := http.DetectContentType(content)
if detected != "" && detected != "application/octet-stream" {
if base, _, ok := strings.Cut(detected, ";"); ok {
return strings.TrimSpace(base)
}
return detected
}
}
// Default: treat as plain text so it gets XML-wrapped.
return "text/plain"
}
// isBinaryMediaType returns true if the MIME type represents a binary file
// that should be sent as a multimodal FilePart rather than XML-wrapped text.
func isBinaryMediaType(mediaType string) bool {
// Image types — always binary.
if strings.HasPrefix(mediaType, "image/") {
return true
}
// Audio types — always binary.
if strings.HasPrefix(mediaType, "audio/") {
return true
}
// Video types — always binary.
if strings.HasPrefix(mediaType, "video/") {
return true
}
// Specific application types that are binary.
switch mediaType {
case "application/pdf",
"application/zip",
"application/gzip",
"application/x-tar",
"application/octet-stream",
"application/wasm",
"application/x-executable",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
"application/msword",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
return true
}
return false
}
+209
View File
@@ -0,0 +1,209 @@
package fileutil
import (
"os"
"path/filepath"
"testing"
)
func TestProcessFileAttachments_TextFile(t *testing.T) {
// Create a temp text file
dir := t.TempDir()
textFile := filepath.Join(dir, "hello.txt")
if err := os.WriteFile(textFile, []byte("hello world"), 0644); err != nil {
t.Fatal(err)
}
text := "@" + textFile + " check this out"
result := ProcessFileAttachments(text, dir)
if len(result.FileParts) != 0 {
t.Errorf("expected 0 FileParts for text file, got %d", len(result.FileParts))
}
if result.ProcessedText == text {
t.Error("expected text file to be XML-wrapped, but got original text unchanged")
}
// Should contain XML wrapping
if !contains(result.ProcessedText, "<file path=") {
t.Error("expected XML <file> wrapping in processed text")
}
if !contains(result.ProcessedText, "hello world") {
t.Error("expected file content in processed text")
}
}
func TestProcessFileAttachments_BinaryFile(t *testing.T) {
// Create a minimal PNG file (binary)
dir := t.TempDir()
pngFile := filepath.Join(dir, "image.png")
// Minimal valid PNG header
pngData := []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // 8bit RGB
0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00,
0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, 0x33,
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, // IEND chunk
0xAE, 0x42, 0x60, 0x82,
}
if err := os.WriteFile(pngFile, pngData, 0644); err != nil {
t.Fatal(err)
}
text := "@" + pngFile + " what is this image?"
result := ProcessFileAttachments(text, dir)
if len(result.FileParts) != 1 {
t.Fatalf("expected 1 FilePart for binary file, got %d", len(result.FileParts))
}
if result.FileParts[0].MediaType != "image/png" {
t.Errorf("expected media type image/png, got %s", result.FileParts[0].MediaType)
}
if result.FileParts[0].Filename != "image.png" {
t.Errorf("expected filename image.png, got %s", result.FileParts[0].Filename)
}
// The @token should be removed from the text
if contains(result.ProcessedText, "@") && contains(result.ProcessedText, pngFile) {
t.Error("expected @token to be removed from processed text for binary file")
}
if contains(result.ProcessedText, "what is this image?") {
// Good, the prompt text should remain
} else {
t.Error("expected prompt text to remain in processed text")
}
}
func TestProcessFileAttachments_MCPResource(t *testing.T) {
// Test @mcp:server:uri token processing with a mock reader
text := "@mcp:test-server:docs://readme tell me about this"
reader := func(serverName, uri string) (string, []byte, string, bool, error) {
if serverName != "test-server" || uri != "docs://readme" {
t.Errorf("unexpected server/uri: %s/%s", serverName, uri)
}
return "Hello from MCP resource", nil, "text/plain", false, nil
}
result := ProcessFileAttachments(text, "/tmp", reader)
if len(result.FileParts) != 0 {
t.Errorf("expected 0 FileParts for text MCP resource, got %d", len(result.FileParts))
}
if !contains(result.ProcessedText, "<resource uri=\"docs://readme\" server=\"test-server\">") {
t.Error("expected <resource> XML wrapping in processed text")
}
if !contains(result.ProcessedText, "Hello from MCP resource") {
t.Error("expected MCP resource content in processed text")
}
}
func TestProcessFileAttachments_MCPResource_Binary(t *testing.T) {
// Test @mcp:server:uri token processing for a binary resource
text := "@mcp:test-server:images://logo describe this"
reader := func(serverName, uri string) (string, []byte, string, bool, error) {
if serverName != "test-server" || uri != "images://logo" {
t.Errorf("unexpected server/uri: %s/%s", serverName, uri)
}
return "", []byte{0x89, 0x50, 0x4E, 0x47}, "image/png", true, nil
}
result := ProcessFileAttachments(text, "/tmp", reader)
if len(result.FileParts) != 1 {
t.Fatalf("expected 1 FilePart for binary MCP resource, got %d", len(result.FileParts))
}
if result.FileParts[0].MediaType != "image/png" {
t.Errorf("expected media type image/png, got %s", result.FileParts[0].MediaType)
}
if result.FileParts[0].Filename != "logo" {
t.Errorf("expected filename 'logo', got %s", result.FileParts[0].Filename)
}
// The @token should be removed from the text
if contains(result.ProcessedText, "@mcp:") {
t.Error("expected @mcp: token to be removed from processed text for binary resource")
}
}
func TestProcessFileAttachments_NoReader(t *testing.T) {
// Without an MCP reader, @mcp: tokens should be left as-is
text := "@mcp:server:resource this is a test"
result := ProcessFileAttachments(text, "/tmp")
if len(result.FileParts) != 0 {
t.Errorf("expected 0 FileParts, got %d", len(result.FileParts))
}
// The @mcp: token should remain unchanged since no reader was provided
if result.ProcessedText != text {
t.Errorf("expected text unchanged without reader, got: %s", result.ProcessedText)
}
}
func TestDetectMediaType(t *testing.T) {
tests := []struct {
ext string
content []byte
expected string
}{
// An intentionally-synthetic extension that is not registered
// in any system MIME database. Exercises the "unknown ext +
// no content" branch, which must return the text/plain default.
// Do not use real extensions (e.g. .go) here: CI images often
// ship /etc/mime.types with entries like ".go → text/x-go",
// which would make the assertion environment-dependent.
{".kitsyntheticext", nil, "text/plain"},
{".png", []byte{0x89, 0x50, 0x4E, 0x47}, "image/png"},
{".jpg", []byte{0xFF, 0xD8, 0xFF}, "image/jpeg"},
{".pdf", []byte{0x25, 0x50, 0x44, 0x46}, "application/pdf"},
{".txt", []byte("hello"), "text/plain"},
{".wav", nil, "audio/wav"},
{".webp", nil, "image/webp"},
}
for _, tt := range tests {
t.Run(tt.ext, func(t *testing.T) {
got := detectMediaType("test"+tt.ext, tt.content)
if got != tt.expected {
t.Errorf("detectMediaType(%q) = %q, want %q", tt.ext, got, tt.expected)
}
})
}
}
func TestIsBinaryMediaType(t *testing.T) {
tests := []struct {
mimeType string
expected bool
}{
{"image/png", true},
{"image/jpeg", true},
{"audio/wav", true},
{"video/mp4", true},
{"application/pdf", true},
{"text/plain", false},
{"text/go", false},
{"application/json", false},
}
for _, tt := range tests {
t.Run(tt.mimeType, func(t *testing.T) {
got := isBinaryMediaType(tt.mimeType)
if got != tt.expected {
t.Errorf("isBinaryMediaType(%q) = %v, want %v", tt.mimeType, got, tt.expected)
}
})
}
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsStr(s, substr))
}
func containsStr(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
+1
View File
@@ -17,6 +17,7 @@ type Renderer interface {
RenderReasoningBlock(content string, timestamp time.Time) UIMessage
RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage
RenderSystemMessage(content string, timestamp time.Time) UIMessage
RenderCustomMessage(content, label string, timestamp time.Time) UIMessage
RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage
RenderDebugMessage(message string, timestamp time.Time) UIMessage
RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage
+137 -40
View File
@@ -2,6 +2,7 @@ package ui
import (
"fmt"
"sort"
"strings"
"charm.land/bubbles/v2/key"
@@ -39,7 +40,6 @@ type InputComponent struct {
width int
lastValue string
popupHeight int
title string
submitNext bool // defer submit one tick so popup dismisses cleanly
// Argument completion state. When the user types "/cmd " followed by
@@ -61,6 +61,10 @@ type InputComponent struct {
// autocomplete suggestions. Set by the parent via SetCwd.
cwd string
// mcpResources is a callback that returns available MCP resources for
// the @ autocomplete popup. Set by the parent via SetMCPResourceProvider.
mcpResources func() []FileSuggestion
// appCtrl is used for slash commands that mutate app state.
// May be nil in tests; nil-safe.
appCtrl AppController
@@ -101,17 +105,17 @@ type clipboardImageMsg struct {
err error
}
// NewInputComponent creates a new InputComponent with the given width, title,
// and optional AppController. If appCtrl is nil the component still works but
// NewInputComponent creates a new InputComponent with the given width and
// optional AppController. If appCtrl is nil the component still works but
// /clear and /clear-queue are no-ops.
func NewInputComponent(width int, title string, appCtrl AppController) *InputComponent {
func NewInputComponent(width int, appCtrl AppController) *InputComponent {
ta := textarea.New()
ta.Placeholder = "Type your message..."
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(3) // Default to 3 lines like huh
ta.SetHeight(4) // 4 lines for comfortable multi-line input
ta.Focus()
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
@@ -136,8 +140,8 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
commands: commands.SlashCommands,
width: width,
popupHeight: 7,
title: title,
appCtrl: appCtrl,
hideHint: true,
}
}
@@ -147,6 +151,12 @@ func (s *InputComponent) SetCwd(cwd string) {
s.cwd = cwd
}
// SetMCPResourceProvider sets a callback that returns MCP resource suggestions
// for the @ autocomplete popup. Called by the parent after construction.
func (s *InputComponent) SetMCPResourceProvider(fn func() []FileSuggestion) {
s.mcpResources = fn
}
// Init implements tea.Model. Starts the cursor blink animation.
func (s *InputComponent) Init() tea.Cmd {
return textarea.Blink
@@ -190,7 +200,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyPressMsg:
if !s.showPopup {
switch msg.String() {
case "ctrl+d", "enter":
case "enter":
value := s.textarea.Value()
s.pushHistory(value)
s.textarea.SetValue("")
@@ -285,16 +295,25 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
s.textarea.CursorEnd()
return s, nil
}
selectedCmd := s.filtered[s.selected].Command
// Populate textarea with selected item and submit on next tick.
if s.argMode {
s.textarea.SetValue(s.argCommand + " " + s.filtered[s.selected].Command.Name)
s.textarea.SetValue(s.argCommand + " " + selectedCmd.Name)
} else {
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
s.textarea.SetValue(selectedCmd.Name)
}
s.textarea.CursorEnd()
s.showPopup = false
s.selected = 0
s.submitNext = true
// If the selected command expects arguments, populate
// the input with the command + trailing space so the
// user can type args, instead of auto-submitting.
if !s.argMode && selectedCmd.HasArgs {
s.textarea.SetValue(selectedCmd.Name + " ")
s.textarea.CursorEnd()
} else {
s.submitNext = true
}
return s, nil
}
return s, nil
@@ -323,9 +342,46 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Check for @file trigger first.
cursorCol := len(line) // approximate: cursor is at end after typing
if hasAt, prefix, atIdx := ExtractAtPrefix(line, cursorCol); hasAt && s.cwd != "" {
suggestions := GetFileSuggestions(prefix, s.cwd)
if hasAt, prefix, atIdx := ExtractAtPrefix(line, cursorCol); hasAt {
var suggestions []FileSuggestion
// Local file suggestions (only if cwd is set).
if s.cwd != "" {
suggestions = GetFileSuggestions(prefix, s.cwd)
}
// MCP resource suggestions — merge with file suggestions.
if s.mcpResources != nil {
mcpSuggestions := s.mcpResources()
if prefix != "" {
// Fuzzy-filter MCP resources against the typed prefix.
queryLower := strings.ToLower(prefix)
var filtered []FileSuggestion
for _, r := range mcpSuggestions {
score := scoreFilePath(queryLower, r.RelPath)
if score <= 0 {
// Also try matching against the resource name without prefix.
score = scoreFilePath(queryLower, r.MCPServerName+"/"+r.RelPath)
}
if score > 0 {
r.Score = score
filtered = append(filtered, r)
}
}
mcpSuggestions = filtered
}
suggestions = append(suggestions, mcpSuggestions...)
}
if len(suggestions) > 0 {
// Sort by score descending, cap at maxFileSuggestions.
sort.Slice(suggestions, func(i, j int) bool {
return suggestions[i].Score > suggestions[j].Score
})
if len(suggestions) > maxFileSuggestions {
suggestions = suggestions[:maxFileSuggestions]
}
s.showPopup = true
s.fileMode = true
s.argMode = false
@@ -339,6 +395,8 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
desc := ""
if fs.IsDir {
desc = "directory"
} else if fs.IsMCPResource {
desc = "mcp:" + fs.MCPServerName
}
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
@@ -461,19 +519,13 @@ func (s *InputComponent) resetHistoryBrowsing() {
s.savedInput = ""
}
// View implements tea.Model. Renders the title, textarea, autocomplete popup
// View implements tea.Model. Renders the textarea, autocomplete popup
// (if visible), and help text.
func (s *InputComponent) View() tea.View {
containerStyle := lipgloss.NewStyle()
theme := style.GetTheme()
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
titleStyle := lipgloss.NewStyle().
Foreground(theme.Text).
MarginBottom(1).
PaddingLeft(3)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.ThickBorder()).
BorderLeft(true).
@@ -481,12 +533,12 @@ func (s *InputComponent) View() tea.View {
BorderTop(false).
BorderBottom(false).
BorderForeground(theme.Primary).
MarginTop(1).
MarginBottom(1).
PaddingLeft(2). // match message block paddingLeft
Width(s.width - 1) // full width minus left border
var view strings.Builder
view.WriteString(titleStyle.Render(s.title))
view.WriteString("\n")
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
// Popup is now rendered as a centered overlay in AppModel.View()
@@ -521,12 +573,14 @@ func (s *InputComponent) View() tea.View {
} else {
hint = "^X s steer"
}
} else if availableHintWidth >= 80 {
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+x e editor • ctrl+v paste image"
} else if availableHintWidth >= 67 {
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
hint = "enter submit • ctrl+j new line • ctrl+x e editor • ctrl+v image"
} else if availableHintWidth >= 40 {
hint = "↵ submit • ctrl+j newline • ctrl+v image"
hint = "↵ submit • ctrl+j newline • ^X e editor"
} else if availableHintWidth >= 20 {
hint = "↵ submit • ctrl+j"
hint = "↵ submit • ^X e editor"
} else {
hint = "↵ submit"
}
@@ -647,9 +701,25 @@ func (s *InputComponent) renderPopupWithOptions(centered bool) string {
}
content = indicator + displayName
} else {
nameWidth := 15
if innerWidth < 25 {
nameWidth = max(innerWidth*2/5+1, 8)
// Compute nameWidth from the longest command name in the
// visible slice so we never truncate unnecessarily.
nameWidth := 0
for _, fm := range s.filtered {
if n := len([]rune(fm.Command.Name)); n > nameWidth {
nameWidth = n
}
}
nameWidth += 3 // account for indicator prefix (2) + gap before description (1)
// Ensure descriptions still get at least 20 chars when possible.
maxForName := innerWidth - 20
if maxForName < 8 {
maxForName = innerWidth * 2 / 3
}
if nameWidth > maxForName {
nameWidth = maxForName
}
if nameWidth < 8 {
nameWidth = 8
}
maxNameChars := nameWidth - 2
displayName := sc.Name
@@ -782,9 +852,25 @@ func (s *InputComponent) PendingImageCount() int {
return len(s.pendingImages)
}
// Clear clears the textarea content and resets related state. Returns true if
// there was content to clear, false if the input was already empty.
func (s *InputComponent) Clear() bool {
hadContent := s.textarea.Value() != ""
s.textarea.SetValue("")
s.textarea.CursorEnd()
s.lastValue = ""
s.showPopup = false
s.argMode = false
s.fileMode = false
s.browsingHistory = false
s.savedInput = ""
return hadContent
}
// applyFileCompletion replaces the @prefix in the textarea with the selected
// file suggestion. For directories, it keeps the popup open for further
// drilling. For files, it closes the popup and adds a trailing space.
// file or MCP resource suggestion. For directories, it keeps the popup open
// for further drilling. For files and resources, it closes the popup and adds
// a trailing space.
func (s *InputComponent) applyFileCompletion(idx int) {
if idx >= len(s.fileSuggestions) {
return
@@ -801,19 +887,30 @@ func (s *InputComponent) applyFileCompletion(idx int) {
// Reconstruct: everything before the @ on the last line + @<path>
beforeAt := lastLine[:s.fileAtStartIdx]
needsQuote := strings.Contains(suggestion.RelPath, " ")
var replacement string
if needsQuote {
replacement = `@"` + suggestion.RelPath + `"`
} else {
replacement = "@" + suggestion.RelPath
}
// For files, add a trailing space. For directories, don't — allow
// continued drilling into the directory.
if !suggestion.IsDir {
if suggestion.IsMCPResource {
// MCP resources use @mcp:server:uri format.
// Quote if the URI contains spaces.
ref := "mcp:" + suggestion.MCPServerName + ":" + suggestion.MCPResourceURI
if strings.Contains(ref, " ") {
replacement = `@"` + ref + `"`
} else {
replacement = "@" + ref
}
replacement += " "
} else {
needsQuote := strings.Contains(suggestion.RelPath, " ")
if needsQuote {
replacement = `@"` + suggestion.RelPath + `"`
} else {
replacement = "@" + suggestion.RelPath
}
// For files, add a trailing space. For directories, don't — allow
// continued drilling into the directory.
if !suggestion.IsDir {
replacement += " "
}
}
newLastLine := beforeAt + replacement
@@ -825,7 +922,7 @@ func (s *InputComponent) applyFileCompletion(idx int) {
s.textarea.SetValue(newValue)
s.textarea.CursorEnd()
if suggestion.IsDir {
if suggestion.IsDir && !suggestion.IsMCPResource {
// Keep popup open — trigger a refresh for the new directory.
s.lastValue = "" // force re-evaluation on next update tick
} else {
+6 -8
View File
@@ -109,8 +109,8 @@ func (m *TextMessageItem) renderContent(width int) string {
// It accumulates content chunks and re-renders on each update for live display.
type StreamingMessageItem struct {
id string
role string // "assistant" or "reasoning"
content string // Accumulated streaming content
role string // "assistant" or "reasoning"
content strings.Builder // Accumulated streaming content
timestamp time.Time
startTime time.Time // When streaming started (for live duration counter)
modelName string
@@ -156,10 +156,10 @@ func (s *StreamingMessageItem) Render(width int) string {
durationMs = time.Since(s.startTime).Milliseconds()
}
ty := createTypography(style.GetTheme())
rendered = render.ReasoningBlock(s.content, durationMs, ty, style.GetTheme())
rendered = render.ReasoningBlock(s.content.String(), durationMs, width, ty, style.GetTheme())
} else {
// Render as assistant message
rendered = render.AssistantBlock(s.content, width, style.GetTheme())
rendered = render.AssistantBlock(s.content.String(), width, style.GetTheme())
}
// Cache and return (but reasoning is never cached due to live duration)
@@ -187,7 +187,7 @@ func (s *StreamingMessageItem) Height() int {
// AppendChunk adds a content chunk and invalidates the render cache.
func (s *StreamingMessageItem) AppendChunk(chunk string) {
s.content += chunk
s.content.WriteString(chunk)
s.cachedWidth = 0 // Invalidate cache
}
@@ -243,9 +243,7 @@ func (m *StreamingBashOutputItem) Render(width int) string {
// Header with command
if m.command != "" {
headerStyle := lipgloss.NewStyle().
Foreground(theme.Muted).
Italic(true)
headerStyle := style.GetCachedStyles().BashHeader
parts = append(parts, headerStyle.Render(fmt.Sprintf("▸ %s", m.command)))
}
+39 -12
View File
@@ -88,13 +88,9 @@ func formatToolParams(toolArgs string, maxWidth int) string {
}
bodyKeys := map[string]bool{
"content": true,
"old_text": true,
"new_text": true,
"oldText": true,
"newText": true,
"edits": true,
"todos": true,
"content": true,
"edits": true,
"todos": true,
}
var remaining []string
for key, val := range params {
@@ -150,9 +146,26 @@ func (r *MessageRenderer) SetWidth(width int) {
r.width = width
}
// RenderUserMessage renders a user's input message using herald Tip alert
// RenderUserMessage renders a user's input message with a colored left border.
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
theme := style.GetTheme()
// Highlight @file tokens with accent color.
content = render.HighlightFileTokens(content, theme)
rendered := renderContentBlock(
content,
r.width,
WithAlign(lipgloss.Left),
WithBorderColor(theme.Success),
WithPaddingTop(0),
WithPaddingBottom(0),
WithMarginBottom(1),
)
return UIMessage{
Type: UserMessage,
@@ -178,7 +191,7 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
// as live streaming: muted italic text with margin. This is used when resuming
// sessions to display saved reasoning content.
func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Time) UIMessage {
rendered := render.ReasoningBlock(content, 0, r.ty, style.GetTheme())
rendered := render.ReasoningBlock(content, 0, r.width, r.ty, style.GetTheme())
return UIMessage{
Type: AssistantMessage,
@@ -200,6 +213,19 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
}
}
// RenderCustomMessage renders a message with a custom alert label (e.g. "Help").
// Content is rendered as markdown.
func (r *MessageRenderer) RenderCustomMessage(content, label string, timestamp time.Time) UIMessage {
rendered := render.CustomBlock(content, label, r.width, style.GetTheme())
return UIMessage{
Type: SystemMessage,
Content: rendered,
Height: lipgloss.Height(rendered),
Timestamp: timestamp,
}
}
// RenderDebugMessage renders diagnostic and debugging information
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
header := r.ty.H6("🔍 Debug Output")
@@ -308,7 +334,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
// Build the content: icon + name + params on first line, then body
headerLine := styledIcon + " " + styledName
if params != "" {
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
headerLine += " " + style.GetCachedStyles().ToolMuted.Render(params)
}
// Get body content
@@ -399,7 +425,8 @@ func createTypography(theme style.Theme) *herald.Typography {
herald.WithCodeLineNumbers(true),
// Customize alert labels
herald.WithAlertLabel(herald.AlertNote, "Info"),
herald.WithAlertLabel(herald.AlertTip, "You"),
herald.WithAlertLabel(herald.AlertTip, ""),
herald.WithAlertIcon(herald.AlertTip, ""),
herald.WithAlertLabel(herald.AlertWarning, "Working"),
herald.WithAlertLabel(herald.AlertCaution, "Error"),
)
+660 -54
View File
File diff suppressed because it is too large Load Diff
+151 -9
View File
@@ -515,12 +515,12 @@ func TestWindowResize_distributeHeight(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
// With height=30, scroll height = 30 - 1 (separator) - 9 (input) - 1 (statusBar) = 19
// With height=30, scroll height = 30 - 1 (separator) - 8 (input) - 1 (statusBar) = 20
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 30})
_ = m
if m.scrollList.height != 19 {
t.Fatalf("expected scroll list height=19, got %d", m.scrollList.height)
if m.scrollList.height != 20 {
t.Fatalf("expected scroll list height=20, got %d", m.scrollList.height)
}
}
@@ -853,23 +853,165 @@ func TestSpinnerEvent_hideDoesNotTransitionState(t *testing.T) {
}
// --------------------------------------------------------------------------
// ctrl+c produces tea.Quit
// ctrl+c double-press to quit
// --------------------------------------------------------------------------
// TestCtrlC_producesQuit verifies that ctrl+c always returns a tea.Quit cmd.
// TestCtrlC_producesQuit verifies that double ctrl+c returns a tea.Quit cmd.
func TestCtrlC_producesQuit(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
// First Ctrl+C arms the quit flag.
updated, cmd := m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
m = updated.(*AppModel)
if cmd == nil {
t.Fatal("expected a command after first ctrl+c, got nil")
}
// Should be a reset timer, not quit.
msg := cmd()
if _, ok := msg.(core.CtrlCResetMsg); !ok {
t.Fatalf("expected CtrlCResetMsg after first ctrl+c, got %T", msg)
}
// Second Ctrl+C should quit.
_, cmd = m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
if cmd == nil {
t.Fatal("expected tea.Quit cmd on second ctrl+c, got nil")
}
msg = cmd()
if _, ok := msg.(tea.QuitMsg); !ok {
t.Fatalf("expected QuitMsg from second ctrl+c, got %T", msg)
}
}
// TestCtrlC_clearsInput_firstPress tests that Ctrl+C clears input on first
// press when there's content, and requires a second press to quit.
func TestCtrlC_clearsInput_firstPress(t *testing.T) {
// Create a real InputComponent to test the clear behavior
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
// Replace with real InputComponent that has content
input := NewInputComponent(80, ctrl)
input.textarea.SetValue("some text content")
m.input = input
// First Ctrl+C should clear input, not quit
_, cmd := m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
if cmd == nil {
t.Fatal("expected tea.Quit cmd on ctrl+c, got nil")
// Should have cleared the input
if input.textarea.Value() != "" {
t.Fatalf("expected input to be cleared, got %q", input.textarea.Value())
}
// Should have set ctrlCPressedOnce flag
if !m.ctrlCPressedOnce {
t.Fatal("expected ctrlCPressedOnce to be true after first Ctrl+C")
}
// The command should be a ctrlCResetCmd (not tea.Quit)
if cmd == nil {
t.Fatal("expected a command after first Ctrl+C, got nil")
}
// We verify it's a quit command by running it and checking the message type.
msg := cmd()
if _, ok := msg.(core.CtrlCResetMsg); !ok {
t.Fatalf("expected CtrlCResetMsg, got %T", msg)
}
// Second Ctrl+C should now quit
_, cmd = m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
if cmd == nil {
t.Fatal("expected tea.Quit cmd on second Ctrl+C, got nil")
}
msg = cmd()
if _, ok := msg.(tea.QuitMsg); !ok {
t.Fatalf("expected QuitMsg from ctrl+c cmd, got %T", msg)
t.Fatalf("expected QuitMsg on second Ctrl+C, got %T", msg)
}
}
// TestCtrlC_resetAfterSubmit tests that the Ctrl+C flag is reset after
// submitting a message, so the next Ctrl+C clears input again.
func TestCtrlC_resetAfterSubmit(t *testing.T) {
// Use newTestAppModel but replace the input with a real InputComponent
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
// Replace with real InputComponent
input := NewInputComponent(80, ctrl)
input.textarea.SetValue("content")
m.input = input
// First Ctrl+C clears input
updated, _ := m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
m = updated.(*AppModel)
if input.textarea.Value() != "" {
t.Fatal("expected input to be cleared")
}
// Flag should be set
if !m.ctrlCPressedOnce {
t.Fatal("expected ctrlCPressedOnce to be true after first Ctrl+C")
}
// Simulate CtrlCResetMsg being processed (timer expired)
updated, _ = m.Update(core.CtrlCResetMsg{})
m = updated.(*AppModel)
// Flag should be reset
if m.ctrlCPressedOnce {
t.Fatal("expected ctrlCPressedOnce to be false after CtrlCResetMsg")
}
// Add new content to input
input.textarea.SetValue("new content")
// Next Ctrl+C should clear again (not quit) because flag was reset
_, cmd := m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
if input.textarea.Value() != "" {
t.Fatalf("expected input to be cleared again, got %q", input.textarea.Value())
}
if cmd == nil {
t.Fatal("expected a command after Ctrl+C, got nil")
}
msg := cmd()
if _, ok := msg.(core.CtrlCResetMsg); !ok {
t.Fatalf("expected CtrlCResetMsg, got %T", msg)
}
}
// TestCtrlC_emptyInput_armsQuit tests that Ctrl+C on empty input still
// requires a second press to quit (consistent double-press behavior).
func TestCtrlC_emptyInput_armsQuit(t *testing.T) {
ctrl := &stubAppController{}
m, _, _ := newTestAppModel(ctrl)
// Replace with real InputComponent (empty by default)
input := NewInputComponent(80, ctrl)
m.input = input
// First Ctrl+C on empty input should arm the flag, not quit.
updated, cmd := m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
m = updated.(*AppModel)
if !m.ctrlCPressedOnce {
t.Fatal("expected ctrlCPressedOnce to be true after first Ctrl+C")
}
if cmd == nil {
t.Fatal("expected a command (reset timer), got nil")
}
msg := cmd()
if _, ok := msg.(core.CtrlCResetMsg); !ok {
t.Fatalf("expected CtrlCResetMsg, got %T", msg)
}
// Second Ctrl+C should quit.
_, cmd = m.Update(tea.KeyPressMsg{Code: 'c', Mod: tea.ModCtrl})
if cmd == nil {
t.Fatal("expected tea.Quit cmd on second Ctrl+C, got nil")
}
msg = cmd()
if _, ok := msg.(tea.QuitMsg); !ok {
t.Fatalf("expected QuitMsg on second Ctrl+C, got %T", msg)
}
}
+6
View File
@@ -288,3 +288,9 @@ func (pr *ProgressReader) Close() error {
return nil
}
// NewProgressReadCloser is a convenience wrapper around NewProgressReader that
// returns an io.ReadCloser, suitable for use as a ProgressReaderFunc callback.
func NewProgressReadCloser(r io.Reader) io.ReadCloser {
return NewProgressReader(r)
}
+78 -9
View File
@@ -19,9 +19,10 @@ import (
type promptMode string
const (
promptModeSelect promptMode = "select"
promptModeConfirm promptMode = "confirm"
promptModeInput promptMode = "input"
promptModeSelect promptMode = "select"
promptModeConfirm promptMode = "confirm"
promptModeInput promptMode = "input"
promptModePassword promptMode = "password"
)
// promptResult carries the synchronous outcome of a prompt overlay update.
@@ -102,10 +103,38 @@ func newInputPrompt(message, placeholder, defaultValue string, width, height int
}
}
// Init returns the initial command for the prompt overlay. For input mode
// this starts the cursor blink animation.
// newPasswordPrompt creates a prompt overlay for password input (masked).
func newPasswordPrompt(message string, width, height int) *promptOverlay {
ta := textarea.New()
ta.Placeholder = "Enter password"
ta.ShowLineNumbers = false
ta.Prompt = ""
ta.CharLimit = 0
ta.SetWidth(width - 12) // account for border + padding
ta.SetHeight(1)
ta.Focus()
// Prevent Enter from inserting a newline — we intercept it for submit.
ta.KeyMap.InsertNewline = key.NewBinding(
key.WithKeys("ctrl+j", "shift+enter"),
)
// Enable password masking - the textarea will show dots instead of characters
// Note: textarea doesn't have built-in password masking, so we handle it in View()
return &promptOverlay{
mode: promptModePassword,
message: message,
inputTA: ta,
width: width,
height: height,
}
}
// Init returns the initial command for the prompt overlay. For input/password
// modes this starts the cursor blink animation.
func (p *promptOverlay) Init() tea.Cmd {
if p.mode == promptModeInput {
if p.mode == promptModeInput || p.mode == promptModePassword {
return textarea.Blink
}
return nil
@@ -113,13 +142,13 @@ func (p *promptOverlay) Init() tea.Cmd {
// Update handles messages for the prompt overlay. It returns a non-nil
// *promptResult when the user completes or cancels the prompt. The returned
// tea.Cmd is for textarea blink ticks (input mode only).
// tea.Cmd is for textarea blink ticks (input/password modes only).
func (p *promptOverlay) Update(msg tea.Msg) (*promptResult, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
p.width = msg.Width
p.height = msg.Height
if p.mode == promptModeInput {
if p.mode == promptModeInput || p.mode == promptModePassword {
p.inputTA.SetWidth(p.width - 12)
}
return nil, nil
@@ -132,11 +161,13 @@ func (p *promptOverlay) Update(msg tea.Msg) (*promptResult, tea.Cmd) {
return p.updateConfirm(msg)
case promptModeInput:
return p.updateInput(msg)
case promptModePassword:
return p.updatePassword(msg)
}
}
// Pass non-key messages to textarea for blink animation.
if p.mode == promptModeInput {
if p.mode == promptModeInput || p.mode == promptModePassword {
var cmd tea.Cmd
p.inputTA, cmd = p.inputTA.Update(msg)
return nil, cmd
@@ -202,6 +233,20 @@ func (p *promptOverlay) updateInput(msg tea.KeyPressMsg) (*promptResult, tea.Cmd
}
}
func (p *promptOverlay) updatePassword(msg tea.KeyPressMsg) (*promptResult, tea.Cmd) {
switch msg.String() {
case "enter":
return &promptResult{completed: true, value: p.inputTA.Value()}, nil
case "esc":
return &promptResult{cancelled: true}, nil
default:
// Delegate character input, backspace, cursor movement, etc.
var cmd tea.Cmd
p.inputTA, cmd = p.inputTA.Update(msg)
return nil, cmd
}
}
// Render returns the prompt as a styled string for inline composition in the
// AppModel layout. The prompt replaces the normal input area (below the
// separator and above the status bar) rather than taking over the full screen.
@@ -216,6 +261,8 @@ func (p *promptOverlay) Render() string {
content = p.viewConfirm(theme)
case promptModeInput:
content = p.viewInput(theme)
case promptModePassword:
content = p.viewPassword(theme)
}
return renderContentBlock(content, p.width,
@@ -286,3 +333,25 @@ func (p *promptOverlay) viewInput(theme style.Theme) string {
return strings.Join(lines, "\n")
}
func (p *promptOverlay) viewPassword(theme style.Theme) string {
var lines []string
// Add 🔐 icon to message for password prompt
lines = append(lines, lipgloss.NewStyle().Bold(true).Foreground(theme.Text).Render("🔐 "+p.message))
lines = append(lines, "")
// Mask the password input with dots
passwordValue := p.inputTA.Value()
masked := strings.Repeat("•", len([]rune(passwordValue)))
// Render the masked password in a style that looks like input
maskedStyle := lipgloss.NewStyle().Foreground(theme.Text)
cursor := lipgloss.NewStyle().Foreground(theme.Accent).Render("█")
lines = append(lines, maskedStyle.Render(masked)+cursor)
lines = append(lines, "")
lines = append(lines, lipgloss.NewStyle().
Foreground(theme.Muted).
Render(" Enter submit Esc cancel (input is hidden)"))
return strings.Join(lines, "\n")
}
+77 -8
View File
@@ -5,6 +5,7 @@ package render
import (
"fmt"
"regexp"
"strings"
"charm.land/lipgloss/v2"
@@ -13,8 +14,14 @@ import (
"github.com/mark3labs/kit/internal/ui/style"
)
// fileTokenPattern matches @file references in user text. Supports:
// - @"path with spaces.txt" (quoted)
// - @path/to/file.txt (unquoted, no spaces)
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
// UserBlock renders a user message with herald Tip styling.
// The width parameter controls line wrapping so long messages don't overflow.
// Any @file tokens in the content are highlighted with the theme accent color.
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
@@ -27,10 +34,23 @@ func UserBlock(content string, width int, ty *herald.Typography, theme style.The
content = lipgloss.Wrap(content, width-4, "")
}
// Highlight @file tokens with accent color so file references are
// visually distinct from surrounding prompt text.
content = HighlightFileTokens(content, theme)
rendered := ty.Tip(content)
return styleMarginBottom(theme, rendered)
}
// HighlightFileTokens wraps @file tokens in the given text with the theme
// accent color so they stand out visually in rendered user messages.
func HighlightFileTokens(text string, theme style.Theme) string {
accentStyle := style.GetCachedStyles().FileTokenAccent
return fileTokenPattern.ReplaceAllStringFunc(text, func(token string) string {
return accentStyle.Render(token)
})
}
// AssistantBlock renders an assistant message with markdown styling.
func AssistantBlock(content string, width int, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
@@ -43,16 +63,20 @@ func AssistantBlock(content string, width int, theme style.Theme) string {
// ReasoningBlock renders a reasoning/thinking block with muted italic text.
// If duration > 0, shows "Thought for Xs" label. Otherwise shows just "Thought".
func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme style.Theme) string {
// The width parameter controls soft-wrapping so long reasoning lines don't get cut off.
func ReasoningBlock(content string, duration int64, width int, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
return ""
}
// Match live streaming styling: muted italic text
// Match live streaming styling: muted italic text.
lines := strings.Split(strings.TrimRight(content, "\n"), "\n")
contentStr := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
contentRendered := mutedStyle.Render(ty.Italic(contentStr))
if width > 4 {
contentStr = wrapText(contentStr, width-4)
}
cs := style.GetCachedStyles()
contentRendered := cs.Muted.Render(ty.Italic(contentStr))
// Build label based on duration
if duration > 0 {
@@ -62,14 +86,14 @@ func ReasoningBlock(content string, duration int64, ty *herald.Typography, theme
} else {
durationStr = fmt.Sprintf("%.1fs", float64(duration)/1000)
}
labelPart := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
durationPart := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
labelPart := cs.VeryMuted.Render("Thought for ")
durationPart := cs.Accent.Render(durationStr)
label := labelPart + durationPart
rendered := contentRendered + "\n" + label
return styleMarginBottom(theme, rendered)
}
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought")
label := cs.VeryMuted.Render("Thought")
rendered := contentRendered + "\n" + label
return styleMarginBottom(theme, rendered)
@@ -85,6 +109,45 @@ func SystemBlock(content string, ty *herald.Typography, theme style.Theme) strin
return styleMarginBottom(theme, rendered)
}
// CustomBlock renders a message with herald Note styling and a custom label.
// Content is rendered as markdown before being wrapped in the alert. This
// creates a one-off Typography instance with the given label so callers
// can use any title (e.g. "Help", "Warning") without changing the shared
// typography's default "Info" label.
func CustomBlock(content, label string, width int, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "No content available"
}
// Render markdown first — subtract 4 for the alert bar prefix ("│ ").
mdWidth := max(width-4, 10)
rendered := style.ToMarkdown(content, mdWidth)
ty := herald.New(
herald.WithPalette(herald.ColorPalette{
Primary: theme.Primary,
Secondary: theme.Secondary,
Tertiary: theme.Info,
Accent: theme.Accent,
Highlight: theme.Highlight,
Muted: theme.Muted,
Text: theme.Text,
Surface: theme.Background,
Base: theme.CodeBg,
}),
herald.WithAlertPalette(herald.AlertPalette{
Note: theme.Info,
Tip: theme.Success,
Important: theme.Accent,
Warning: theme.Warning,
Caution: theme.Error,
}),
herald.WithAlertLabel(herald.AlertNote, label),
)
alertRendered := ty.Note(rendered)
return styleMarginBottom(theme, alertRendered)
}
// ErrorBlock renders an error message with herald Caution styling.
func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) string {
rendered := ty.Caution(errorMsg)
@@ -131,5 +194,11 @@ func ToolBlock(displayName, params, body string, isError bool, width int, ty *he
// styleMarginBottom applies a 1-line margin bottom using the theme.
func styleMarginBottom(theme style.Theme, content string) string {
return lipgloss.NewStyle().MarginBottom(1).Render(content)
return style.GetCachedStyles().MarginBottom1.Render(content)
}
// wrapText soft-wraps a string to the given width using lipgloss, which is
// ANSI-aware and preserves escape sequences across line breaks.
func wrapText(s string, width int) string {
return lipgloss.NewStyle().Width(width).Render(s)
}
+111
View File
@@ -0,0 +1,111 @@
package render
import (
"strings"
"testing"
"github.com/indaco/herald"
"github.com/mark3labs/kit/internal/ui/style"
)
// testTypography creates a herald Typography for tests.
func testTypography(theme style.Theme) *herald.Typography {
return herald.New(
herald.WithPalette(herald.ColorPalette{
Primary: theme.Primary,
Secondary: theme.Secondary,
Tertiary: theme.Info,
Accent: theme.Accent,
Highlight: theme.Highlight,
Muted: theme.Muted,
Text: theme.Text,
Surface: theme.Background,
Base: theme.CodeBg,
}),
herald.WithAlertLabel(herald.AlertTip, ""),
herald.WithAlertIcon(herald.AlertTip, ""),
)
}
func TestHighlightFileTokens(t *testing.T) {
theme := style.DefaultTheme()
tests := []struct {
name string
input string
wantHas []string // substrings that must be present in the output
wantNone []string // substrings that must NOT be present as plain text
}{
{
name: "no tokens",
input: "hello world",
wantHas: []string{"hello world"},
},
{
name: "single unquoted token",
input: "refactor @main.go please",
wantHas: []string{"@main.go", "refactor", "please"},
},
{
name: "quoted token with spaces",
input: `check @"path with spaces/file.txt" out`,
wantHas: []string{`@"path with spaces/file.txt"`, "check", "out"},
},
{
name: "multiple tokens",
input: "@main.go @utils.go refactor these",
wantHas: []string{"@main.go", "@utils.go", "refactor these"},
},
{
name: "path with directory",
input: "look at @internal/ui/render/blocks.go",
wantHas: []string{"@internal/ui/render/blocks.go", "look at"},
},
{
name: "empty string",
input: "",
wantHas: []string{""},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HighlightFileTokens(tt.input, theme)
for _, want := range tt.wantHas {
if !strings.Contains(result, want) {
t.Errorf("HighlightFileTokens(%q) = %q, want substring %q", tt.input, result, want)
}
}
// If there were @tokens, the result should contain ANSI escape
// sequences (from lipgloss styling).
if fileTokenPattern.MatchString(tt.input) && !strings.Contains(result, "\x1b[") {
t.Errorf("HighlightFileTokens(%q) should contain ANSI escapes for @tokens but got %q", tt.input, result)
}
})
}
}
func TestUserBlockHighlightsFileTokens(t *testing.T) {
theme := style.DefaultTheme()
ty := testTypography(theme)
// A user message with @file tokens should contain ANSI escapes around the token.
content := "refactor @main.go and @utils.go"
result := UserBlock(content, 80, ty, theme)
// The rendered output should contain both file references.
if !strings.Contains(result, "@main.go") {
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
}
if !strings.Contains(result, "@utils.go") {
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
}
// Verify ANSI codes are present (the tokens are styled).
if !strings.Contains(result, "\x1b[") {
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
}
}
+96 -66
View File
@@ -35,6 +35,12 @@ type ScrollList struct {
autoScroll bool // Whether to auto-scroll to bottom on new content
itemGap int // Number of blank lines between items (0 = no gap)
// heightCache maps item ID → rendered line count at current width.
// Avoids redundant Render() calls in GotoBottom/clampOffset/AtBottom.
// Invalidated on width change; individual entries are refreshed in
// View() when an item is actually rendered.
heightCache map[string]int
// Character-level text selection (crush-style).
sel selection.State
}
@@ -42,13 +48,14 @@ type ScrollList struct {
// NewScrollList creates a new ScrollList with the given dimensions.
func NewScrollList(width, height int) *ScrollList {
return &ScrollList{
items: []MessageItem{},
offsetIdx: 0,
offsetLine: 0,
width: width,
height: height,
autoScroll: true,
sel: selection.NewState(),
items: []MessageItem{},
offsetIdx: 0,
offsetLine: 0,
width: width,
height: height,
autoScroll: true,
heightCache: make(map[string]int, 64),
sel: selection.NewState(),
}
}
@@ -61,6 +68,13 @@ func (s *ScrollList) SetItems(items []MessageItem) {
}
}
// InvalidateItemHeight removes the cached height for the given item ID,
// forcing a re-render on the next height query. Call this after mutating
// an item's content (e.g. AppendChunk on a streaming message).
func (s *ScrollList) InvalidateItemHeight(id string) {
delete(s.heightCache, id)
}
// SetHeight updates the viewport height. Called when the terminal is resized.
func (s *ScrollList) SetHeight(height int) {
s.height = height
@@ -68,9 +82,11 @@ func (s *ScrollList) SetHeight(height int) {
}
// SetWidth updates the viewport width. Called when the terminal is resized.
// This may invalidate cached renders in MessageItems.
// This invalidates the height cache since rendered heights are width-dependent.
func (s *ScrollList) SetWidth(width int) {
s.width = width
// Width change invalidates all cached heights.
clear(s.heightCache)
s.clampOffset()
}
@@ -338,9 +354,8 @@ func (s *ScrollList) ScrollBy(lines int) {
if s.offsetIdx >= len(s.items) {
break
}
currentItem := s.items[s.offsetIdx]
itemHeight := currentItem.Height()
remainingLines := itemHeight - s.offsetLine
ih := s.itemHeight(s.items[s.offsetIdx])
remainingLines := ih - s.offsetLine
if lines >= remainingLines {
// Move to next item
@@ -387,14 +402,13 @@ func (s *ScrollList) ScrollBy(lines int) {
// Move to previous item
s.offsetIdx--
if s.offsetIdx < len(s.items) {
currentItem := s.items[s.offsetIdx]
itemHeight := currentItem.Height()
ih := s.itemHeight(s.items[s.offsetIdx])
if lines >= itemHeight {
lines -= itemHeight
if lines >= ih {
lines -= ih
s.offsetLine = 0
} else {
s.offsetLine = itemHeight - lines
s.offsetLine = ih - lines
lines = 0
}
}
@@ -405,6 +419,8 @@ func (s *ScrollList) ScrollBy(lines int) {
}
// GotoBottom scrolls to the end of the list.
// Uses cached heights and walks backwards from the end to avoid rendering
// every item in the list.
func (s *ScrollList) GotoBottom() {
if len(s.items) == 0 {
s.offsetIdx = 0
@@ -412,42 +428,31 @@ func (s *ScrollList) GotoBottom() {
return
}
// Calculate total height including gaps
totalHeight := 0
for i, item := range s.items {
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
totalHeight += itemHeight
if s.itemGap > 0 && i < len(s.items)-1 {
totalHeight += s.itemGap
// Walk backwards from the last item, accumulating height until we
// exceed the viewport. This is O(visible) instead of O(all items).
budget := s.height
for idx := len(s.items) - 1; idx >= 0; idx-- {
ih := s.itemHeight(s.items[idx])
// Account for gap *above* this item (gap between idx-1 and idx).
gap := 0
if s.itemGap > 0 && idx < len(s.items)-1 {
gap = s.itemGap
}
}
// If content fits in viewport, start at top
if totalHeight <= s.height {
s.offsetIdx = 0
s.offsetLine = 0
return
}
// Otherwise, position viewport at bottom
remaining := totalHeight - s.height
for idx := 0; idx < len(s.items); idx++ {
rendered := s.items[idx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if remaining < itemHeight {
if ih+gap >= budget {
// This item (partially) fills the remaining budget.
// When the gap consumed part of the budget, offsetLine would go
// negative — clamp to 0 so the item is shown fully.
s.offsetIdx = idx
s.offsetLine = remaining
s.offsetLine = max(0, ih-budget)
return
}
remaining -= itemHeight
if s.itemGap > 0 && idx < len(s.items)-1 {
remaining -= s.itemGap
}
budget -= ih + gap
}
// Fallback: show last item
s.offsetIdx = max(0, len(s.items)-1)
// All content fits in viewport — start at top.
s.offsetIdx = 0
s.offsetLine = 0
}
@@ -465,14 +470,12 @@ func (s *ScrollList) AtBottom() bool {
visibleHeight := 0
for idx := s.offsetIdx; idx < len(s.items); idx++ {
item := s.items[idx]
rendered := item.Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
ih := s.itemHeight(s.items[idx])
if idx == s.offsetIdx {
visibleHeight += itemHeight - s.offsetLine
visibleHeight += ih - s.offsetLine
} else {
visibleHeight += itemHeight
visibleHeight += ih
}
if s.itemGap > 0 && idx < len(s.items)-1 {
@@ -520,6 +523,9 @@ func (s *ScrollList) View() string {
content := item.Render(s.width)
contentLines := strings.Split(content, "\n")
// Refresh height cache from the actual render (authoritative).
s.heightCache[item.ID()] = len(contentLines)
startLine := 0
if idx == s.offsetIdx {
startLine = s.offsetLine
@@ -568,7 +574,7 @@ func (s *ScrollList) ScrollPercent() float64 {
totalHeight := 0
for _, item := range s.items {
totalHeight += item.Height()
totalHeight += s.itemHeight(item)
}
if totalHeight <= s.height {
@@ -577,7 +583,7 @@ func (s *ScrollList) ScrollPercent() float64 {
linesAbove := 0
for i := 0; i < s.offsetIdx && i < len(s.items); i++ {
linesAbove += s.items[i].Height()
linesAbove += s.itemHeight(s.items[i])
}
linesAbove += s.offsetLine
@@ -597,7 +603,8 @@ func (s *ScrollList) ScrollPercent() float64 {
}
// clampOffset ensures the offset values are within valid bounds after
// resizing or scrolling operations.
// resizing or scrolling operations. Uses cached heights to avoid
// redundant Render() calls.
func (s *ScrollList) clampOffset() {
if len(s.items) == 0 {
s.offsetIdx = 0
@@ -605,6 +612,7 @@ func (s *ScrollList) clampOffset() {
return
}
// Clamp offsetIdx to valid item range.
if s.offsetIdx >= len(s.items) {
s.offsetIdx = len(s.items) - 1
}
@@ -612,37 +620,38 @@ func (s *ScrollList) clampOffset() {
s.offsetIdx = 0
}
// Clamp offsetLine within current item.
if s.offsetIdx < len(s.items) {
rendered := s.items[s.offsetIdx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
if s.offsetLine >= itemHeight {
s.offsetLine = max(0, itemHeight-1)
ih := s.itemHeight(s.items[s.offsetIdx])
if s.offsetLine >= ih {
s.offsetLine = max(0, ih-1)
}
}
if s.offsetLine < 0 {
s.offsetLine = 0
}
// Prevent scrolling past the bottom
// Prevent scrolling past the bottom — compute total height and check
// whether remaining content from the current offset fills the viewport.
totalHeight := 0
for i, item := range s.items {
rendered := item.Render(s.width)
totalHeight += strings.Count(rendered, "\n") + 1
totalHeight += s.itemHeight(item)
if s.itemGap > 0 && i < len(s.items)-1 {
totalHeight += s.itemGap
}
}
// If content fits in viewport, force start at top.
if totalHeight <= s.height {
s.offsetIdx = 0
s.offsetLine = 0
return
}
// Compute lines above the viewport.
linesAbove := 0
for i := 0; i < s.offsetIdx; i++ {
rendered := s.items[i].Render(s.width)
linesAbove += strings.Count(rendered, "\n") + 1
linesAbove += s.itemHeight(s.items[i])
if s.itemGap > 0 && i < len(s.items)-1 {
linesAbove += s.itemGap
}
@@ -651,20 +660,21 @@ func (s *ScrollList) clampOffset() {
linesFromCurrentToEnd := totalHeight - linesAbove
if linesFromCurrentToEnd < s.height {
// We've scrolled past the bottom — reposition so the last line
// of content sits at the bottom of the viewport.
targetLine := totalHeight - s.height
currentLine := 0
for idx := 0; idx < len(s.items); idx++ {
rendered := s.items[idx].Render(s.width)
itemHeight := strings.Count(rendered, "\n") + 1
ih := s.itemHeight(s.items[idx])
if currentLine+itemHeight > targetLine {
if currentLine+ih > targetLine {
s.offsetIdx = idx
s.offsetLine = targetLine - currentLine
return
}
currentLine += itemHeight
currentLine += ih
if s.itemGap > 0 && idx < len(s.items)-1 {
currentLine += s.itemGap
}
@@ -672,6 +682,26 @@ func (s *ScrollList) clampOffset() {
}
}
// itemHeight returns the cached rendered height for an item, computing and
// caching it on first access. This avoids calling Render() purely to
// count lines — the most common source of redundant work in the scroll
// list (GotoBottom, clampOffset, AtBottom, ScrollBy all need heights but
// never use the rendered content).
//
// The cache is invalidated wholesale on width changes (SetWidth) and
// individual entries are refreshed in View() after an item is actually
// rendered, so stale entries are self-correcting within one frame.
func (s *ScrollList) itemHeight(item MessageItem) int {
id := item.ID()
if h, ok := s.heightCache[id]; ok {
return h
}
// Cache miss — render to measure.
h := s.renderedHeight(item)
s.heightCache[id] = h
return h
}
// renderedHeight returns the height of a message item in lines by actually
// rendering it. This is the single source of truth for item height — it
// matches exactly what View() produces, unlike item.Height() which may
+17 -13
View File
@@ -21,12 +21,11 @@ func knightRiderFrames() []string {
const numDots = 8
const dot = "▪"
theme := style.GetTheme()
bright := lipgloss.NewStyle().Foreground(theme.Primary)
med := lipgloss.NewStyle().Foreground(theme.Muted)
dim := lipgloss.NewStyle().Foreground(theme.VeryMuted)
off := lipgloss.NewStyle().Foreground(theme.MutedBorder)
cs := style.GetCachedStyles()
bright := cs.SpinnerBright
med := cs.SpinnerMed
dim := cs.SpinnerDim
off := cs.SpinnerOff
// Scanner bounces: 0→7→0
positions := make([]int, 0, 2*numDots-2)
@@ -472,9 +471,12 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
// Main content using Italic with Muted color for visual distinction.
content := strings.TrimLeft(strings.Join(lines, "\n"), " \t\n")
theme := GetTheme()
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
parts = append(parts, mutedStyle.Render(s.ty.Italic(content)))
// Soft-wrap to the available width so long lines don't get cut off.
if s.width > 4 {
content = lipgloss.NewStyle().Width(s.width - 4).Render(content)
}
cs := style.GetCachedStyles()
parts = append(parts, cs.Muted.Render(s.ty.Italic(content)))
// Duration footer with VeryMuted label and Accent duration.
var duration time.Duration
@@ -490,8 +492,8 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
} else {
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
}
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
label := cs.VeryMuted.Render("Thought for ")
durationStyled := cs.Accent.Render(durationStr)
parts = append(parts, label+durationStyled)
}
@@ -588,8 +590,10 @@ func formatToolExecutionMessage(toolName string) string {
return toolName
}
// UpdateTheme refreshes the component's typography instance with colors from
// the current theme. This is called when the user changes themes via /theme.
// UpdateTheme refreshes the component's typography instance and spinner
// animation frames with colors from the current theme. This is called when
// the user changes themes via /theme.
func (s *StreamComponent) UpdateTheme() {
s.ty = createTypography(GetTheme())
s.spinnerFrames = knightRiderFrames()
}
+64
View File
@@ -40,6 +40,70 @@ func GetTheme() Theme {
func SetTheme(theme Theme) {
currentTheme = theme
markdownTypographyCache = nil // invalidate cached renderer; colors may have changed
styleCache = nil // invalidate cached styles; colors may have changed
}
// CachedStyles holds pre-built lipgloss styles that are reused across
// render frames. Invalidated by SetTheme, lazily rebuilt on next access.
// Only accessed from BubbleTea's single-threaded Update/View cycle.
type CachedStyles struct {
// render/blocks.go
FileTokenAccent lipgloss.Style // Foreground(Accent).Bold(true)
Muted lipgloss.Style // Foreground(Muted)
VeryMuted lipgloss.Style // Foreground(VeryMuted)
Accent lipgloss.Style // Foreground(Accent)
MarginBottom1 lipgloss.Style // MarginBottom(1)
// stream.go - spinner phases
SpinnerBright lipgloss.Style // Foreground(Primary)
SpinnerMed lipgloss.Style // Foreground(Muted)
SpinnerDim lipgloss.Style // Foreground(VeryMuted)
SpinnerOff lipgloss.Style // Foreground(MutedBorder)
// message_items.go - bash output
BashHeader lipgloss.Style // Foreground(Muted).Italic(true)
BashStderr lipgloss.Style // Foreground(Error)
// render/blocks.go - tool block
ToolSuccess lipgloss.Style // Foreground(Success)
ToolError lipgloss.Style // Foreground(Error)
ToolInfo lipgloss.Style // Foreground(Info).Bold(true)
ToolMuted lipgloss.Style // Foreground(Muted)
// common
ErrorFg lipgloss.Style // Foreground(Error)
TextBold lipgloss.Style // Foreground(Text).Bold(true)
}
var styleCache *CachedStyles
// GetCachedStyles returns the pre-built style cache, creating it lazily
// from the current theme. Invalidated by SetTheme.
func GetCachedStyles() *CachedStyles {
if styleCache != nil {
return styleCache
}
theme := GetTheme()
styleCache = &CachedStyles{
FileTokenAccent: lipgloss.NewStyle().Foreground(theme.Accent).Bold(true),
Muted: lipgloss.NewStyle().Foreground(theme.Muted),
VeryMuted: lipgloss.NewStyle().Foreground(theme.VeryMuted),
Accent: lipgloss.NewStyle().Foreground(theme.Accent),
MarginBottom1: lipgloss.NewStyle().MarginBottom(1),
SpinnerBright: lipgloss.NewStyle().Foreground(theme.Primary),
SpinnerMed: lipgloss.NewStyle().Foreground(theme.Muted),
SpinnerDim: lipgloss.NewStyle().Foreground(theme.VeryMuted),
SpinnerOff: lipgloss.NewStyle().Foreground(theme.MutedBorder),
BashHeader: lipgloss.NewStyle().Foreground(theme.Muted).Italic(true),
BashStderr: lipgloss.NewStyle().Foreground(theme.Error),
ToolSuccess: lipgloss.NewStyle().Foreground(theme.Success),
ToolError: lipgloss.NewStyle().Foreground(theme.Error),
ToolInfo: lipgloss.NewStyle().Foreground(theme.Info).Bold(true),
ToolMuted: lipgloss.NewStyle().Foreground(theme.Muted),
ErrorFg: lipgloss.NewStyle().Foreground(theme.Error),
TextBold: lipgloss.NewStyle().Foreground(theme.Text).Bold(true),
}
return styleCache
}
// MarkdownThemeColors defines colors for markdown rendering and syntax highlighting.
+18 -26
View File
@@ -79,8 +79,7 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
// Edit tool — side-by-side diff
// ---------------------------------------------------------------------------
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
// Supports both single-edit mode and multi-edit mode (edits array).
// renderEditBody renders a side-by-side diff from the edits array in toolArgs.
func renderEditBody(toolArgs, toolResult string, width int) string {
var args map[string]any
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
@@ -90,35 +89,28 @@ func renderEditBody(toolArgs, toolResult string, width int) string {
// Try to extract the starting line number from the unified diff in the result
startLine := extractDiffStartLine(toolResult)
// Check for multi-edit mode (edits array)
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
var results []string
for _, edit := range editsArr {
if e, ok := edit.(map[string]any); ok {
oldText, _ := e["old_text"].(string)
newText, _ := e["new_text"].(string)
if oldText != "" || newText != "" {
diff := renderDiffBlock(oldText, newText, startLine, width)
if diff != "" {
results = append(results, diff)
}
editsArr, ok := args["edits"].([]any)
if !ok || len(editsArr) == 0 {
return ""
}
var results []string
for _, edit := range editsArr {
if e, ok := edit.(map[string]any); ok {
oldText, _ := e["old_text"].(string)
newText, _ := e["new_text"].(string)
if oldText != "" || newText != "" {
diff := renderDiffBlock(oldText, newText, startLine, width)
if diff != "" {
results = append(results, diff)
}
}
}
if len(results) > 0 {
return strings.Join(results, "\n")
}
return ""
}
// Single-edit mode (legacy)
oldText, _ := args["old_text"].(string)
newText, _ := args["new_text"].(string)
if oldText == "" && newText == "" {
return ""
if len(results) > 0 {
return strings.Join(results, "\n")
}
return renderDiffBlock(oldText, newText, startLine, width)
return ""
}
// extractDiffStartLine parses the first @@ hunk header from a unified diff
-4
View File
@@ -200,10 +200,6 @@ func (ts *TreeSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+l"))):
ts.filter = TreeFilterLabelOnly
ts.rebuildFlatList()
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+a"))):
ts.filter = TreeFilterAll
ts.rebuildFlatList()
default:
// Typing search.
if msg.Text != "" && len(msg.Text) == 1 {
+125 -1
View File
@@ -77,6 +77,11 @@ host, err := kit.New(ctx, &kit.Options{
// Compaction
AutoCompact: true, // Auto-compact near context limit
// In-process MCP servers (map name → *kit.MCPServer)
InProcessMCPServers: map[string]*kit.MCPServer{
"docs": mcpSrv,
},
})
```
@@ -101,7 +106,7 @@ unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
})
defer unsub2()
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
unsub3 := host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
fmt.Print(e.Chunk)
})
defer unsub3()
@@ -112,6 +117,114 @@ response, err := host.Prompt(
)
```
### Dynamic MCP Server Management
Add, remove, and list MCP servers at runtime:
```go
// Add an MCP server at runtime
n, err := host.AddMCPServer(ctx, "github", kit.MCPServerConfig{
Command: "npx",
Args: []string{"-y", "@modelcontextprotocol/server-github"},
})
fmt.Printf("Loaded %d tools from MCP server\n", n)
// List connected MCP servers
for _, s := range host.ListMCPServers() {
fmt.Printf("%s: %d tools\n", s.Name, s.ToolCount)
}
// Disconnect a server and remove its tools
host.RemoveMCPServer("github")
```
### In-Process MCP Servers
Register mcp-go servers that run in the same process — no subprocess spawning,
no network I/O. This is ideal for custom tool servers implemented in Go:
```go
import (
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// Create an mcp-go server with tools
mcpSrv := server.NewMCPServer("my-tools", "1.0.0",
server.WithToolCapabilities(true),
)
mcpSrv.AddTool(mcp.NewTool("search_docs",
mcp.WithDescription("Search documentation"),
mcp.WithString("query", mcp.Required()),
), searchHandler)
// Option 1: At init time via Options
host, _ := kit.New(ctx, &kit.Options{
InProcessMCPServers: map[string]*kit.MCPServer{
"docs": mcpSrv,
},
})
// Option 2: At runtime
n, err := host.AddInProcessMCPServer(ctx, "docs", mcpSrv)
fmt.Printf("Loaded %d tools from in-process server\n", n)
```
Kit does not take ownership of the server's lifecycle — the caller is responsible for any cleanup. In-process server tools are prefixed the same way as external MCP servers (e.g. `"docs__search_docs"`).
### MCP Prompts
MCP servers can expose prompt templates via the MCP prompts capability.
Kit exposes these through the SDK:
```go
// List prompts from all connected MCP servers
prompts := host.ListMCPPrompts()
for _, p := range prompts {
fmt.Printf("%s/%s: %s\n", p.Server, p.Name, p.Description)
}
// Get a specific prompt with arguments
msg, err := host.GetMCPPrompt(ctx, "server-name", "prompt-name", map[string]string{
"topic": "concurrency",
})
```
### MCP Tasks (long-running tools)
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
during `initialize`. Cooperating servers can respond to `tools/call` with a
`taskId` immediately; Kit then polls `tasks/get` / `tasks/result` until the
task reaches a terminal state, and best-effort `tasks/cancel`s on context
cancellation. Servers that don't advertise the capability keep their previous
synchronous behaviour.
```go
host, _ := kit.New(ctx, &kit.Options{
// Per-server mode: auto (default), never, or always.
MCPTaskMode: map[string]kit.MCPTaskMode{
"build-server": kit.MCPTaskModeAlways,
},
MCPTaskTimeout: 15 * time.Minute, // total wall-clock cap
MCPTaskProgress: func(p kit.MCPTaskProgress) {
log.Printf("%s/%s: %s", p.Server, p.TaskID, p.Status)
},
})
// Inspect / cancel in-flight tasks
tasks, _ := host.ListMCPTasks(ctx, "build-server")
t, _ := host.GetMCPTask(ctx, "build-server", tasks[0].TaskID)
if !t.Status.IsTerminal() {
_, _ = host.CancelMCPTask(ctx, "build-server", t.TaskID)
}
```
The progress handler fires once when a task is accepted and again on every
observed status transition; the final invocation always carries a terminal
status (`MCPTaskStatusCompleted`, `MCPTaskStatusFailed`, or
`MCPTaskStatusCancelled`). Don't block in the handler — dispatch long work on
a goroutine.
### Session Management
Maintain conversation context:
@@ -145,6 +258,16 @@ kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
kit.LLMResponse // {Content, FinishReason, Usage}
kit.LLMFilePart // {Filename, Data []byte, MediaType}
// MCP OAuth types
kit.MCPServer // *server.MCPServer for in-process MCP transport
kit.MCPServerConfig // Configuration for an MCP server (stdio, SSE, or in-process)
kit.MCPAuthHandler // Interface: handles user-facing OAuth authorization
kit.DefaultMCPAuthHandler // Port + callback-server mechanics; set OnAuthURL for presentation
kit.CLIMCPAuthHandler // CLI wrapper: opens browser, prints status
kit.MCPTokenStore // Persists OAuth tokens for a single MCP server
kit.MCPToken // OAuth token (access token, refresh token, expiry)
kit.MCPTokenStoreFactory // Creates an MCPTokenStore for a given server URL
// Conversion helpers
msgs := kit.ConvertToLLMMessages(&msg) // SDK Message → []LLMMessage
msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
@@ -192,6 +315,7 @@ Key `Options` fields for SDK usage:
| `NoSession` | Ephemeral mode (no session persistence) |
| `SessionPath` | Open specific session file |
| `Continue` | Resume most recent session |
| `InProcessMCPServers` | Map of name → `*kit.MCPServer` for in-process MCP servers |
| `Debug` | Enable debug logging |
## Environment Variables
+5 -6
View File
@@ -22,13 +22,13 @@ func NewTreeManagerAdapter(tm *session.TreeManager) SessionManager {
// AppendMessage implements SessionManager.
func (a *treeManagerAdapter) AppendMessage(msg LLMMessage) (string, error) {
// LLMMessage is just an alias for fantasy.Message, so no conversion needed
// LLMMessage is a type alias, so no conversion needed.
return a.inner.AppendLLMMessage(msg)
}
// GetMessages implements SessionManager.
func (a *treeManagerAdapter) GetMessages() []LLMMessage {
// LLMMessage is just an alias for fantasy.Message
// LLMMessage is a type alias, so no conversion needed.
return a.inner.GetLLMMessages()
}
@@ -223,9 +223,8 @@ func (a *treeManagerAdapter) convertEntry(entry any) *BranchEntry {
}
}
// convertKitMessagesToFantasy converts kit LLM messages to fantasy messages.
// Since LLMMessage is an alias for fantasy.Message, this is a no-op.
func convertKitMessagesToFantasy(msgs []LLMMessage) []fantasy.Message {
// LLMMessage is just an alias for fantasy.Message, so we can type convert
// convertToLLMMessages converts kit LLM messages to the underlying provider
// message type. Since LLMMessage is a type alias, this is a no-op.
func convertToLLMMessages(msgs []LLMMessage) []fantasy.Message {
return msgs
}
+3 -3
View File
@@ -58,7 +58,7 @@ func (m *Kit) ShouldCompact() bool {
// Fall back to text-based heuristic before first turn completes.
messages := m.session.GetMessages()
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
return compaction.ShouldCompact(convertToLLMMessages(messages), info.Limit.Context, reserveTokens)
}
// GetContextStats returns current context usage statistics including
@@ -203,9 +203,9 @@ func (m *Kit) compactInternal(ctx context.Context, opts *CompactionOptions, cust
// custom summary. It still determines the cut point and persists a
// CompactionEntry.
func (m *Kit) applyCustomCompaction(summary string, messages []LLMMessage, opts *CompactionOptions) (*CompactionResult, error) {
originalTokens := compaction.EstimateMessageTokens(convertKitMessagesToFantasy(messages))
originalTokens := compaction.EstimateMessageTokens(convertToLLMMessages(messages))
cutPoint := compaction.FindCutPoint(convertKitMessagesToFantasy(messages), opts.KeepRecentTokens)
cutPoint := compaction.FindCutPoint(convertToLLMMessages(messages), opts.KeepRecentTokens)
if cutPoint == 0 {
cutPoint = len(messages) - 1
if cutPoint < 1 {
+31 -10
View File
@@ -38,20 +38,37 @@ Guidelines:
- Be concise in your responses
- Show file paths clearly when working with files`
// setSDKDefaults registers the same viper defaults that the CLI sets via
// cobra flag bindings. This ensures the SDK behaves identically to the CLI
// even when cobra is not used.
// sdkDefaultMaxTokens is the last-resort ceiling applied when the SDK caller
// has not configured max-tokens via Options, env, config, or a per-model
// default. It matches the CLI's --max-tokens cobra default so SDK and CLI
// callers see the same base value before per-model right-sizing runs.
// It is intentionally applied on the *models.ProviderConfig struct
// (not via viper) so that viper.IsSet("max-tokens") remains false and the
// right-sizing + per-model-default paths continue to work.
const sdkDefaultMaxTokens = 8192
// setSDKDefaults registers viper defaults that match the CLI's cobra flag
// defaults for keys where SetDefault does not interfere with downstream
// viper.IsSet() checks.
//
// Keys that participate in "explicit vs unset" precedence downstream —
// max-tokens, temperature, top-p, top-k, frequency-penalty, presence-penalty,
// thinking-level — are deliberately NOT registered here. viper.SetDefault
// causes viper.IsSet() to return true, which would suppress per-model
// defaults (ApplyModelSettings) and automatic right-sizing (rightSizeMaxTokens)
// for every SDK-created Kit. Those defaults are instead applied:
//
// - max-tokens: as a last-resort struct-level floor (sdkDefaultMaxTokens)
// in kit.New() after BuildProviderConfig returns, when the resolved
// value is still zero.
// - thinking-level: handled implicitly by models.ParseThinkingLevel("")
// which returns models.ThinkingOff.
// - sampling params (temperature, top-p, top-k, frequency/presence-penalty):
// left as nil pointers so provider libraries apply their own defaults.
func setSDKDefaults() {
viper.SetDefault("model", "anthropic/claude-sonnet-4-5-20250929")
viper.SetDefault("system-prompt", defaultSystemPrompt)
viper.SetDefault("max-tokens", 4096)
viper.SetDefault("temperature", 0.7)
viper.SetDefault("top-p", 0.95)
viper.SetDefault("top-k", 40)
viper.SetDefault("frequency-penalty", 0.0)
viper.SetDefault("presence-penalty", 0.0)
viper.SetDefault("stream", true)
viper.SetDefault("thinking-level", "off")
viper.SetDefault("num-gpu-layers", -1)
viper.SetDefault("main-gpu", 0)
}
@@ -102,6 +119,10 @@ func InitConfig(configFile string, debug bool) error {
}
viper.SetEnvPrefix("KIT")
// Map hyphenated config keys (e.g. "max-tokens") to underscored env
// var names (e.g. KIT_MAX_TOKENS). Without this, AutomaticEnv looks
// for KIT_MAX-TOKENS and silently misses valid env overrides.
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
viper.AutomaticEnv()
return nil
}
+472 -3
View File
@@ -23,6 +23,14 @@ const (
EventMessageUpdate EventType = "message_update"
// EventMessageEnd fires when the assistant message is complete.
EventMessageEnd EventType = "message_end"
// EventToolCallStart fires when the LLM begins generating tool call arguments.
// The tool name is known but arguments are still streaming.
EventToolCallStart EventType = "tool_call_start"
// EventToolCallDelta fires for each streamed fragment of tool call arguments.
EventToolCallDelta EventType = "tool_call_delta"
// EventToolCallEnd fires when tool argument streaming is complete, before
// the tool call is parsed and execution begins.
EventToolCallEnd EventType = "tool_call_end"
// EventToolCall fires when a tool call has been parsed and is about to execute.
EventToolCall EventType = "tool_call"
// EventToolExecutionStart fires when a tool begins executing.
@@ -45,9 +53,36 @@ const (
// EventToolOutput fires when a tool produces streaming output chunks.
EventToolOutput EventType = "tool_output"
EventStepUsage EventType = "step_usage"
// EventPasswordPrompt fires when a sudo command needs a password.
EventPasswordPrompt EventType = "password_prompt"
// EventSteerConsumed fires when one or more steering messages have been
// injected into the agent turn via PrepareStep.
EventSteerConsumed EventType = "steer_consumed"
// EventStepStart fires when a new LLM call begins within a turn.
EventStepStart EventType = "step_start"
// EventStepFinish fires when a step completes, providing full step context
// including whether tool calls were made, the finish reason, and usage stats.
EventStepFinish EventType = "step_finish"
// EventTextStart fires when the LLM begins generating text content.
EventTextStart EventType = "text_start"
// EventTextEnd fires when the LLM finishes generating text content.
EventTextEnd EventType = "text_end"
// EventReasoningStart fires when the LLM begins reasoning/thinking.
EventReasoningStart EventType = "reasoning_start"
// EventWarnings fires when the LLM provider returns warnings.
EventWarnings EventType = "warnings"
// EventSource fires when the LLM references a source (e.g. from web search).
EventSource EventType = "source"
// EventStreamFinish fires when a per-step LLM stream completes with
// usage stats and a finish reason.
EventStreamFinish EventType = "stream_finish"
// EventError fires when an agent-level error occurs during streaming.
// This is distinct from TurnEndEvent.Error — it fires at the point of
// failure, before the turn ends.
EventError EventType = "error"
// EventRetry fires when the LLM provider request is retried after a
// transient error.
EventRetry EventType = "retry"
)
// ---------------------------------------------------------------------------
@@ -108,6 +143,38 @@ func parseToolArgs(toolArgs string) map[string]any {
return nil
}
// ---------------------------------------------------------------------------
// Finish reason constants
// ---------------------------------------------------------------------------
// Finish reasons reported by the LLM provider on a completed turn. These
// mirror fantasy.FinishReason string values so comparisons against
// TurnEndEvent.StopReason / TurnResult.StopReason are stable across
// providers.
const (
// FinishReasonStop: the model produced a natural stop (e.g. stop sequence
// or end-of-turn signal).
FinishReasonStop = "stop"
// FinishReasonLength: the model hit the configured max_output_tokens
// budget. The response is truncated. Surface this to the user and
// consider raising --max-tokens / KIT_MAX_TOKENS / modelSettings[...]
// .maxTokens.
FinishReasonLength = "length"
// FinishReasonToolCalls: the model stopped to emit tool calls (normal
// mid-turn state during agentic loops).
FinishReasonToolCalls = "tool-calls"
// FinishReasonContentFilter: the provider's safety filter stopped
// generation.
FinishReasonContentFilter = "content-filter"
// FinishReasonError: the model stopped because of an error.
FinishReasonError = "error"
// FinishReasonOther: provider-specific reason that doesn't map to any of
// the above.
FinishReasonOther = "other"
// FinishReasonUnknown: the provider didn't report a finish reason.
FinishReasonUnknown = "unknown"
)
// ---------------------------------------------------------------------------
// Concrete event structs
// ---------------------------------------------------------------------------
@@ -122,9 +189,13 @@ func (e TurnStartEvent) EventType() EventType { return EventTurnStart }
// TurnEndEvent fires after the agent finishes processing.
type TurnEndEvent struct {
Response string
Error error
StopReason string // "end_turn", "max_tokens", "tool_use", "error", etc.
Response string
Error error
// StopReason is the LLM provider's finish reason for the final step of
// the turn. Compare against the FinishReason* constants — in particular,
// FinishReasonLength indicates the response was truncated because the
// agent hit its max_output_tokens budget.
StopReason string
}
// EventType implements Event.
@@ -178,6 +249,40 @@ type MessageEndEvent struct {
// EventType implements Event.
func (e MessageEndEvent) EventType() EventType { return EventMessageEnd }
// ToolCallStartEvent fires when the LLM begins generating tool call arguments.
// The tool name is known at this point but the full arguments are still being
// streamed. UIs can use this to show a "running" indicator immediately instead
// of waiting for the full argument JSON to finish streaming.
type ToolCallStartEvent struct {
ToolCallID string // Stable ID for correlating tool lifecycle events
ToolName string
ToolKind string // Tool classification: "execute", "edit", "read", "search", "agent"
}
// EventType implements Event.
func (e ToolCallStartEvent) EventType() EventType { return EventToolCallStart }
// ToolCallDeltaEvent fires for each streamed fragment of tool call arguments.
// Useful for live-previewing artifact content as it's generated, or showing a
// progress indicator with byte count.
type ToolCallDeltaEvent struct {
ToolCallID string // Stable ID for correlating tool lifecycle events
Delta string // JSON fragment of tool arguments
}
// EventType implements Event.
func (e ToolCallDeltaEvent) EventType() EventType { return EventToolCallDelta }
// ToolCallEndEvent fires when tool argument streaming is complete, before
// the tool call is parsed and execution begins. UIs can use this to
// transition from an "generating args" state to an "executing" state.
type ToolCallEndEvent struct {
ToolCallID string // Stable ID for correlating tool lifecycle events
}
// EventType implements Event.
func (e ToolCallEndEvent) EventType() EventType { return EventToolCallEnd }
// ToolCallEvent fires when a tool call has been parsed.
type ToolCallEvent struct {
ToolCallID string // Stable ID for correlating tool lifecycle events
@@ -299,6 +404,120 @@ type SteerConsumedEvent struct {
// EventType implements Event.
func (e SteerConsumedEvent) EventType() EventType { return EventSteerConsumed }
// StepStartEvent fires when a new LLM call begins within a multi-step agent turn.
type StepStartEvent struct {
StepNumber int
}
// EventType implements Event.
func (e StepStartEvent) EventType() EventType { return EventStepStart }
// StepFinishEvent fires when a step completes, providing full step context.
// This is a unified event that carries the same data as the existing
// ToolCallContentEvent and StepUsageEvent, plus additional step metadata.
type StepFinishEvent struct {
StepNumber int
HasToolCalls bool
FinishReason string
Usage LLMUsage
}
// EventType implements Event.
func (e StepFinishEvent) EventType() EventType { return EventStepFinish }
// TextStartEvent fires when the LLM begins generating text content.
// Paired with MessageUpdateEvent (deltas) and TextEndEvent.
type TextStartEvent struct {
ID string
}
// EventType implements Event.
func (e TextStartEvent) EventType() EventType { return EventTextStart }
// TextEndEvent fires when the LLM finishes generating text content.
type TextEndEvent struct {
ID string
}
// EventType implements Event.
func (e TextEndEvent) EventType() EventType { return EventTextEnd }
// ReasoningStartEvent fires when the LLM begins reasoning/thinking.
// Paired with ReasoningDeltaEvent (deltas) and ReasoningCompleteEvent.
type ReasoningStartEvent struct {
ID string
}
// EventType implements Event.
func (e ReasoningStartEvent) EventType() EventType { return EventReasoningStart }
// WarningsEvent fires when the LLM provider returns warnings about the request.
type WarningsEvent struct {
Warnings []string
}
// EventType implements Event.
func (e WarningsEvent) EventType() EventType { return EventWarnings }
// SourceEvent fires when the LLM references a source (e.g. from web search tools).
type SourceEvent struct {
SourceType string
ID string
URL string
Title string
}
// EventType implements Event.
func (e SourceEvent) EventType() EventType { return EventSource }
// StreamFinishEvent fires when a per-step LLM stream completes.
// Provides per-stream usage stats and finish reason.
type StreamFinishEvent struct {
Usage LLMUsage
FinishReason string
}
// EventType implements Event.
func (e StreamFinishEvent) EventType() EventType { return EventStreamFinish }
// ErrorEvent fires when an agent-level error occurs during streaming.
// This is distinct from TurnEndEvent.Error — it fires at the point of failure.
type ErrorEvent struct {
Error error
}
// EventType implements Event.
func (e ErrorEvent) EventType() EventType { return EventError }
// RetryEvent fires when the LLM provider request is retried after a transient error.
type RetryEvent struct {
Attempt int
Error error
}
// EventType implements Event.
func (e RetryEvent) EventType() EventType { return EventRetry }
// PasswordPromptEvent fires when a sudo command needs a password.
// The TUI should display a password prompt and send the result back via ResponseCh.
type PasswordPromptEvent struct {
// Prompt is the message to display to the user.
Prompt string
// ResponseCh receives the password from the TUI.
// The TUI must send exactly one value: (password, false) for submit
// or ("", true) for cancel.
ResponseCh chan<- PasswordPromptResponse
}
// PasswordPromptResponse carries the password prompt result.
type PasswordPromptResponse struct {
Password string
Cancelled bool
}
// EventType implements Event.
func (e PasswordPromptEvent) EventType() EventType { return EventPasswordPrompt }
// ---------------------------------------------------------------------------
// EventBus
// ---------------------------------------------------------------------------
@@ -362,6 +581,39 @@ func (m *Kit) OnToolCall(handler func(ToolCallEvent)) func() {
})
}
// OnToolCallStart registers a handler that fires only for ToolCallStartEvent.
// This fires when the LLM begins generating tool call arguments — before the
// full argument JSON is available. Returns an unsubscribe function.
func (m *Kit) OnToolCallStart(handler func(ToolCallStartEvent)) func() {
return m.Subscribe(func(e Event) {
if tcs, ok := e.(ToolCallStartEvent); ok {
handler(tcs)
}
})
}
// OnToolCallDelta registers a handler that fires only for ToolCallDeltaEvent.
// Each delta contains a JSON fragment of tool call arguments as they stream in.
// Returns an unsubscribe function.
func (m *Kit) OnToolCallDelta(handler func(ToolCallDeltaEvent)) func() {
return m.Subscribe(func(e Event) {
if tcd, ok := e.(ToolCallDeltaEvent); ok {
handler(tcd)
}
})
}
// OnToolCallEnd registers a handler that fires only for ToolCallEndEvent.
// This fires when tool argument streaming is complete, before the tool call
// is parsed and execution begins. Returns an unsubscribe function.
func (m *Kit) OnToolCallEnd(handler func(ToolCallEndEvent)) func() {
return m.Subscribe(func(e Event) {
if tce, ok := e.(ToolCallEndEvent); ok {
handler(tce)
}
})
}
// OnToolResult registers a handler that fires only for ToolResultEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolResult(handler func(ToolResultEvent)) func() {
@@ -384,7 +636,16 @@ func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() {
// OnStreaming registers a handler that fires only for MessageUpdateEvent
// (streaming text chunks). Returns an unsubscribe function.
//
// Deprecated: Use OnMessageUpdate instead. OnStreaming will be removed in a
// future release.
func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() {
return m.OnMessageUpdate(handler)
}
// OnMessageUpdate registers a handler that fires only for MessageUpdateEvent
// (streaming text chunks). Returns an unsubscribe function.
func (m *Kit) OnMessageUpdate(handler func(MessageUpdateEvent)) func() {
return m.Subscribe(func(e Event) {
if mu, ok := e.(MessageUpdateEvent); ok {
handler(mu)
@@ -422,6 +683,214 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
})
}
// ---------------------------------------------------------------------------
// Typed subscribers for previously unsubscribed event types
// ---------------------------------------------------------------------------
// OnMessageStart registers a handler that fires only for MessageStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnMessageStart(handler func(MessageStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ms, ok := e.(MessageStartEvent); ok {
handler(ms)
}
})
}
// OnMessageEnd registers a handler that fires only for MessageEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnMessageEnd(handler func(MessageEndEvent)) func() {
return m.Subscribe(func(e Event) {
if me, ok := e.(MessageEndEvent); ok {
handler(me)
}
})
}
// OnReasoningDelta registers a handler that fires only for ReasoningDeltaEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningDelta(handler func(ReasoningDeltaEvent)) func() {
return m.Subscribe(func(e Event) {
if rd, ok := e.(ReasoningDeltaEvent); ok {
handler(rd)
}
})
}
// OnReasoningComplete registers a handler that fires only for ReasoningCompleteEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningComplete(handler func(ReasoningCompleteEvent)) func() {
return m.Subscribe(func(e Event) {
if rc, ok := e.(ReasoningCompleteEvent); ok {
handler(rc)
}
})
}
// OnToolExecutionStart registers a handler that fires only for ToolExecutionStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolExecutionStart(handler func(ToolExecutionStartEvent)) func() {
return m.Subscribe(func(e Event) {
if tes, ok := e.(ToolExecutionStartEvent); ok {
handler(tes)
}
})
}
// OnToolExecutionEnd registers a handler that fires only for ToolExecutionEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolExecutionEnd(handler func(ToolExecutionEndEvent)) func() {
return m.Subscribe(func(e Event) {
if tee, ok := e.(ToolExecutionEndEvent); ok {
handler(tee)
}
})
}
// OnToolCallContent registers a handler that fires only for ToolCallContentEvent.
// Returns an unsubscribe function.
func (m *Kit) OnToolCallContent(handler func(ToolCallContentEvent)) func() {
return m.Subscribe(func(e Event) {
if tcc, ok := e.(ToolCallContentEvent); ok {
handler(tcc)
}
})
}
// OnStepUsage registers a handler that fires only for StepUsageEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepUsage(handler func(StepUsageEvent)) func() {
return m.Subscribe(func(e Event) {
if su, ok := e.(StepUsageEvent); ok {
handler(su)
}
})
}
// OnCompaction registers a handler that fires only for CompactionEvent.
// Returns an unsubscribe function.
func (m *Kit) OnCompaction(handler func(CompactionEvent)) func() {
return m.Subscribe(func(e Event) {
if ce, ok := e.(CompactionEvent); ok {
handler(ce)
}
})
}
// OnSteerConsumed registers a handler that fires only for SteerConsumedEvent.
// Returns an unsubscribe function.
func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
return m.Subscribe(func(e Event) {
if sc, ok := e.(SteerConsumedEvent); ok {
handler(sc)
}
})
}
// ---------------------------------------------------------------------------
// Typed subscribers for new event types
// ---------------------------------------------------------------------------
// OnStepStart registers a handler that fires only for StepStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepStart(handler func(StepStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ss, ok := e.(StepStartEvent); ok {
handler(ss)
}
})
}
// OnStepFinish registers a handler that fires only for StepFinishEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStepFinish(handler func(StepFinishEvent)) func() {
return m.Subscribe(func(e Event) {
if sf, ok := e.(StepFinishEvent); ok {
handler(sf)
}
})
}
// OnTextStart registers a handler that fires only for TextStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTextStart(handler func(TextStartEvent)) func() {
return m.Subscribe(func(e Event) {
if ts, ok := e.(TextStartEvent); ok {
handler(ts)
}
})
}
// OnTextEnd registers a handler that fires only for TextEndEvent.
// Returns an unsubscribe function.
func (m *Kit) OnTextEnd(handler func(TextEndEvent)) func() {
return m.Subscribe(func(e Event) {
if te, ok := e.(TextEndEvent); ok {
handler(te)
}
})
}
// OnReasoningStart registers a handler that fires only for ReasoningStartEvent.
// Returns an unsubscribe function.
func (m *Kit) OnReasoningStart(handler func(ReasoningStartEvent)) func() {
return m.Subscribe(func(e Event) {
if rs, ok := e.(ReasoningStartEvent); ok {
handler(rs)
}
})
}
// OnWarnings registers a handler that fires only for WarningsEvent.
// Returns an unsubscribe function.
func (m *Kit) OnWarnings(handler func(WarningsEvent)) func() {
return m.Subscribe(func(e Event) {
if w, ok := e.(WarningsEvent); ok {
handler(w)
}
})
}
// OnSource registers a handler that fires only for SourceEvent.
// Returns an unsubscribe function.
func (m *Kit) OnSource(handler func(SourceEvent)) func() {
return m.Subscribe(func(e Event) {
if s, ok := e.(SourceEvent); ok {
handler(s)
}
})
}
// OnStreamFinish registers a handler that fires only for StreamFinishEvent.
// Returns an unsubscribe function.
func (m *Kit) OnStreamFinish(handler func(StreamFinishEvent)) func() {
return m.Subscribe(func(e Event) {
if sf, ok := e.(StreamFinishEvent); ok {
handler(sf)
}
})
}
// OnError registers a handler that fires only for ErrorEvent.
// Returns an unsubscribe function.
func (m *Kit) OnError(handler func(ErrorEvent)) func() {
return m.Subscribe(func(e Event) {
if ee, ok := e.(ErrorEvent); ok {
handler(ee)
}
})
}
// OnRetry registers a handler that fires only for RetryEvent.
// Returns an unsubscribe function.
func (m *Kit) OnRetry(handler func(RetryEvent)) func() {
return m.Subscribe(func(e Event) {
if r, ok := e.(RetryEvent); ok {
handler(r)
}
})
}
// ---------------------------------------------------------------------------
// Subagent event subscriptions
// ---------------------------------------------------------------------------
+69
View File
@@ -1,6 +1,7 @@
package kit
import (
"fmt"
"sync"
"sync/atomic"
"testing"
@@ -190,6 +191,74 @@ func TestEventTypes(t *testing.T) {
}
}
// TestNewEventTypes verifies that each new event struct returns the correct EventType.
func TestNewEventTypes(t *testing.T) {
tests := []struct {
event Event
expected EventType
}{
{StepStartEvent{StepNumber: 0}, EventStepStart},
{StepFinishEvent{StepNumber: 1, HasToolCalls: true}, EventStepFinish},
{TextStartEvent{ID: "text-1"}, EventTextStart},
{TextEndEvent{ID: "text-1"}, EventTextEnd},
{ReasoningStartEvent{ID: "reason-1"}, EventReasoningStart},
{WarningsEvent{Warnings: []string{"test"}}, EventWarnings},
{SourceEvent{URL: "https://example.com", Title: "Example"}, EventSource},
{StreamFinishEvent{FinishReason: "stop"}, EventStreamFinish},
{ErrorEvent{Error: fmt.Errorf("test error")}, EventError},
{RetryEvent{Attempt: 1, Error: fmt.Errorf("retry error")}, EventRetry},
{ToolCallStartEvent{}, EventToolCallStart},
{ToolCallDeltaEvent{}, EventToolCallDelta},
{ToolCallEndEvent{}, EventToolCallEnd},
{PasswordPromptEvent{}, EventPasswordPrompt},
}
for _, tt := range tests {
if got := tt.event.EventType(); got != tt.expected {
t.Errorf("%T.EventType() = %q, want %q", tt.event, got, tt.expected)
}
}
}
// TestNewEventEmission verifies that new event types are properly emitted and received.
func TestNewEventEmission(t *testing.T) {
bus := newEventBus()
var received []Event
bus.subscribe(func(e Event) {
received = append(received, e)
})
bus.emit(StepStartEvent{StepNumber: 0})
bus.emit(TextStartEvent{ID: "text-1"})
bus.emit(TextEndEvent{ID: "text-1"})
bus.emit(ReasoningStartEvent{ID: "reason-1"})
bus.emit(WarningsEvent{Warnings: []string{"low confidence"}})
bus.emit(SourceEvent{URL: "https://example.com", Title: "Example"})
bus.emit(StreamFinishEvent{FinishReason: "stop"})
bus.emit(StepFinishEvent{StepNumber: 0, HasToolCalls: false, FinishReason: "stop"})
bus.emit(ErrorEvent{Error: fmt.Errorf("test error")})
bus.emit(RetryEvent{Attempt: 1, Error: fmt.Errorf("retry")})
if len(received) != 10 {
t.Fatalf("expected 10 events, got %d", len(received))
}
// Verify specific event fields
if ss, ok := received[0].(StepStartEvent); !ok || ss.StepNumber != 0 {
t.Errorf("event 0: expected StepStartEvent{StepNumber:0}, got %T %+v", received[0], received[0])
}
if ts, ok := received[1].(TextStartEvent); !ok || ts.ID != "text-1" {
t.Errorf("event 1: expected TextStartEvent{ID:text-1}, got %T %+v", received[1], received[1])
}
if w, ok := received[4].(WarningsEvent); !ok || len(w.Warnings) != 1 || w.Warnings[0] != "low confidence" {
t.Errorf("event 4: expected WarningsEvent with 1 warning, got %T %+v", received[4], received[4])
}
if sf, ok := received[7].(StepFinishEvent); !ok || sf.StepNumber != 0 || sf.HasToolCalls {
t.Errorf("event 7: expected StepFinishEvent{StepNumber:0, HasToolCalls:false}, got %T %+v", received[7], received[7])
}
}
// TestEventBusListenerCanUnsubscribeInCallback verifies that a listener can
// safely call its own unsubscribe function from within the callback.
func TestEventBusListenerCanUnsubscribeInCallback(t *testing.T) {
+162
View File
@@ -100,6 +100,38 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
})
}
// Tool call input streaming events — fire as the LLM generates tool arguments.
if runner.HasHandlers(extensions.ToolCallInputStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallStartEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputStartEvent{
ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName,
ToolKind: ev.ToolKind,
})
}
})
}
if runner.HasHandlers(extensions.ToolCallInputDelta) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallDeltaEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputDeltaEvent{
ToolCallID: ev.ToolCallID,
Delta: ev.Delta,
})
}
})
}
if runner.HasHandlers(extensions.ToolCallInputEnd) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ToolCallEndEvent); ok {
_, _ = runner.Emit(extensions.ToolCallInputEndEvent{
ToolCallID: ev.ToolCallID,
})
}
})
}
if runner.HasHandlers(extensions.AgentEnd) {
m.Subscribe(func(e Event) {
if ev, ok := e.(TurnEndEvent); ok {
@@ -324,4 +356,134 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
return nil
})
}
// --- Step lifecycle observation events ---
if runner.HasHandlers(extensions.StepStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(StepStartEvent); ok {
_, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber})
}
})
}
if runner.HasHandlers(extensions.StepFinish) {
m.Subscribe(func(e Event) {
if ev, ok := e.(StepFinishEvent); ok {
_, _ = runner.Emit(extensions.StepFinishEvent{
StepNumber: ev.StepNumber,
HasToolCalls: ev.HasToolCalls,
FinishReason: ev.FinishReason,
InputTokens: ev.Usage.InputTokens,
OutputTokens: ev.Usage.OutputTokens,
CacheReadTokens: ev.Usage.CacheReadTokens,
CacheWriteTokens: ev.Usage.CacheCreationTokens,
})
}
})
}
if runner.HasHandlers(extensions.ReasoningStart) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ReasoningStartEvent); ok {
_, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID})
}
})
}
if runner.HasHandlers(extensions.Warnings) {
m.Subscribe(func(e Event) {
if ev, ok := e.(WarningsEvent); ok {
_, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings})
}
})
}
if runner.HasHandlers(extensions.Source) {
m.Subscribe(func(e Event) {
if ev, ok := e.(SourceEvent); ok {
_, _ = runner.Emit(extensions.SourceEvent{
SourceType: ev.SourceType,
ID: ev.ID,
URL: ev.URL,
Title: ev.Title,
})
}
})
}
if runner.HasHandlers(extensions.Error) {
m.Subscribe(func(e Event) {
if ev, ok := e.(ErrorEvent); ok {
_, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()})
}
})
}
if runner.HasHandlers(extensions.Retry) {
m.Subscribe(func(e Event) {
if ev, ok := e.(RetryEvent); ok {
_, _ = runner.Emit(extensions.RetryEvent{
Attempt: ev.Attempt,
Error: ev.Error.Error(),
})
}
})
}
// --- PrepareStep hook ---
// Extension PrepareStep → SDK PrepareStep hook.
// Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage.
if runner.HasHandlers(extensions.PrepareStep) {
m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
// Convert LLM message slice to extension ContextMessage slice.
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
for i, msg := range h.Messages {
var sb strings.Builder
for _, part := range msg.Content {
if tp, ok := part.(LLMTextPart); ok {
sb.WriteString(tp.Text)
}
}
extMsgs[i] = extensions.ContextMessage{
Index: i,
Role: string(msg.Role),
Content: sb.String(),
}
}
result, _ := runner.Emit(extensions.PrepareStepEvent{
StepNumber: h.StepNumber,
Messages: extMsgs,
})
r, ok := result.(extensions.PrepareStepResult)
if !ok || r.Messages == nil {
return nil
}
// Rebuild LLM message slice from extension result.
rebuilt := make([]LLMMessage, 0, len(r.Messages))
for _, cm := range r.Messages {
if cm.Index >= 0 && cm.Index < len(h.Messages) {
rebuilt = append(rebuilt, h.Messages[cm.Index])
} else {
role := LLMRoleUser
switch cm.Role {
case "assistant":
role = LLMRoleAssistant
case "system":
role = LLMRoleSystem
case "tool":
role = LLMRoleTool
}
rebuilt = append(rebuilt, LLMMessage{
Role: role,
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
})
}
}
return &PrepareStepResult{Messages: rebuilt}
})
}
}
+48 -11
View File
@@ -5,8 +5,6 @@ import (
"fmt"
"sort"
"sync"
"charm.land/fantasy"
)
// ---------------------------------------------------------------------------
@@ -121,6 +119,32 @@ type BeforeCompactResult struct {
Summary string
}
// PrepareStepHook is the input for hooks that fire between steps within a
// multi-step agent turn, with full message replacement capability. This is
// the most powerful interception point — it fires after the existing steering
// logic (if any) and before the messages are sent to the LLM.
//
// Use cases:
// - Transforming tool results (e.g. converting image tool results to FilePart
// user messages for vision models that don't support media in tool results)
// - Dynamic tool filtering per step
// - Mid-turn context injection beyond simple steering
// - Custom stop conditions that inspect message history
type PrepareStepHook struct {
// StepNumber is the zero-based step index within the current turn.
StepNumber int
// Messages is the current context window that will be sent to the LLM.
// This includes any steering messages already injected in this step.
Messages []LLMMessage
}
// PrepareStepResult can replace the context window between steps.
type PrepareStepResult struct {
// Messages replaces the entire context window for this step. If nil,
// the original messages (including any steering) are used unchanged.
Messages []LLMMessage
}
// ---------------------------------------------------------------------------
// Generic hook registry with priority ordering
// ---------------------------------------------------------------------------
@@ -248,6 +272,19 @@ func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeC
return m.beforeCompact.register(p, h)
}
// OnPrepareStep registers a hook that fires between steps within a multi-step
// agent turn, after steering messages are injected and before the messages are
// sent to the LLM. Return a non-nil PrepareStepResult with Messages to replace
// the entire context window for this step. Hooks execute in priority order;
// the first non-nil result wins. Returns an unregister function.
//
// This is the most powerful interception point in the agent lifecycle. It
// enables patterns like transforming tool results, dynamic tool filtering,
// and mid-turn context injection.
func (m *Kit) OnPrepareStep(p HookPriority, h func(PrepareStepHook) *PrepareStepResult) func() {
return m.prepareStep.register(p, h)
}
// ---------------------------------------------------------------------------
// Tool wrapping via hooks
// ---------------------------------------------------------------------------
@@ -256,16 +293,16 @@ func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeC
// AfterToolResult hooks around each execution. The registries are referenced
// by pointer so hooks added after agent creation are still invoked.
type hookedTool struct {
inner fantasy.AgentTool
inner Tool
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
}
func (h *hookedTool) Info() fantasy.ToolInfo { return h.inner.Info() }
func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions { return h.inner.ProviderOptions() }
func (h *hookedTool) SetProviderOptions(o fantasy.ProviderOptions) { h.inner.SetProviderOptions(o) }
func (h *hookedTool) Info() LLMToolInfo { return h.inner.Info() }
func (h *hookedTool) ProviderOptions() LLMProviderOptions { return h.inner.ProviderOptions() }
func (h *hookedTool) SetProviderOptions(o LLMProviderOptions) { h.inner.SetProviderOptions(o) }
func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
func (h *hookedTool) Run(ctx context.Context, call LLMToolCall) (LLMToolResponse, error) {
toolName := h.inner.Info().Name
// 1. BeforeToolCall — can block execution.
@@ -279,7 +316,7 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
if reason == "" {
reason = "blocked by hook"
}
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
return newLLMTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
fmt.Errorf("tool blocked by hook: %s", reason)
}
}
@@ -314,9 +351,9 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
func hookToolWrapper(
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult],
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult],
) func([]fantasy.AgentTool) []fantasy.AgentTool {
return func(tools []fantasy.AgentTool) []fantasy.AgentTool {
wrapped := make([]fantasy.AgentTool, len(tools))
) func([]Tool) []Tool {
return func(tools []Tool) []Tool {
wrapped := make([]Tool, len(tools))
for i, tool := range tools {
wrapped[i] = &hookedTool{
inner: tool,
+98 -28
View File
@@ -5,8 +5,6 @@ import (
"fmt"
"sync"
"testing"
"charm.land/fantasy"
)
// ---------------------------------------------------------------------------
@@ -177,20 +175,20 @@ func TestHookRegistry_ConcurrentAccess(t *testing.T) {
// mockAgentTool implements the AgentTool interface for testing.
type mockAgentTool struct {
name string
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
popts fantasy.ProviderOptions
runFn func(ctx context.Context, call LLMToolCall) (LLMToolResponse, error)
popts LLMProviderOptions
}
func (m *mockAgentTool) Info() fantasy.ToolInfo {
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
func (m *mockAgentTool) Info() LLMToolInfo {
return LLMToolInfo{Name: m.name, Description: "mock tool"}
}
func (m *mockAgentTool) ProviderOptions() fantasy.ProviderOptions { return m.popts }
func (m *mockAgentTool) SetProviderOptions(o fantasy.ProviderOptions) { m.popts = o }
func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
func (m *mockAgentTool) ProviderOptions() LLMProviderOptions { return m.popts }
func (m *mockAgentTool) SetProviderOptions(o LLMProviderOptions) { m.popts = o }
func (m *mockAgentTool) Run(ctx context.Context, call LLMToolCall) (LLMToolResponse, error) {
if m.runFn != nil {
return m.runFn(ctx, call)
}
return fantasy.NewTextResponse("default output"), nil
return newLLMTextResponse("default output"), nil
}
// newEmptyHookedTool creates a hookedTool with empty hook registries and the given mock tool.
@@ -203,14 +201,14 @@ func newEmptyHookedTool(mock *mockAgentTool) *hookedTool {
func TestHookedTool_Passthrough(t *testing.T) {
mock := &mockAgentTool{
name: "test_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("hello world"), nil
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
return newLLMTextResponse("hello world"), nil
},
}
ht := newEmptyHookedTool(mock)
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -226,9 +224,9 @@ func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
toolRan := false
mock := &mockAgentTool{
name: "dangerous_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
toolRan = true
return fantasy.NewTextResponse("should not run"), nil
return newLLMTextResponse("should not run"), nil
},
}
@@ -241,7 +239,7 @@ func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
if err == nil {
t.Fatal("expected error from blocked tool")
}
@@ -263,7 +261,7 @@ func TestHookedTool_BeforeToolCallBlockDefaultReason(t *testing.T) {
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, _ := ht.Run(context.Background(), fantasy.ToolCall{})
resp, _ := ht.Run(context.Background(), LLMToolCall{})
if resp.Content != "Error: blocked by hook" {
t.Errorf("expected default block reason, got %q", resp.Content)
}
@@ -275,8 +273,8 @@ func TestHookedTool_AfterToolResultModify(t *testing.T) {
mock := &mockAgentTool{
name: "tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("secret data"), nil
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
return newLLMTextResponse("secret data"), nil
},
}
@@ -286,7 +284,7 @@ func TestHookedTool_AfterToolResultModify(t *testing.T) {
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -301,8 +299,8 @@ func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
mock := &mockAgentTool{
name: "tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("ok"), nil
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
return newLLMTextResponse("ok"), nil
},
}
@@ -312,7 +310,7 @@ func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
resp, err := ht.Run(context.Background(), fantasy.ToolCall{})
resp, err := ht.Run(context.Background(), LLMToolCall{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -327,8 +325,8 @@ func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
mock := &mockAgentTool{
name: "my_tool",
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.NewTextResponse("result"), nil
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
return newLLMTextResponse("result"), nil
},
}
@@ -345,7 +343,7 @@ func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
})
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
_, _ = ht.Run(context.Background(), fantasy.ToolCall{Input: `{"key":"value"}`})
_, _ = ht.Run(context.Background(), LLMToolCall{Input: `{"key":"value"}`})
if capturedBefore.ToolName != "my_tool" {
t.Errorf("BeforeToolCall: expected tool name 'my_tool', got %q", capturedBefore.ToolName)
@@ -380,7 +378,7 @@ func TestHookToolWrapper(t *testing.T) {
wrapper := hookToolWrapper(before, after)
tools := []fantasy.AgentTool{
tools := []Tool{
&mockAgentTool{name: "tool_a"},
&mockAgentTool{name: "tool_b"},
}
@@ -407,7 +405,7 @@ func TestHookToolWrapper(t *testing.T) {
return &BeforeToolCallResult{Block: true, Reason: "late hook"}
})
_, err := wrapped[0].Run(context.Background(), fantasy.ToolCall{})
_, err := wrapped[0].Run(context.Background(), LLMToolCall{})
if err == nil {
t.Error("expected error from late-registered blocking hook")
}
@@ -538,3 +536,75 @@ func TestKit_HookMethodsExist(t *testing.T) {
u3()
u4()
}
// TestPrepareStepHookRegistry verifies registration and execution of PrepareStep hooks.
func TestPrepareStepHookRegistry(t *testing.T) {
hr := newHookRegistry[PrepareStepHook, PrepareStepResult]()
// Register a hook that appends a message.
hr.register(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
if h.StepNumber == 0 {
// On step 0, prepend a system message.
newMsgs := make([]LLMMessage, 0, len(h.Messages)+1)
newMsgs = append(newMsgs, NewLLMSystemMessage("injected"))
newMsgs = append(newMsgs, h.Messages...)
return &PrepareStepResult{Messages: newMsgs}
}
return nil // No modification for other steps.
})
// Test step 0 — should modify messages.
input := PrepareStepHook{
StepNumber: 0,
Messages: []LLMMessage{NewLLMUserMessage("hello")},
}
result := hr.run(input)
if result == nil {
t.Fatal("expected non-nil result for step 0")
}
if len(result.Messages) != 2 {
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
}
if result.Messages[0].Role != LLMRoleSystem {
t.Errorf("expected system message first, got role %q", result.Messages[0].Role)
}
// Test step 1 — should return nil (no modification).
input.StepNumber = 1
result = hr.run(input)
if result != nil {
t.Errorf("expected nil result for step 1, got %+v", result)
}
}
// TestPrepareStepHookPriority verifies that PrepareStep hooks respect priority ordering.
func TestPrepareStepHookPriority(t *testing.T) {
hr := newHookRegistry[PrepareStepHook, PrepareStepResult]()
var order []string
// Low priority — should run second.
hr.register(HookPriorityLow, func(_ PrepareStepHook) *PrepareStepResult {
order = append(order, "low")
return nil
})
// High priority — should run first and win.
hr.register(HookPriorityHigh, func(h PrepareStepHook) *PrepareStepResult {
order = append(order, "high")
return &PrepareStepResult{Messages: h.Messages}
})
input := PrepareStepHook{
StepNumber: 0,
Messages: []LLMMessage{NewLLMUserMessage("test")},
}
result := hr.run(input)
if result == nil {
t.Fatal("expected non-nil result")
}
if len(order) != 1 || order[0] != "high" {
t.Errorf("expected [high] (first non-nil wins), got %v", order)
}
}

Some files were not shown because too many files have changed in this diff Show More