Compare commits

...

24 Commits

Author SHA1 Message Date
Ed Zynda a322dfc59a fix(ui): eliminate mouse copy-selection drift during streaming
- Lock viewport scroll while a drag-select is active so highlighted
  content stays under the cursor (SetItems, appendStreamingChunk,
  MouseWheelDown all now honor IsMouseDown).
- HandleMouseDrag defensively clears autoScroll on every update so a
  racy re-enable can't shift the row mid-drag.
- Recompute scrollback yOffset/viewport height on each mouse event
  via currentScrollbackBounds() instead of relying on stale values
  cached during the previous View() pass.
- Account for canceling/ctrlCPressedOnce warning rows in
  distributeHeight and mark layoutDirty when those flags toggle so
  the height budget and mouse origin stay in sync.
- Add ScrollList regression tests covering the three invariants.
2026-05-15 13:30:57 +03:00
Ed Zynda b1387d837e feat(ui): add /copy slash command to copy last message
- Register /copy (alias /cp) in the System command category
- Walk the scrollback to find the last user/assistant/reasoning
  message, skipping transient system messages
- Reuse internal/ui/clipboard.CopyToClipboard for OSC 52 + native
  clipboard support (works over SSH)
- Document the command in /help
2026-05-15 13:06:35 +03:00
Ed Zynda f561f4cfd9 fix(session): order kept messages before post-compact branch in BuildContext
After /compact, BuildContext emitted [summary, post-compact, kept]
which placed an older kept user/assistant turn after the latest
post-compaction turn. This broke user/assistant alternation and caused
the model to respond as if the post-compaction turn never happened on
the next user message.

- Emit kept messages chronologically before post-compaction messages
- Mirror the same order in GetContextEntryIDs so cut-point to entry-ID
  mapping stays aligned across repeat compactions
- Update TestCompactionWithNewMessagesAfterCompaction to assert the
  correct chronological order
2026-05-14 20:42:20 +03:00
Ed Zynda 64caed57d4 fix(sdk): stop leaking fantasy types through pkg/kit.AgentConfig (#30) (#32)
* fix(sdk): stop leaking fantasy types through pkg/kit.AgentConfig (#30)

Replace the alias-based AgentConfig and handler types with SDK-owned
structs and function types. CoreTools / ExtraTools / ToolWrapper now
accept []kit.Tool, and the handler types (ToolCallHandler,
ToolExecutionHandler, ToolResultHandler, ResponseHandler,
StreamingResponseHandler, ToolCallContentHandler) plus SpinnerFunc are
declared in pkg/kit/ with signatures that reference only SDK types.

Consumers no longer need to import charm.land/fantasy to populate an
AgentConfig or assign a handler. go doc pkg/kit AgentConfig output no
longer mentions fantasy.*.

- Add unexported (*AgentConfig).toInternal() to convert at the SDK
  boundary; Tool is still an alias for the underlying tool type, so
  slice and function fields convert without allocation.
- Add agent_config_internal_test.go covering nil receiver, scalar
  fields, tool slices, ToolWrapper invocation, OnMCPServerLoaded, and
  auth/token-factory wiring.
- Add types_test.go cases that populate AgentConfig and SpinnerFunc
  without importing fantasy -- the file compiling is the regression
  proof for the leak.
- Update pkg/kit/README.md Re-exported Types section to record that
  AgentConfig and the handler types are now Kit-owned.

Fixes #30

* fix(sdk): add DebugLogger and MCPTaskConfig to kit.AgentConfig (#30)

The first revision of the SDK-owned AgentConfig dropped two fields that
internal/agent.AgentConfig carried: DebugLogger (tools.DebugLogger) and
MCPTaskConfig (tools.MCPTaskConfig). Restore them with SDK-owned
equivalents and wire them through toInternal().

- Add kit.DebugLogger interface (LogDebug / IsDebugEnabled) mirroring
  tools.DebugLogger. Interface-to-interface assignment is automatic
  because the method sets match.
- Add kit.MCPTaskConfig struct mirroring tools.MCPTaskConfig with SDK
  types (MCPTaskMode, MCPTaskProgressHandler) and a toToolsConfig()
  helper that converts at the SDK boundary.
- Wire both new fields in (*AgentConfig).toInternal().
- Extend agent_config_internal_test.go with cases for both fields.
- Document the additions in pkg/kit/README.md.
2026-05-13 21:10:28 +03:00
Ed Zynda 975c30a773 fix(mcp): surface MCP tool failures as soft errors, not critical aborts (#31)
The MCP adapter previously wrapped any error returned by MCPToolManager.ExecuteTool
into a Go error returned from the fantasy.AgentTool.Run interface. The fantasy
agent loop treats those as critical errors and aborts the entire turn —
discarding all prior reasoning, tool calls, and results.

In practice that meant a single misbehaved MCP server returning a JSON-RPC
"-32602 Invalid params" (e.g. a Zod schema mismatch on the server's input
validation) would kill an in-progress turn after the model had already done
dozens of seconds of useful work, with no way for the model to see the
validation message and self-correct.

This mismatched the contract that native Kit tools follow: native tools
return errors via kit.ErrorResult(...), which become soft tool-result errors
that the model reads and can act on (retry with corrected args, try a
different tool, give up gracefully).

Make the MCP path behave the same way:

  - JSON-RPC protocol errors, transport failures, and server-side schema
    rejections are now returned as fantasy.NewTextErrorResponse(...) with
    err == nil, so the agent loop continues and the model sees the failure
    in-band as a tool result it can reason about.
  - Context cancellation (ctx.Err() != nil) remains a critical error so
    callers can abort turns deterministically. This is the only case where
    bubbling up is correct — the caller intentionally tore the turn down
    and the agent must not keep spinning.
  - Server-side soft errors (CallToolResult{ isError: true }) and the
    happy path are unchanged.

The agent loop's MaxSteps cap already bounds the worst case for a
permanently broken MCP server, so there is no risk of unbounded retries.

Side effect: extracted a tiny mcpExecutor interface for the one method the
adapter uses (ExecuteTool), purely so the adapter is unit-testable in
isolation without standing up a full MCPToolManager + connection pool.

Behavior change note for downstream consumers: code that relied on
host.PromptResult / Stream returning a Go error containing
"mcp tool execution failed" will no longer see those errors — the
failure information is now in the assistant's final response (or in the
OnAfterToolResult / OnToolResult hooks, where IsError will be true).
Context cancellation continues to surface as an error from those calls
as before.

Co-authored-by: space_cowboy <space_cowboy@mark3labs.com>
2026-05-13 20:12:31 +03:00
Ed Zynda 35b9360d64 feat(ui): autocomplete /skill:<name> slash commands
- register loaded skills into the input autocomplete under category
  "Skills" with HasArgs so Enter populates "/skill:name " instead of
  auto-submitting, leaving room for trailing args
- prefix descriptions with [project] or [user] to disambiguate
  colliding skill names across sources
- extend refreshSkillItems to prune & re-add Skills entries on
  ContentReloadEvent, matching the pattern used for prompt templates
  and MCP prompts
- add Description field to ui.SkillItem and populate it from
  kit.Skill.Description in both initial build and hot-reload paths
2026-05-13 15:35:07 +03:00
Ed Zynda 1b8373e133 cleanup 2026-05-12 13:30:30 +03:00
Ed Zynda 1a5e4ce7c5 Merge pull request #29 from mark3labs/fix/27-queued-messages-after-compact
test(app): cover steer-drain branch of releaseBusyAfterCompact
2026-05-08 13:11:45 +03:00
Ed Zynda 8823977612 test(app): cover steer-drain branch of releaseBusyAfterCompact
- Add unexported steerDrainFn test seam on App so unit tests can
  inject fake steer items without standing up a full *kit.Kit
  (Options.Kit is a concrete struct, not an interface).
- releaseBusyAfterCompact now prefers the seam over Kit.DrainSteer
  via a small switch; production behaviour is unchanged when the
  field is nil.
- Add TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue, which
  pre-populates both fake steer items and ordinary queue prompts,
  invokes releaseBusyAfterCompact, and asserts the first dispatched
  prompt is the steer item — proving steer messages retain 'act now'
  priority and that drainQueue is actually launched (the bug from
  #27).
2026-05-08 12:18:52 +03:00
Ed Zynda 24e2ea111c Merge pull request #28 from mark3labs/fix/27-queued-messages-after-compact
fix(app): flush queued messages after /compact completes (#27)
2026-05-08 12:16:28 +03:00
Ed Zynda 31ea80ec4f fix(app): flush queued messages after /compact completes (#27)
- Add releaseBusyAfterCompact() shared deferred tail used by both
  CompactConversation and CompactAsync. It drains the SDK steer
  channel, splices steer items in front of any queued prompts, and
  hands off to drainQueue so messages received during compaction
  are dispatched automatically once compaction finishes.
- Previously, busy was simply cleared on completion and the queue
  sat idle until the user submitted another prompt, which then
  flushed everything together.
- Honor the closed flag so a teardown during compaction discards
  pending items instead of spawning drainQueue against a torn-down
  App.
- Add regression tests covering the queued-flush, idle-empty, and
  closed-during-compact paths.

Fixes #27
2026-05-08 11:30:26 +03:00
Ed Zynda 99f2680c2e Merge pull request #26 from mark3labs/fix/25-system-prompt-file-path
fix(kit): resolve system-prompt file path before PromptBuilder (#25)
2026-05-08 10:54:09 +03:00
Ed Zynda da7e05eb87 fix(cmd): nil-guard CLI when emitting system-prompt notice in quiet mode
SetupCLIForNonInteractive returns nil when --quiet is active, matching
the pre-existing nil checks elsewhere in the same block (e.g. the
buffered debug-message branch). Without this guard the new
'System Prompt loaded' notice panicked on quiet, non-interactive runs.

Discovered via tmux smoke test of the #25 fix.
2026-05-08 10:44:01 +03:00
Ed Zynda a95714a22d fix(kit): resolve system-prompt file path before PromptBuilder (#25)
When system-prompt was a file path (via --system-prompt, config entry,
or SDK Options.SystemPrompt), the path string itself was used as the
base prompt because config.LoadSystemPrompt only ran later in
BuildProviderConfig — by which point viper had been overwritten with
the path-augmented composed text. The LLM received the path instead of
the prompt contents.

- Call config.LoadSystemPrompt on the raw viper value in New() before
  PromptBuilder composes runtime context (AGENTS.md / skills / date).
- Add HasCustomSystemPrompt() and GetSystemPromptSource() so SDK callers
  can inspect prompt state without reaching into viper.
- Display 'System Prompt loaded: <source>' at startup in CLI and TUI
  modes, paralleling the per-server 'MCP server loaded' notice.
- Add regression tests covering both file-path and inline prompt paths.

Fixes #25
2026-05-08 10:39:14 +03:00
Ed Zynda c4a2b0f1a3 Merge pull request #24 from mark3labs/audit-cleanup
refactor: remove dead code and consolidate duplicated extension wiring
2026-05-07 17:46:49 +03:00
Ed Zynda 2016570e2d test: add docstrings to rewritten tests and use t.Setenv
Addresses two CodeRabbit feedback items on PR #24:

* Docstring coverage warning (was 57.14%, threshold 80%): adds godoc
  comments to the four test functions added or substantially rewritten
  in this PR — TestLoadAndSaveManifest, TestAddAndRemoveFromManifest,
  TestFindInManifest, TestHighlightFileTokensInjectsANSI.
* Quick-win nitpick: replaces the manual os.Setenv/os.Unsetenv +
  defer pattern in TestFindInManifest with t.Setenv, which restores
  the env var automatically on cleanup even on panic or t.Fatal.

go test -race ./... still passes.
2026-05-07 13:16:03 +03:00
Ed Zynda d557f4b870 fix(cmd): wrap bare fn refs in extensions.Context as closures
Per AGENTS.md 'Yaegi function field bug', named function/method
references assigned to extensions.Context fields return zero values
across the interpreter boundary. The two SetContext literals in
runNormalMode (now consolidated in buildInteractiveExtensionContext)
inherited 9 bare references that need to be anonymous closure literals:

  PrintBlock, GetChildren, GetAvailableSkills, ParseTemplate,
  RenderTemplate, ParseArguments, SimpleParseArguments,
  ResolveModelChain, CheckModelAvailable

Each is now wrapped as 'func(args) ret { return <orig>(args) }'.
Behaviour unchanged in regular Go; Yaegi extensions that consume these
fields will now see callable closures instead of zero values.

Verified with go test -race ./...
2026-05-07 13:00:06 +03:00
Ed Zynda 65054fe3db gofmt trailing-blank-line cleanup after dead-code removal 2026-05-07 12:34:29 +03:00
Ed Zynda 97d2246375 drop orphan testTypography helper from render tests
The TestUserBlockHighlightsFileTokens test was rewritten to call
HighlightFileTokens directly (UserBlock was deleted in the dead-code
sweep). That left testTypography with no callers, so staticcheck U1000
flagged it.
2026-05-07 12:31:55 +03:00
Ed Zynda 1e12505741 remove unused style.BaseStyle helper 2026-05-07 12:29:59 +03:00
Ed Zynda 6755597c9b extract buildInteractiveExtensionContext helper
The previous runNormalMode contained two nearly-identical 400-line
extensions.Context literal expressions:

  * the startup-time literal (cmd/root.go:853-1307) that buffered
    Print* calls into startupExtensionMessages
  * the runtime literal (cmd/root.go:1311-1605) that routed Print*
    through appInstance.PrintFromExtension

Every other field — Compact, SendMultimodalMessage, the four prompt
factories, all 25+ data-access fields, all four bridge phases — was
duplicated byte-for-byte. Maintainers had to remember to update both
copies whenever an extension Context field was added.

cmd/root.go is now 1463 lines (was 2225). The new helper lives in
cmd/extension_context.go (455 lines, mostly the closures verbatim) and
returns an extensions.Context with every field populated except
Print/PrintInfo/PrintError, which each call site sets afterwards to
match its phase. This preserves AGENTS.md's 'function field bug'
guarantee — all assignments remain anonymous closure literals.

Output of 'kit --version' / 'kit --help' unchanged. Full test suite
passes.
2026-05-07 12:28:18 +03:00
Ed Zynda 45689cb30d extract duplicated subagent + event conversion to internal/extbridge
The same ~40-line block — building a kit.SubagentConfig, wrapping
OnEvent through sdkEventToSubagentEvent, calling kitInstance.Subagent,
and translating the SDK result into extensions.SubagentResult — was
copy-pasted three times:

  * cmd/root.go (interactive TUI Context, line 1148)
  * cmd/root.go (post-SessionStart runtime Context, line 1446)
  * internal/acpserver/session.go (ACP server Context, line 154)

A separate sdkEventToSubagentEvent function was duplicated byte-for-byte
between cmd/root.go and internal/acpserver/session.go.

Both are now consolidated in a new internal/extbridge package which is
the only module-internal home that can legitimately import both
pkg/kit/ (the public SDK) and internal/extensions/. cmd/ and
internal/acpserver/ both import it, so SDK-event-to-extension-event
schema changes only have one site to update.

Also fixes pkg/kit/events.go godoc comment that named the underlying
LLM library, per AGENTS.md 'No Dependency Name Leakage' rule for
exported SDK symbols.

go test -race ./... passes.
2026-05-07 12:23:15 +03:00
Ed Zynda 78570d4188 remove dead code identified by audit
Removes ~600 lines of unreferenced code surfaced by deadcode + manual
audit (none of it reachable from production code paths or test setup):

- internal/models/pool.go: ProviderPool was never wired into kitsetup
  or the agent; the global pool singleton had zero callers.
- internal/ui/debug_logger.go: CLIDebugLogger was unreachable; debug
  routing goes through internal/tools/buffered_logger.go instead.
- internal/ui/tool_approval_input.go: tea.Model never instantiated;
  approvals are handled inline in model.go.
- internal/ui/cli.go: DisplayAssistantMessage / DisplayCancellation /
  GetDebugLogger had zero callers (the *WithModel variant is what
  event_handler.go uses).
- internal/ui/style/enhanced.go: Style{Card,Header,Subheader,Muted,
  Success,Error,Warning,Info} + Create{Separator,ProgressBar} — none
  used. CreateBadge stays (used by model.go).
- internal/ui/style/themes.go: RefreshThemeRegistry — never called.
- internal/ui/block_renderer.go: With{FullWidth,MarginTop,Padding{Left,
  Right},Background,Foreground,Width} — option helpers nobody calls.
- internal/ui/render/blocks.go: UserBlock, ToolBlock — replaced by
  inline rendering elsewhere; the test for UserBlock was rewritten to
  directly exercise HighlightFileTokens (which is what the test really
  cared about).
- internal/ui/commands/commands.go: GetAllCommandNames — no callers.
- internal/ui/message_items.go: NewTextMessageItem,
  NewSystemMessageItem + the entire SystemMessageItem type — model.go
  uses NewStyledMessageItem instead.
- internal/prompts/loader.go: Deduplicate — the loader does dedup
  internally; standalone helper was unused.
- internal/models/cache_options.go: mergeProviderOptions + its
  test-only consumer.
- internal/extensions/installer.go: Installer.GetInstalledPackages —
  intended for a 'kit ext list' command that was never built.
- internal/extensions/manifest.go: saveManifestToScope,
  saveManifestToPath, GetGlobalManifest, GetProjectManifest,
  addEntryToManifest, removeEntryFromManifest — package-level
  duplicates of *Installer methods. Tests rewritten to exercise the
  live Installer methods instead, which fixes a latent path-resolution
  inconsistency between manifestPathForScope and Installer.manifestPath
  (the former hard-coded paths, the latter respects projectGitRoot).
- internal/extensions/subagent.go: SpawnSubagent + helpers
  (generateSubagentID, findKitBinary, subagentJSONOutput). The
  subprocess-spawn implementation is unreachable; production code
  routes through kit.Kit.Subagent (in-process). Types
  (SubagentConfig/Result/Handle/etc.) and the SubagentHandle methods
  remain because they are exposed to extensions via Yaegi symbols and
  the Context.SpawnSubagent field.
- cmd/root.go: LoadConfigWithEnvSubstitution — one-line wrapper around
  kit.LoadConfigWithEnvSubstitution with zero callers.

go test -race ./... passes.
2026-05-07 12:20:08 +03:00
Ed Zynda 7cf38b37ee Merge pull request #23 from mark3labs/fix/18-windows-session-dir-colon
fix(session): strip illegal characters from windows session dir (#18)
2026-05-07 11:13:34 +03:00
42 changed files with 2398 additions and 2218 deletions
+146
View File
@@ -0,0 +1,146 @@
---
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
---
Perform a comprehensive **read-only** audit of this repository and report
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
hints from the user: $@
## Scope
If the user supplied focus hints above (a package path, a subsystem name, a
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
audit the whole repo, prioritising the highest-traffic packages first
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
## Steps
1. **Map the repo first**:
- `ls` / `find` the top-level layout and list every Go package
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
extension surface)
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
(e.g. "no dependency name leakage", "UI never imports extensions
directly") — these define what counts as a violation
2. **Hunt for dead code**:
- Run `go vet ./...` and capture warnings
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
with zero non-test references inside the module are suspects
- Check for unreferenced files, `// TODO: remove` markers, commented-out
blocks, and `_ = x` discard patterns
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
them and include their output verbatim
- **Do not delete anything** — list candidates with file:line and a
confidence level (high / medium / low)
3. **Find unnecessary duplication**:
- Look for near-identical function bodies, struct shapes, or switch
statements across packages — `grep` for repeated function signatures
and copy-pasted string literals / error messages is a fast first pass
- Distinguish *coincidental* duplication (two things that happen to look
alike but evolve independently) from *unnecessary* duplication (same
intent, drifting in lockstep) — only flag the latter
- For each cluster, propose where the extracted helper should live
(which package, which file) and whether it crosses a boundary
4. **Check concerns / boundary violations**:
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
in exported signatures, and for dependency-name leakage in exported
names / godoc (e.g. library jargon appearing in `LLM*` types)
- **UI ↔ extensions**: grep `internal/ui/` for any import of
`internal/extensions/` — per AGENTS.md the UI must not import
extensions directly; converters in `cmd/root.go` should bridge them
- **cmd vs internal**: business logic living in `cmd/` that should be
in `internal/` (and vice versa)
- **Cyclic risk**: packages that import each other transitively or that
reach across sibling boundaries unexpectedly
- For each violation, cite the offending import / signature with
file:line
5. **Spot refactor opportunities**:
- Long functions (>80 lines) doing multiple unrelated things
- Deeply nested conditionals that flatten well with early returns
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
that could become helpers — but only where the wrapping context is
genuinely uniform
- Structs with too many fields that hint at split responsibilities
- Exported APIs that would be cleaner with options structs / functional
options
- Tests that share setup boilerplate ripe for a helper
- Flag each with: location, current shape (1-2 lines), proposed shape
(1-2 lines), and estimated risk (low / medium / high)
6. **Cross-check against project rules**:
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
findings contradicts the documented gotchas (Yaegi interface ban,
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
"refactor" would reintroduce a known pitfall, drop it from the report
and note why
7. **Write the report** as your final message (do not write it to disk)
structured as:
```
# Code Audit Report
## Summary
- N dead-code candidates
- N duplication clusters
- N boundary violations
- N refactor opportunities
## Dead Code
### High confidence
- path/to/file.go:LINE — symbol — reason
### Medium confidence
...
## Duplication
### Cluster: <short name>
- Sites: file:line, file:line, …
- Suggested home: package/path
- Notes: …
## Boundary Violations
- Rule: <which rule from AGENTS.md / project convention>
- Offender: file:line
- Fix sketch: …
## Refactor Opportunities
- Location: file:line
- Current: …
- Proposed: …
- Risk: low/medium/high
- Why it's worth it: …
## Suggested Next Steps
1. …
2. …
```
8. **End the report with an explicit reminder** that no files were modified,
and recommend the user pick the highest-leverage items to act on
manually (or via a follow-up `/fix-issue` style prompt) rather than
running a sweeping refactor.
## Guidelines
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
straight to it
- **Be honest about confidence**: false positives in a code audit are
expensive — prefer "medium confidence, worth a look" over confidently
wrong claims
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
anything that's purely stylistic unless it directly causes one of the
four issue categories above
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
`vendor/`) and obvious third-party copies
- **Don't propose architectural rewrites** — stay within the existing
shape of the repo and recommend incremental, reviewable changes
+473
View File
@@ -0,0 +1,473 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
"github.com/spf13/viper"
"golang.org/x/term"
"github.com/mark3labs/kit/internal/app"
"github.com/mark3labs/kit/internal/auth"
"github.com/mark3labs/kit/internal/extbridge"
"github.com/mark3labs/kit/internal/extensions"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/ui"
kit "github.com/mark3labs/kit/pkg/kit"
)
// extensionContextDeps groups the runtime dependencies needed to wire up
// an extensions.Context for the interactive TUI mode.
type extensionContextDeps struct {
ctx context.Context
cwd string
modelName string
interactive bool
kitInstance *kit.Kit
appInstance *app.App
usageTracker *ui.UsageTracker
}
// buildInteractiveExtensionContext returns an extensions.Context with every
// field except Print / PrintInfo / PrintError populated. Callers must set
// the three print routes appropriately for their phase (startup buffering
// vs. live runtime routing).
//
// This consolidates two near-identical 400-line literal expressions that
// previously appeared inline in runNormalMode.
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
kitInstance := deps.kitInstance
appInstance := deps.appInstance
usageTracker := deps.usageTracker
ctx := deps.ctx
return extensions.Context{
CWD: deps.cwd,
Model: deps.modelName,
Interactive: deps.interactive,
PrintBlock: func(opts extensions.PrintBlockOpts) {
appInstance.PrintBlockFromExtension(opts)
},
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveWidget: func(id string) {
kitInstance.Extensions().RemoveWidget(id)
go appInstance.NotifyWidgetUpdate()
},
SetHeader: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetHeader(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveHeader: func() {
kitInstance.Extensions().RemoveHeader()
go appInstance.NotifyWidgetUpdate()
},
SetFooter: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetFooter(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveFooter: func() {
kitInstance.Extensions().RemoveFooter()
go appInstance.NotifyWidgetUpdate()
},
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "select",
Message: config.Message,
Options: config.Options,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptSelectResult{Cancelled: true}
}
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
},
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
ch := make(chan app.PromptResponse, 1)
def := "false"
if config.DefaultValue {
def = "true"
}
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "confirm",
Message: config.Message,
Default: def,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptConfirmResult{Cancelled: true}
}
return extensions.PromptConfirmResult{Value: resp.Confirmed}
},
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "input",
Message: config.Message,
Placeholder: config.Placeholder,
Default: config.Default,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptInputResult{Cancelled: true}
}
return extensions.PromptInputResult{Value: resp.Value}
},
SetUIVisibility: func(v extensions.UIVisibility) {
kitInstance.Extensions().SetUIVisibility(v)
go appInstance.NotifyWidgetUpdate()
},
GetContextStats: func() extensions.ContextStats {
s := kitInstance.GetContextStats()
return extensions.ContextStats{
EstimatedTokens: s.EstimatedTokens,
ContextLimit: s.ContextLimit,
UsagePercent: s.UsagePercent,
MessageCount: s.MessageCount,
}
},
SetEditor: func(config extensions.EditorConfig) {
kitInstance.Extensions().SetEditor(config)
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
// deadlocks if called synchronously from inside BubbleTea's
// Update() handler. All call sites use go-routines uniformly.
go appInstance.NotifyWidgetUpdate()
},
ResetEditor: func() {
kitInstance.Extensions().ResetEditor()
go appInstance.NotifyWidgetUpdate()
},
GetMessages: func() []extensions.SessionMessage {
return kitInstance.Extensions().GetSessionMessages()
},
GetSessionPath: func() string {
return kitInstance.GetSessionPath()
},
AppendEntry: func(entryType string, data string) (string, error) {
return kitInstance.Extensions().AppendEntry(entryType, data)
},
GetEntries: func(entryType string) []extensions.ExtensionEntry {
return kitInstance.Extensions().GetEntries(entryType)
},
SetEditorText: func(text string) {
appInstance.SetEditorTextFromExtension(text)
},
SetStatus: func(key string, text string, priority int) {
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
Key: key,
Text: text,
Priority: priority,
})
go appInstance.NotifyWidgetUpdate()
},
RemoveStatus: func(key string) {
kitInstance.Extensions().RemoveStatus(key)
go appInstance.NotifyWidgetUpdate()
},
GetOption: func(name string) string {
return kitInstance.Extensions().GetOption(name)
},
SetOption: func(name string, value string) {
kitInstance.Extensions().SetOption(name, value)
},
SetModel: func(modelString string) error {
// Capture previous model for the ModelChange event.
previousModel := kitInstance.Extensions().GetContext().Model
err := kitInstance.SetModel(context.Background(), modelString)
if err != nil {
return err
}
// Notify TUI so it updates model in status bar.
p, m, _ := models.ParseModelString(modelString)
appInstance.NotifyModelChanged(p, m)
// Update the context's Model field so handlers see it.
kitInstance.Extensions().UpdateContextModel(modelString)
// Fire OnModelChange event to extensions.
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
// Update usage tracker with new model info for correct token counting.
if usageTracker != nil {
newProvider, newModel, _ := models.ParseModelString(modelString)
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
registry := models.GetGlobalRegistry()
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
// Check OAuth status for Anthropic models
isOAuth := false
if newProvider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
}
}
}
return nil
},
GetAvailableModels: func() []extensions.ModelInfoEntry {
return kitInstance.GetAvailableModels()
},
EmitCustomEvent: func(name string, data string) {
kitInstance.Extensions().EmitCustomEvent(name, data)
},
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SuspendTUI: func(callback func()) error {
return appInstance.SuspendTUI(callback)
},
RenderMessage: func(rendererName, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
if renderer == nil || renderer.Render == nil {
appInstance.PrintFromExtension("", content)
return
}
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
if w == 0 {
w = 80
}
rendered := renderer.Render(content, w)
appInstance.PrintFromExtension("", rendered)
},
ReloadExtensions: func() error {
err := kitInstance.Extensions().Reload()
if err != nil {
return err
}
// Notify TUI that widgets/status/commands may have changed.
go appInstance.NotifyWidgetUpdate()
return nil
},
GetAllTools: func() []extensions.ToolInfo {
return kitInstance.Extensions().GetToolInfos()
},
SetActiveTools: func(names []string) {
kitInstance.Extensions().SetActiveTools(names)
},
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
ui.RegisterThemeFromConfig(name,
tc(config.Primary), tc(config.Secondary),
tc(config.Success), tc(config.Warning),
tc(config.Error), tc(config.Info),
tc(config.Text), tc(config.Muted),
tc(config.VeryMuted), tc(config.Background),
tc(config.Border), tc(config.MutedBorder),
tc(config.System), tc(config.Tool),
tc(config.Accent), tc(config.Highlight),
tc(config.MdHeading), tc(config.MdLink),
tc(config.MdKeyword), tc(config.MdString),
tc(config.MdNumber), tc(config.MdComment),
)
},
SetTheme: func(name string) error {
return ui.ApplyTheme(name)
},
ListThemes: func() []string {
return ui.ListThemes()
},
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
ch := make(chan app.OverlayResponse, 1)
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
Title: config.Title,
Content: config.Content.Text,
Markdown: config.Content.Markdown,
BorderColor: config.Style.BorderColor,
Background: config.Style.Background,
Width: config.Width,
MaxHeight: config.MaxHeight,
Anchor: string(config.Anchor),
Actions: config.Actions,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
return extensions.OverlayResult{
Action: resp.Action,
Index: resp.Index,
}
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
return extbridge.SpawnSubagent(ctx, kitInstance, config)
},
// -------------------------------------------------------------------
// Tree Navigation API
// -------------------------------------------------------------------
GetTreeNode: func(entryID string) *extensions.TreeNode {
node := kitInstance.GetTreeNode(entryID)
if node == nil {
return nil
}
return &extensions.TreeNode{
ID: node.ID,
ParentID: node.ParentID,
Type: node.Type,
Role: node.Role,
Content: node.Content,
Model: node.Model,
Provider: node.Provider,
Timestamp: node.Timestamp,
Children: node.Children,
}
},
GetCurrentBranch: func() []extensions.TreeNode {
nodes := kitInstance.GetCurrentBranch()
result := make([]extensions.TreeNode, len(nodes))
for i, n := range nodes {
result[i] = extensions.TreeNode{
ID: n.ID,
ParentID: n.ParentID,
Type: n.Type,
Role: n.Role,
Content: n.Content,
Model: n.Model,
Provider: n.Provider,
Timestamp: n.Timestamp,
Children: n.Children,
}
}
return result
},
GetChildren: func(parentID string) []string {
return kitInstance.GetChildren(parentID)
},
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
err := kitInstance.NavigateTo(entryID)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
SummarizeBranch: func(fromID, toID string) string {
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
return summary
},
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
err := kitInstance.CollapseBranch(fromID, toID, summary)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
// -------------------------------------------------------------------
// Skill Loading API
// -------------------------------------------------------------------
LoadSkill: func(path string) (*extensions.Skill, string) {
s, err := kitInstance.LoadSkillForExtension(path)
return s, err
},
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
return kitInstance.LoadSkillsFromDirForExtension(dir)
},
DiscoverSkills: func() extensions.SkillLoadResult {
skills := kitInstance.DiscoverSkillsForExtension()
return extensions.SkillLoadResult{Skills: skills}
},
InjectSkillAsContext: func(skillName string) string {
skills := kitInstance.DiscoverSkillsForExtension()
for _, s := range skills {
if s.Name == skillName {
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
}
return fmt.Sprintf("skill not found: %s", skillName)
},
InjectRawSkillAsContext: func(path string) string {
s, err := kitInstance.LoadSkillForExtension(path)
if err != "" {
return err
}
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
},
GetAvailableSkills: func() []extensions.Skill {
return kitInstance.DiscoverSkillsForExtension()
},
// -------------------------------------------------------------------
// Template Parsing API
// -------------------------------------------------------------------
ParseTemplate: func(name, content string) extensions.PromptTemplate {
return kit.ParseTemplate(name, content)
},
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
return kit.RenderTemplate(tpl, vars)
},
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
return kit.ParseArguments(input, pattern)
},
SimpleParseArguments: func(input string, count int) []string {
return kit.SimpleParseArguments(input, count)
},
EvaluateModelConditional: func(condition string) bool {
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
},
RenderWithModelConditionals: func(content string) string {
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
},
// -------------------------------------------------------------------
// Model Resolution API
// -------------------------------------------------------------------
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
return kit.ResolveModelChain(preferences)
},
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
return kit.GetModelCapabilities(model)
},
CheckModelAvailable: func(model string) bool {
return kit.CheckModelAvailable(model)
},
GetCurrentProvider: func() string {
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
},
GetCurrentModelID: func() string {
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
},
}
}
+55 -799
View File
@@ -169,12 +169,6 @@ func InitConfig() {
models.ReloadGlobalRegistry()
}
// LoadConfigWithEnvSubstitution loads a config file with environment variable
// substitution. Delegates to the SDK implementation.
func LoadConfigWithEnvSubstitution(configPath string) error {
return kit.LoadConfigWithEnvSubstitution(configPath)
}
// adaptiveOrDefault converts a config.AdaptiveColor to a resolved color.Color,
// falling back to fallback when both Light and Dark are empty.
func adaptiveOrDefault(ac config.AdaptiveColor, fallback color.Color) color.Color {
@@ -790,6 +784,16 @@ func runNormalMode(ctx context.Context) error {
}
defer func() { _ = kitInstance.Close() }()
// Build the "System Prompt loaded" notice shown at startup, paralleling the
// per-server "MCP server loaded" notifications so users can confirm that a
// configured prompt file was found and applied.
var systemPromptLoadedMsg string
if kitInstance.HasCustomSystemPrompt() {
if src := kitInstance.GetSystemPromptSource(); src != "" {
systemPromptLoadedMsg = "System Prompt loaded: " + src
}
}
// Extract metadata for display and app options.
parsedProvider, modelName, serverNames, toolNames, mcpToolCount, extensionToolCount := CollectAgentMetadata(kitInstance, mcpConfig)
@@ -807,6 +811,9 @@ func runNormalMode(ctx context.Context) error {
}
DisplayDebugConfig(cli, kitInstance, mcpConfig, parsedProvider)
if systemPromptLoadedMsg != "" && cli != nil {
cli.DisplayInfo(systemPromptLoadedMsg)
}
}
// Load existing messages from resumed/continued sessions.
@@ -846,763 +853,49 @@ func runNormalMode(ctx context.Context) error {
// Buffer for extension messages during startup (printed after startup banner).
var startupExtensionMessages []string
if systemPromptLoadedMsg != "" {
startupExtensionMessages = append(startupExtensionMessages, systemPromptLoadedMsg)
}
// Set up extension context and emit SessionStart.
if kitInstance.Extensions().HasExtensions() {
cwd, _ := os.Getwd()
kitInstance.Extensions().SetContext(extensions.Context{
CWD: cwd,
Model: modelName,
Interactive: positionalPrompt == "",
Print: func(text string) {
// Capture messages during startup, print after startup banner.
startupExtensionMessages = append(startupExtensionMessages, text)
},
PrintInfo: func(text string) {
startupExtensionMessages = append(startupExtensionMessages, text)
},
PrintError: func(text string) {
startupExtensionMessages = append(startupExtensionMessages, text)
},
PrintBlock: appInstance.PrintBlockFromExtension,
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveWidget: func(id string) {
kitInstance.Extensions().RemoveWidget(id)
go appInstance.NotifyWidgetUpdate()
},
SetHeader: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetHeader(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveHeader: func() {
kitInstance.Extensions().RemoveHeader()
go appInstance.NotifyWidgetUpdate()
},
SetFooter: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetFooter(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveFooter: func() {
kitInstance.Extensions().RemoveFooter()
go appInstance.NotifyWidgetUpdate()
},
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "select",
Message: config.Message,
Options: config.Options,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptSelectResult{Cancelled: true}
}
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
},
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
ch := make(chan app.PromptResponse, 1)
def := "false"
if config.DefaultValue {
def = "true"
}
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "confirm",
Message: config.Message,
Default: def,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptConfirmResult{Cancelled: true}
}
return extensions.PromptConfirmResult{Value: resp.Confirmed}
},
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "input",
Message: config.Message,
Placeholder: config.Placeholder,
Default: config.Default,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptInputResult{Cancelled: true}
}
return extensions.PromptInputResult{Value: resp.Value}
},
SetUIVisibility: func(v extensions.UIVisibility) {
kitInstance.Extensions().SetUIVisibility(v)
go appInstance.NotifyWidgetUpdate()
},
GetContextStats: func() extensions.ContextStats {
s := kitInstance.GetContextStats()
return extensions.ContextStats{
EstimatedTokens: s.EstimatedTokens,
ContextLimit: s.ContextLimit,
UsagePercent: s.UsagePercent,
MessageCount: s.MessageCount,
}
},
SetEditor: func(config extensions.EditorConfig) {
kitInstance.Extensions().SetEditor(config)
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
// deadlocks if called synchronously from inside BubbleTea's
// Update() handler. All call sites use go-routines uniformly.
go appInstance.NotifyWidgetUpdate()
},
ResetEditor: func() {
kitInstance.Extensions().ResetEditor()
go appInstance.NotifyWidgetUpdate()
},
GetMessages: func() []extensions.SessionMessage {
return kitInstance.Extensions().GetSessionMessages()
},
GetSessionPath: func() string {
return kitInstance.GetSessionPath()
},
AppendEntry: func(entryType string, data string) (string, error) {
return kitInstance.Extensions().AppendEntry(entryType, data)
},
GetEntries: func(entryType string) []extensions.ExtensionEntry {
return kitInstance.Extensions().GetEntries(entryType)
},
SetEditorText: func(text string) {
appInstance.SetEditorTextFromExtension(text)
},
SetStatus: func(key string, text string, priority int) {
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
Key: key,
Text: text,
Priority: priority,
})
go appInstance.NotifyWidgetUpdate()
},
RemoveStatus: func(key string) {
kitInstance.Extensions().RemoveStatus(key)
go appInstance.NotifyWidgetUpdate()
},
GetOption: func(name string) string {
return kitInstance.Extensions().GetOption(name)
},
SetOption: func(name string, value string) {
kitInstance.Extensions().SetOption(name, value)
},
SetModel: func(modelString string) error {
// Capture previous model for the ModelChange event.
previousModel := kitInstance.Extensions().GetContext().Model
err := kitInstance.SetModel(context.Background(), modelString)
if err != nil {
return err
}
// Notify TUI so it updates model in status bar.
p, m, _ := models.ParseModelString(modelString)
appInstance.NotifyModelChanged(p, m)
// Update the context's Model field so handlers see it.
kitInstance.Extensions().UpdateContextModel(modelString)
// Fire OnModelChange event to extensions.
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
// Update usage tracker with new model info for correct token counting.
if usageTracker != nil {
newProvider, newModel, _ := models.ParseModelString(modelString)
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
registry := models.GetGlobalRegistry()
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
// Check OAuth status for Anthropic models
isOAuth := false
if newProvider == "anthropic" {
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
if err == nil && strings.HasPrefix(source, "stored OAuth") {
isOAuth = true
}
}
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
}
}
}
return nil
},
GetAvailableModels: func() []extensions.ModelInfoEntry {
return kitInstance.GetAvailableModels()
},
EmitCustomEvent: func(name string, data string) {
kitInstance.Extensions().EmitCustomEvent(name, data)
},
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SuspendTUI: func(callback func()) error {
return appInstance.SuspendTUI(callback)
},
RenderMessage: func(rendererName, content string) {
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
if renderer == nil || renderer.Render == nil {
appInstance.PrintFromExtension("", content)
return
}
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
if w == 0 {
w = 80
}
rendered := renderer.Render(content, w)
appInstance.PrintFromExtension("", rendered)
},
ReloadExtensions: func() error {
err := kitInstance.Extensions().Reload()
if err != nil {
return err
}
// Notify TUI that widgets/status/commands may have changed.
go appInstance.NotifyWidgetUpdate()
return nil
},
GetAllTools: func() []extensions.ToolInfo {
return kitInstance.Extensions().GetToolInfos()
},
SetActiveTools: func(names []string) {
kitInstance.Extensions().SetActiveTools(names)
},
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
ui.RegisterThemeFromConfig(name,
tc(config.Primary), tc(config.Secondary),
tc(config.Success), tc(config.Warning),
tc(config.Error), tc(config.Info),
tc(config.Text), tc(config.Muted),
tc(config.VeryMuted), tc(config.Background),
tc(config.Border), tc(config.MutedBorder),
tc(config.System), tc(config.Tool),
tc(config.Accent), tc(config.Highlight),
tc(config.MdHeading), tc(config.MdLink),
tc(config.MdKeyword), tc(config.MdString),
tc(config.MdNumber), tc(config.MdComment),
)
},
SetTheme: func(name string) error {
return ui.ApplyTheme(name)
},
ListThemes: func() []string {
return ui.ListThemes()
},
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
ch := make(chan app.OverlayResponse, 1)
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
Title: config.Title,
Content: config.Content.Text,
Markdown: config.Content.Markdown,
BorderColor: config.Style.BorderColor,
Background: config.Style.Background,
Width: config.Width,
MaxHeight: config.MaxHeight,
Anchor: string(config.Anchor),
Actions: config.Actions,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
return extensions.OverlayResult{
Action: resp.Action,
Index: resp.Index,
}
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
// In-process subagent via SDK.
sdkCfg := kit.SubagentConfig{
Prompt: config.Prompt,
Model: config.Model,
SystemPrompt: config.SystemPrompt,
Timeout: config.Timeout,
NoSession: config.NoSession,
}
// Bridge SDK events to extension SubagentEvents.
if config.OnEvent != nil {
sdkCfg.OnEvent = func(e kit.Event) {
se := sdkEventToSubagentEvent(e)
if se.Type != "" {
config.OnEvent(se)
}
}
}
result, err := kitInstance.Subagent(ctx, sdkCfg)
if result == nil {
return nil, &extensions.SubagentResult{Error: err}, err
}
extResult := &extensions.SubagentResult{
Response: result.Response,
Error: err,
SessionID: result.SessionID,
Elapsed: result.Elapsed,
}
if result.Usage != nil {
extResult.Usage = &extensions.SubagentUsage{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
}
}
return nil, extResult, err
},
// -------------------------------------------------------------------------
// Tree Navigation API (Phase 1 Bridge)
// -------------------------------------------------------------------------
GetTreeNode: func(entryID string) *extensions.TreeNode {
node := kitInstance.GetTreeNode(entryID)
if node == nil {
return nil
}
return &extensions.TreeNode{
ID: node.ID,
ParentID: node.ParentID,
Type: node.Type,
Role: node.Role,
Content: node.Content,
Model: node.Model,
Provider: node.Provider,
Timestamp: node.Timestamp,
Children: node.Children,
}
},
GetCurrentBranch: func() []extensions.TreeNode {
nodes := kitInstance.GetCurrentBranch()
result := make([]extensions.TreeNode, len(nodes))
for i, n := range nodes {
result[i] = extensions.TreeNode{
ID: n.ID,
ParentID: n.ParentID,
Type: n.Type,
Role: n.Role,
Content: n.Content,
Model: n.Model,
Provider: n.Provider,
Timestamp: n.Timestamp,
Children: n.Children,
}
}
return result
},
GetChildren: kitInstance.GetChildren,
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
err := kitInstance.NavigateTo(entryID)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
SummarizeBranch: func(fromID, toID string) string {
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
return summary
},
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
err := kitInstance.CollapseBranch(fromID, toID, summary)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
// -------------------------------------------------------------------------
// Skill Loading API (Phase 2 Bridge)
// -------------------------------------------------------------------------
LoadSkill: func(path string) (*extensions.Skill, string) {
s, err := kitInstance.LoadSkillForExtension(path)
return s, err
},
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
return kitInstance.LoadSkillsFromDirForExtension(dir)
},
DiscoverSkills: func() extensions.SkillLoadResult {
skills := kitInstance.DiscoverSkillsForExtension()
return extensions.SkillLoadResult{Skills: skills}
},
InjectSkillAsContext: func(skillName string) string {
// Find skill by name
skills := kitInstance.DiscoverSkillsForExtension()
for _, s := range skills {
if s.Name == skillName {
// Inject via SendMessage as a system context message
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
}
return fmt.Sprintf("skill not found: %s", skillName)
},
InjectRawSkillAsContext: func(path string) string {
s, err := kitInstance.LoadSkillForExtension(path)
if err != "" {
return err
}
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
},
GetAvailableSkills: kitInstance.DiscoverSkillsForExtension,
// -------------------------------------------------------------------------
// Template Parsing API (Phase 3 Bridge)
// -------------------------------------------------------------------------
ParseTemplate: kit.ParseTemplate,
RenderTemplate: kit.RenderTemplate,
ParseArguments: kit.ParseArguments,
SimpleParseArguments: kit.SimpleParseArguments,
EvaluateModelConditional: func(condition string) bool {
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
},
RenderWithModelConditionals: func(content string) string {
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
},
// -------------------------------------------------------------------------
// Model Resolution API (Phase 4 Bridge)
// -------------------------------------------------------------------------
ResolveModelChain: kit.ResolveModelChain,
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
return kit.GetModelCapabilities(model)
},
CheckModelAvailable: kit.CheckModelAvailable,
GetCurrentProvider: func() string {
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
},
GetCurrentModelID: func() string {
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
},
extCtx := buildInteractiveExtensionContext(extensionContextDeps{
ctx: ctx,
cwd: cwd,
modelName: modelName,
interactive: positionalPrompt == "",
kitInstance: kitInstance,
appInstance: appInstance,
usageTracker: usageTracker,
})
extCtx.Print = func(text string) {
// Capture messages during startup, print after startup banner.
startupExtensionMessages = append(startupExtensionMessages, text)
}
extCtx.PrintInfo = func(text string) {
startupExtensionMessages = append(startupExtensionMessages, text)
}
extCtx.PrintError = func(text string) {
startupExtensionMessages = append(startupExtensionMessages, text)
}
kitInstance.Extensions().SetContext(extCtx)
kitInstance.Extensions().EmitSessionStart()
// Restore normal print functions for runtime use.
kitInstance.Extensions().SetContext(extensions.Context{
CWD: cwd,
Model: modelName,
Interactive: positionalPrompt == "",
Print: func(text string) { appInstance.PrintFromExtension("", text) },
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
PrintBlock: appInstance.PrintBlockFromExtension,
SendMessage: func(text string) { appInstance.Run(text) },
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
Abort: func() { appInstance.Abort() },
IsIdle: func() bool { return !appInstance.IsBusy() },
Compact: func(cfg extensions.CompactConfig) error {
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
},
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
parts := make([]kit.LLMFilePart, len(files))
for i, f := range files {
parts[i] = kit.LLMFilePart{
Filename: f.Filename,
Data: f.Data,
MediaType: f.MediaType,
}
}
appInstance.RunWithFiles(text, parts)
},
GetSessionUsage: func() extensions.SessionUsage {
if usageTracker == nil {
return extensions.SessionUsage{}
}
stats := usageTracker.GetSessionStats()
return extensions.SessionUsage{
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheReadTokens: stats.TotalCacheReadTokens,
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
TotalCost: stats.TotalCost,
RequestCount: stats.RequestCount,
}
},
Exit: func() { appInstance.QuitFromExtension() },
SetWidget: func(config extensions.WidgetConfig) {
kitInstance.Extensions().SetWidget(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveWidget: func(id string) {
kitInstance.Extensions().RemoveWidget(id)
go appInstance.NotifyWidgetUpdate()
},
SetHeader: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetHeader(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveHeader: func() {
kitInstance.Extensions().RemoveHeader()
go appInstance.NotifyWidgetUpdate()
},
SetFooter: func(config extensions.HeaderFooterConfig) {
kitInstance.Extensions().SetFooter(config)
go appInstance.NotifyWidgetUpdate()
},
RemoveFooter: func() {
kitInstance.Extensions().RemoveFooter()
go appInstance.NotifyWidgetUpdate()
},
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "select",
Message: config.Message,
Options: config.Options,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptSelectResult{Cancelled: true}
}
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
},
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
ch := make(chan app.PromptResponse, 1)
def := "false"
if config.DefaultValue {
def = "true"
}
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "confirm",
Message: config.Message,
Default: def,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptConfirmResult{Cancelled: true}
}
return extensions.PromptConfirmResult{Value: resp.Confirmed}
},
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
ch := make(chan app.PromptResponse, 1)
appInstance.SendPromptRequest(app.PromptRequestEvent{
PromptType: "input",
Message: config.Message,
Placeholder: config.Placeholder,
Default: config.Default,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.PromptInputResult{Cancelled: true}
}
return extensions.PromptInputResult{Value: resp.Value}
},
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
ch := make(chan app.OverlayResponse, 1)
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
Title: config.Title,
Content: config.Content.Text,
Markdown: config.Content.Markdown,
BorderColor: config.Style.BorderColor,
Background: config.Style.Background,
Width: config.Width,
MaxHeight: config.MaxHeight,
Anchor: string(config.Anchor),
Actions: config.Actions,
ResponseCh: ch,
})
resp := <-ch
if resp.Cancelled {
return extensions.OverlayResult{Cancelled: true, Index: -1}
}
return extensions.OverlayResult{
Action: resp.Action,
Index: resp.Index,
}
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
// In-process subagent via SDK.
sdkCfg := kit.SubagentConfig{
Prompt: config.Prompt,
Model: config.Model,
SystemPrompt: config.SystemPrompt,
Timeout: config.Timeout,
NoSession: config.NoSession,
}
// Bridge SDK events to extension SubagentEvents.
if config.OnEvent != nil {
sdkCfg.OnEvent = func(e kit.Event) {
se := sdkEventToSubagentEvent(e)
if se.Type != "" {
config.OnEvent(se)
}
}
}
result, err := kitInstance.Subagent(ctx, sdkCfg)
if result == nil {
return nil, &extensions.SubagentResult{Error: err}, err
}
extResult := &extensions.SubagentResult{
Response: result.Response,
Error: err,
SessionID: result.SessionID,
Elapsed: result.Elapsed,
}
if result.Usage != nil {
extResult.Usage = &extensions.SubagentUsage{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
}
}
return nil, extResult, err
},
// -------------------------------------------------------------------------
// Tree Navigation API (Phase 1 Bridge) - Second Context
// -------------------------------------------------------------------------
GetTreeNode: func(entryID string) *extensions.TreeNode {
node := kitInstance.GetTreeNode(entryID)
if node == nil {
return nil
}
return &extensions.TreeNode{
ID: node.ID,
ParentID: node.ParentID,
Type: node.Type,
Role: node.Role,
Content: node.Content,
Model: node.Model,
Provider: node.Provider,
Timestamp: node.Timestamp,
Children: node.Children,
}
},
GetCurrentBranch: func() []extensions.TreeNode {
nodes := kitInstance.GetCurrentBranch()
result := make([]extensions.TreeNode, len(nodes))
for i, n := range nodes {
result[i] = extensions.TreeNode{
ID: n.ID,
ParentID: n.ParentID,
Type: n.Type,
Role: n.Role,
Content: n.Content,
Model: n.Model,
Provider: n.Provider,
Timestamp: n.Timestamp,
Children: n.Children,
}
}
return result
},
GetChildren: kitInstance.GetChildren,
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
err := kitInstance.NavigateTo(entryID)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
SummarizeBranch: func(fromID, toID string) string {
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
return summary
},
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
err := kitInstance.CollapseBranch(fromID, toID, summary)
if err != nil {
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
}
return extensions.TreeNavigationResult{Success: true}
},
// -------------------------------------------------------------------------
// Skill Loading API (Phase 2 Bridge) - Second Context
// -------------------------------------------------------------------------
LoadSkill: func(path string) (*extensions.Skill, string) {
s, err := kitInstance.LoadSkillForExtension(path)
return s, err
},
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
return kitInstance.LoadSkillsFromDirForExtension(dir)
},
DiscoverSkills: func() extensions.SkillLoadResult {
skills := kitInstance.DiscoverSkillsForExtension()
return extensions.SkillLoadResult{Skills: skills}
},
InjectSkillAsContext: func(skillName string) string {
skills := kitInstance.DiscoverSkillsForExtension()
for _, s := range skills {
if s.Name == skillName {
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
}
}
return fmt.Sprintf("skill not found: %s", skillName)
},
InjectRawSkillAsContext: func(path string) string {
s, err := kitInstance.LoadSkillForExtension(path)
if err != "" {
return err
}
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
return ""
},
GetAvailableSkills: func() []extensions.Skill {
return kitInstance.DiscoverSkillsForExtension()
},
// -------------------------------------------------------------------------
// Template Parsing API (Phase 3 Bridge) - Second Context
// -------------------------------------------------------------------------
ParseTemplate: kit.ParseTemplate,
RenderTemplate: kit.RenderTemplate,
ParseArguments: kit.ParseArguments,
SimpleParseArguments: kit.SimpleParseArguments,
EvaluateModelConditional: func(condition string) bool {
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
},
RenderWithModelConditionals: func(content string) string {
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
},
// -------------------------------------------------------------------------
// Model Resolution API (Phase 4 Bridge) - Second Context
// -------------------------------------------------------------------------
ResolveModelChain: kit.ResolveModelChain,
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
return kit.GetModelCapabilities(model)
},
CheckModelAvailable: kit.CheckModelAvailable,
GetCurrentProvider: func() string {
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
},
GetCurrentModelID: func() string {
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
},
extCtx = buildInteractiveExtensionContext(extensionContextDeps{
ctx: ctx,
cwd: cwd,
modelName: modelName,
interactive: positionalPrompt == "",
kitInstance: kitInstance,
appInstance: appInstance,
usageTracker: usageTracker,
})
extCtx.Print = func(text string) { appInstance.PrintFromExtension("", text) }
extCtx.PrintInfo = func(text string) { appInstance.PrintFromExtension("info", text) }
extCtx.PrintError = func(text string) { appInstance.PrintFromExtension("error", text) }
kitInstance.Extensions().SetContext(extCtx)
}
// Convert extension commands to UI-layer type for the interactive TUI.
@@ -1642,9 +935,10 @@ func runNormalMode(ctx context.Context) error {
source = "project"
}
skillItems = append(skillItems, ui.SkillItem{
Name: s.Name,
Path: s.Path,
Source: source,
Name: s.Name,
Path: s.Path,
Source: source,
Description: s.Description,
})
}
@@ -1683,9 +977,10 @@ func runNormalMode(ctx context.Context) error {
source = "project"
}
items = append(items, ui.SkillItem{
Name: s.Name,
Path: s.Path,
Source: source,
Name: s.Name,
Path: s.Path,
Source: source,
Description: s.Description,
})
}
return items
@@ -2184,42 +1479,3 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
_, runErr := program.Run()
return runErr
}
// sdkEventToSubagentEvent converts an SDK event to an extension-facing
// SubagentEvent. Returns a zero-value event (Type=="") for events that
// don't map to anything useful.
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
switch ev := e.(type) {
case kit.MessageUpdateEvent:
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
case kit.ReasoningDeltaEvent:
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
case kit.ToolCallEvent:
return extensions.SubagentEvent{
Type: "tool_call", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
}
case kit.ToolExecutionStartEvent:
return extensions.SubagentEvent{
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolExecutionEndEvent:
return extensions.SubagentEvent{
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolResultEvent:
return extensions.SubagentEvent{
Type: "tool_result", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
ToolResult: ev.Result, IsError: ev.IsError,
}
case kit.TurnStartEvent:
return extensions.SubagentEvent{Type: "turn_start"}
case kit.TurnEndEvent:
return extensions.SubagentEvent{Type: "turn_end"}
default:
return extensions.SubagentEvent{}
}
}
+2 -69
View File
@@ -8,6 +8,7 @@ import (
"github.com/charmbracelet/log"
"github.com/mark3labs/kit/internal/extbridge"
"github.com/mark3labs/kit/internal/extensions"
kit "github.com/mark3labs/kit/pkg/kit"
)
@@ -152,38 +153,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
return kitInstance.ExecuteCompletion(context.Background(), req)
},
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
sdkCfg := kit.SubagentConfig{
Prompt: config.Prompt,
Model: config.Model,
SystemPrompt: config.SystemPrompt,
Timeout: config.Timeout,
NoSession: config.NoSession,
}
if config.OnEvent != nil {
sdkCfg.OnEvent = func(e kit.Event) {
se := sdkEventToSubagentEvent(e)
if se.Type != "" {
config.OnEvent(se)
}
}
}
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
if result == nil {
return nil, &extensions.SubagentResult{Error: err}, err
}
extResult := &extensions.SubagentResult{
Response: result.Response,
Error: err,
SessionID: result.SessionID,
Elapsed: result.Elapsed,
}
if result.Usage != nil {
extResult.Usage = &extensions.SubagentUsage{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
}
}
return nil, extResult, err
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
},
// Render — fall back to logging.
@@ -269,40 +239,3 @@ func (s *acpSession) clearCancel() {
defer s.cancelMu.Unlock()
s.cancelFn = nil
}
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
switch ev := e.(type) {
case kit.MessageUpdateEvent:
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
case kit.ReasoningDeltaEvent:
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
case kit.ToolCallEvent:
return extensions.SubagentEvent{
Type: "tool_call", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
}
case kit.ToolExecutionStartEvent:
return extensions.SubagentEvent{
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolExecutionEndEvent:
return extensions.SubagentEvent{
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolResultEvent:
return extensions.SubagentEvent{
Type: "tool_result", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
ToolResult: ev.Result, IsError: ev.IsError,
}
case kit.TurnStartEvent:
return extensions.SubagentEvent{Type: "turn_start"}
case kit.TurnEndEvent:
return extensions.SubagentEvent{Type: "turn_end"}
default:
return extensions.SubagentEvent{}
}
}
+28 -5
View File
@@ -9,12 +9,19 @@ import (
"github.com/mark3labs/kit/internal/tools"
)
// mcpExecutor is the subset of *tools.MCPToolManager that the adapter
// actually uses. Extracted as an interface so the adapter is unit-testable
// without constructing a full manager + connection pool.
type mcpExecutor interface {
ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error)
}
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
// This keeps the fantasy dependency confined to the agent layer — the tools
// package is a pure MCP client library with no LLM framework dependency.
type mcpAgentTool struct {
tool tools.MCPTool
manager *tools.MCPToolManager
exec mcpExecutor
providerOptions fantasy.ProviderOptions
}
@@ -29,10 +36,26 @@ func (t *mcpAgentTool) Info() fantasy.ToolInfo {
}
// Run executes the MCP tool by delegating to the MCPToolManager.
//
// MCP-side failures (JSON-RPC protocol errors, transport failures, schema
// validation rejections from the server) are surfaced to the model as soft
// tool errors rather than escalated to a critical agent error. This matches
// the contract that native Kit tools follow via kit.ErrorResult(...) and
// lets the model self-correct (e.g. retry with a fixed argument shape) or
// give up gracefully rather than aborting the turn mid-run.
//
// Context cancellation is the one exception: if the caller cancelled the
// context the turn was aborted intentionally, so we propagate the ctx error
// to let the agent loop unwind cleanly.
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
result, err := t.exec.ExecuteTool(ctx, t.tool.Name, call.Input)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
if ctxErr := ctx.Err(); ctxErr != nil {
return fantasy.ToolResponse{}, ctxErr
}
return fantasy.NewTextErrorResponse(
fmt.Sprintf("MCP tool %q failed: %s", t.tool.Name, err.Error()),
), nil
}
if result.IsError {
@@ -57,8 +80,8 @@ func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManage
agentTools := make([]fantasy.AgentTool, len(mcpTools))
for i, t := range mcpTools {
agentTools[i] = &mcpAgentTool{
tool: t,
manager: manager,
tool: t,
exec: manager,
}
}
return agentTools
+158
View File
@@ -0,0 +1,158 @@
package agent
import (
"context"
"errors"
"strings"
"testing"
"time"
"charm.land/fantasy"
"github.com/mark3labs/kit/internal/tools"
)
// stubExecutor lets each test script the (result, err) pair returned by
// ExecuteTool. The adapter holds an mcpExecutor interface, so this is the
// only seam the tests need.
type stubExecutor struct {
result *tools.MCPToolResult
err error
// called records the last invocation for assertion.
called bool
name string
input string
}
func (s *stubExecutor) ExecuteTool(_ context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error) {
s.called = true
s.name = prefixedName
s.input = inputJSON
return s.result, s.err
}
func newMCPAgentTool(exec mcpExecutor, name string) *mcpAgentTool {
return &mcpAgentTool{
tool: tools.MCPTool{Name: name},
exec: exec,
}
}
// Manager-side Go errors (JSON-RPC protocol errors, transport failures,
// schema validation rejections from the MCP server) must be surfaced to
// the model as soft tool errors so the agent loop can keep going. Aborting
// the turn would discard all prior tool results — see issue #N.
func TestMCPAgentTool_RPCErrorBecomesSoftError(t *testing.T) {
exec := &stubExecutor{
err: errors.New("MCP error -32602: Invalid params: missing field \"task\""),
}
tool := newMCPAgentTool(exec, "pubmed__search")
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
ID: "call-1",
Name: "pubmed__search",
Input: `{"query":"foo"}`,
})
if err != nil {
t.Fatalf("expected nil error (soft), got %v", err)
}
if !resp.IsError {
t.Fatalf("expected IsError=true, got false")
}
if !strings.Contains(resp.Content, "pubmed__search") {
t.Errorf("expected tool name in error content, got %q", resp.Content)
}
if !strings.Contains(resp.Content, "-32602") {
t.Errorf("expected underlying error text in content, got %q", resp.Content)
}
}
// Context cancellation is the one error that must remain critical: it
// means the caller intentionally aborted, and the agent loop needs to
// unwind cleanly rather than burning more steps.
func TestMCPAgentTool_CtxCancelStaysCritical(t *testing.T) {
exec := &stubExecutor{
// Real managers typically return ctx.Err() (or a wrapper) when the
// context is cancelled mid-call.
err: context.Canceled,
}
tool := newMCPAgentTool(exec, "slow__tool")
ctx, cancel := context.WithCancel(context.Background())
cancel()
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled, got %v", err)
}
if resp.IsError || resp.Content != "" {
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
}
}
// Deadline-exceeded behaves the same as cancellation: ctx.Err() is
// non-nil, so the adapter must propagate the critical error rather than
// converting the executor's error into a soft response.
func TestMCPAgentTool_CtxDeadlineStaysCritical(t *testing.T) {
exec := &stubExecutor{err: context.DeadlineExceeded}
tool := newMCPAgentTool(exec, "slow__tool")
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
defer cancel()
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
}
if resp.IsError || resp.Content != "" {
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
}
}
// Server-side soft errors (CallToolResult{ isError: true }) must continue
// to flow through as soft errors — this was the existing behavior and
// must not regress.
func TestMCPAgentTool_ServerIsErrorRemainsSoftError(t *testing.T) {
exec := &stubExecutor{
result: &tools.MCPToolResult{
IsError: true,
Content: "search service is rate limited; try again in 30s",
},
}
tool := newMCPAgentTool(exec, "pubmed__search")
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
if err != nil {
t.Fatalf("expected nil error, got %v", err)
}
if !resp.IsError {
t.Fatalf("expected IsError=true, got false")
}
if resp.Content != "search service is rate limited; try again in 30s" {
t.Errorf("expected pass-through content, got %q", resp.Content)
}
}
// Happy path: ordinary successful tool result is passed through unchanged.
func TestMCPAgentTool_SuccessIsPassthrough(t *testing.T) {
exec := &stubExecutor{
result: &tools.MCPToolResult{
IsError: false,
Content: `{"hits":3}`,
},
}
tool := newMCPAgentTool(exec, "pubmed__search")
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if resp.IsError {
t.Fatalf("expected IsError=false")
}
if resp.Content != `{"hits":3}` {
t.Errorf("expected pass-through content, got %q", resp.Content)
}
}
+91 -10
View File
@@ -78,6 +78,13 @@ type App struct {
// (~1 frame) so new updates are always let through once the TUI has had a
// chance to process the pending event.
widgetUpdatePending atomic.Bool
// steerDrainFn is the test seam used by releaseBusyAfterCompact to pull
// any steer messages that arrived during compaction. In production it is
// nil and the helper falls back to a.opts.Kit.DrainSteer(); tests that
// need to exercise the steer-drain path without standing up a full
// *kit.Kit can set this field directly to inject fake items.
steerDrainFn func() []queueItem
}
// New creates a new App with the provided options and pre-loaded messages.
@@ -356,6 +363,10 @@ func (a *App) AddContextMessage(text string) {
// tea.Program. customInstructions is optional text appended to the summary
// prompt (e.g. "Focus on the API design decisions").
//
// Any prompts queued via Run/RunWithFiles or steering messages injected via
// Steer/SteerWithFiles while compaction is running are flushed automatically
// once compaction completes (see releaseBusyAfterCompact).
//
// Satisfies ui.AppController.
func (a *App) CompactConversation(customInstructions string) error {
a.mu.Lock()
@@ -377,11 +388,7 @@ func (a *App) CompactConversation(customInstructions string) error {
go func() {
defer a.wg.Done()
defer func() {
a.mu.Lock()
a.busy = false
a.mu.Unlock()
}()
defer a.releaseBusyAfterCompact()
// Subscribe to SDK events for streaming compaction summary to the TUI.
sendFn := func(msg tea.Msg) {
@@ -420,6 +427,9 @@ func (a *App) CompactConversation(customInstructions string) error {
// CompactAsync is like CompactConversation but calls onComplete/onError
// callbacks instead of sending TUI events. Used by the extension API's
// ctx.Compact() which needs callback-based notification.
//
// Like CompactConversation, any prompts/steer messages received during
// compaction are flushed automatically once compaction finishes.
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
a.mu.Lock()
if a.closed {
@@ -440,11 +450,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
go func() {
defer a.wg.Done()
defer func() {
a.mu.Lock()
a.busy = false
a.mu.Unlock()
}()
defer a.releaseBusyAfterCompact()
// Subscribe to SDK events for streaming compaction summary to the TUI.
sendFn := func(msg tea.Msg) {
@@ -489,6 +495,81 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
return nil
}
// releaseBusyAfterCompact is the deferred tail that runs at the end of every
// compaction goroutine (success, error, or panic-after-recover paths). It
// flips a.busy back to false, but before doing so it checks whether any
// prompts piled up while compaction was running:
//
// - Run/RunWithFiles append to a.queue when a.busy is set.
// - Steer/SteerWithFiles deposit messages into the SDK steer channel via
// Kit.InjectSteerWithFiles when a.busy is set.
//
// Without this hand-off the queue would sit idle until the user submits
// another prompt — see issue #27. If we find anything pending we keep busy
// set, splice the steer messages to the front of the queue, and start a
// fresh drainQueue goroutine to deliver them as a single batched turn.
func (a *App) releaseBusyAfterCompact() {
// Pull steer messages outside the app mutex; DrainSteer takes its own
// internal lock and we don't want to nest the two. The test seam
// (a.steerDrainFn) takes precedence so unit tests can inject fake
// steer items without a real *kit.Kit.
var steerItems []queueItem
switch {
case a.steerDrainFn != nil:
steerItems = a.steerDrainFn()
case a.opts.Kit != nil:
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
steerItems = make([]queueItem, len(leftover))
for i, sm := range leftover {
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
}
}
}
a.mu.Lock()
// If the app was closed while compaction was running, drop everything
// and just clear busy. Run/Steer would have rejected new items already
// after Close(), but this guards against in-flight items that slipped
// in just before closed was set.
if a.closed {
a.queue = a.queue[:0]
a.busy = false
a.mu.Unlock()
return
}
// Combine steer-channel items (front) with the in-memory queue (back).
// Steer messages are placed first so they retain their "act now"
// semantics relative to ordinary queued prompts that arrived later.
pending := append(steerItems, a.queue...)
a.queue = a.queue[:0]
if len(pending) == 0 {
a.busy = false
a.mu.Unlock()
return
}
// Hand off to drainQueue: it will pick up the first item directly and
// scoop the rest from a.queue on its first iteration.
first := pending[0]
if len(pending) > 1 {
a.queue = append(a.queue, pending[1:]...)
}
// Stay busy across the goroutine swap.
a.wg.Add(1)
a.mu.Unlock()
// Notify the UI that steer-channel messages were consumed so the
// steering badge can clear; ordinary queued prompts will be reflected
// by the QueueUpdatedEvent that drainQueue emits as it picks them up.
if len(steerItems) > 0 {
a.sendEvent(SteerConsumedEvent{})
}
go a.drainQueue(first)
}
// --------------------------------------------------------------------------
// Non-interactive execution
// --------------------------------------------------------------------------
+206
View File
@@ -763,3 +763,209 @@ func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
}
}
}
// --------------------------------------------------------------------------
// releaseBusyAfterCompact (issue #27)
// --------------------------------------------------------------------------
// TestReleaseBusyAfterCompact_flushesQueuedMessages is a regression test for
// issue #27: messages queued via Run() while /compact is running used to sit
// in a.queue indefinitely until the user typed another prompt. After the fix
// the deferred releaseBusyAfterCompact tail picks up any pending items and
// dispatches drainQueue automatically.
//
// We simulate the compaction completion path directly (bypassing the SDK)
// by toggling busy=true, populating the queue exactly as Run() would have
// during compaction, and then invoking releaseBusyAfterCompact.
func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
return turnResult("compacted then drained"), nil
},
)
app := newTestApp(stub)
defer app.Close()
// Simulate the state at the start of the compaction tail: busy is set
// and a couple of prompts have piled up in the queue while we were
// summarising. (Run() would have appended them and returned a queue
// length > 0 to the caller.)
app.mu.Lock()
app.busy = true
app.queue = append(app.queue,
queueItem{Prompt: "queued during compact #1"},
queueItem{Prompt: "queued during compact #2"},
)
app.mu.Unlock()
// Invoke the deferred tail directly. It should kick off drainQueue.
app.releaseBusyAfterCompact()
// drainQueue runs in a goroutine. Wait for the app to come back to idle.
ok := waitForCondition(2*time.Second, func() bool {
app.mu.Lock()
defer app.mu.Unlock()
return !app.busy
})
if !ok {
t.Fatal("app did not become idle after releaseBusyAfterCompact: queue not drained")
}
// Wait for any in-flight goroutine to finish before reading state.
app.wg.Wait()
if got := app.QueueLength(); got != 0 {
t.Fatalf("expected empty queue after drain, got %d", got)
}
if n := stub.callCount(); n == 0 {
t.Fatalf("expected stub PromptFunc to fire at least once after compact, got %d calls", n)
}
}
// TestReleaseBusyAfterCompact_idleWhenQueueEmpty verifies that with no
// pending messages the helper just clears busy and does NOT spawn a
// drainQueue goroutine (no spurious agent turn).
func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
stub := newStub()
app := newTestApp(stub)
defer app.Close()
app.mu.Lock()
app.busy = true
app.mu.Unlock()
app.releaseBusyAfterCompact()
app.mu.Lock()
busy := app.busy
app.mu.Unlock()
if busy {
t.Fatal("expected busy=false after releaseBusyAfterCompact with empty queue")
}
// Give any rogue goroutine a moment to (incorrectly) call PromptFunc.
time.Sleep(50 * time.Millisecond)
if n := stub.callCount(); n != 0 {
t.Fatalf("expected 0 PromptFunc calls when queue empty, got %d", n)
}
}
// TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue exercises the SDK
// steer-drain branch of releaseBusyAfterCompact (issue #27 follow-up).
//
// Production wires a.opts.Kit.DrainSteer() to pull messages that arrived via
// Steer/SteerWithFiles during compaction, but Options.Kit is *kit.Kit (a
// concrete struct) so unit tests cannot stand up a real instance without a
// full LLM backend. The test uses the unexported steerDrainFn seam to inject
// fake steer items, then asserts that:
//
// - Steer items are dispatched ahead of any prompts that piled up in
// a.queue (steer retains "act now" priority over ordinary queued
// prompts), and
// - the helper still hands off to drainQueue so the steer item actually
// fires (the previous behaviour left them stranded — see #27).
func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
var pmu sync.Mutex
var firstPrompt string
stub := newStubWithFuncs(
func(ctx context.Context) (*kit.TurnResult, error) {
return turnResult("steer dispatched"), nil
},
)
// Wrap PromptFunc so we can capture the prompt text the stub receives
// (newStubWithFuncs's fns ignore prompt; we need it to verify ordering).
capturingPrompt := func(ctx context.Context, prompt string) (*kit.TurnResult, error) {
pmu.Lock()
if firstPrompt == "" {
firstPrompt = prompt
}
pmu.Unlock()
return stub.fn(ctx, prompt)
}
app := New(Options{PromptFunc: capturingPrompt}, nil)
defer app.Close()
// Inject fake steer items via the test seam. In production the same
// items would have been delivered through Kit.InjectSteerWithFiles
// during /compact and pulled by DrainSteer here.
app.steerDrainFn = func() []queueItem {
return []queueItem{
{Prompt: "steer-1"},
{Prompt: "steer-2"},
}
}
// Simulate the state at the end of compaction: busy is set and a couple
// of regular Run() prompts have piled up after the steer messages.
app.mu.Lock()
app.busy = true
app.queue = append(app.queue,
queueItem{Prompt: "queued-1"},
queueItem{Prompt: "queued-2"},
)
app.mu.Unlock()
app.releaseBusyAfterCompact()
// Wait for the dispatched batch to complete.
ok := waitForCondition(2*time.Second, func() bool {
app.mu.Lock()
defer app.mu.Unlock()
return !app.busy
})
if !ok {
t.Fatal("app did not become idle after steer-spliced releaseBusyAfterCompact")
}
app.wg.Wait()
// drainQueue picks up `first` directly and batches the rest. With
// PromptFunc set, executeBatch invokes us with items[0] only — that
// item must be the first steer message, proving steer items were
// spliced ahead of the previously queued prompts.
pmu.Lock()
got := firstPrompt
pmu.Unlock()
if got != "steer-1" {
t.Fatalf("expected first dispatched prompt to be steer item %q (steer items must come before queued prompts), got %q",
"steer-1", got)
}
// Queue should be fully drained and PromptFunc must have actually fired.
if n := app.QueueLength(); n != 0 {
t.Fatalf("expected empty queue after drain, got %d entries", n)
}
if n := stub.callCount(); n == 0 {
t.Fatal("expected stub PromptFunc to fire at least once after splice")
}
}
// TestReleaseBusyAfterCompact_dropsQueueWhenClosed verifies that if the app
// was closed during compaction the helper discards any pending items rather
// than spawning drainQueue against a torn-down App.
func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
stub := newStub()
app := newTestApp(stub)
app.mu.Lock()
app.busy = true
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
app.closed = true
app.mu.Unlock()
app.releaseBusyAfterCompact()
app.mu.Lock()
busy := app.busy
qLen := len(app.queue)
app.mu.Unlock()
if busy {
t.Fatal("expected busy=false even when closed")
}
if qLen != 0 {
t.Fatalf("expected queue cleared on closed app, got %d entries", qLen)
}
time.Sleep(20 * time.Millisecond)
if n := stub.callCount(); n != 0 {
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
}
}
+97
View File
@@ -0,0 +1,97 @@
// Package extbridge wires the public Kit SDK to the internal extensions
// package. It exists so that cmd/ and internal/acpserver/ don't both
// reimplement the same SDK→extension event/subagent conversions.
package extbridge
import (
"context"
"github.com/mark3labs/kit/internal/extensions"
kit "github.com/mark3labs/kit/pkg/kit"
)
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
// (Type=="") for events that don't map to anything useful — callers should
// drop those.
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
switch ev := e.(type) {
case kit.MessageUpdateEvent:
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
case kit.ReasoningDeltaEvent:
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
case kit.ToolCallEvent:
return extensions.SubagentEvent{
Type: "tool_call", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
}
case kit.ToolExecutionStartEvent:
return extensions.SubagentEvent{
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolExecutionEndEvent:
return extensions.SubagentEvent{
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
}
case kit.ToolResultEvent:
return extensions.SubagentEvent{
Type: "tool_result", ToolCallID: ev.ToolCallID,
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
ToolResult: ev.Result, IsError: ev.IsError,
}
case kit.TurnStartEvent:
return extensions.SubagentEvent{Type: "turn_start"}
case kit.TurnEndEvent:
return extensions.SubagentEvent{Type: "turn_end"}
default:
return extensions.SubagentEvent{}
}
}
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
// the result/events back into the extension-facing types. The returned
// handle is always nil — the SDK path runs synchronously and does not
// expose a separate process handle. Callers that need non-blocking
// behaviour should run this in their own goroutine.
//
// This function consolidates the previously-duplicated wiring in
// cmd/root.go (interactive + runtime contexts) and
// internal/acpserver/session.go.
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
sdkCfg := kit.SubagentConfig{
Prompt: cfg.Prompt,
Model: cfg.Model,
SystemPrompt: cfg.SystemPrompt,
Timeout: cfg.Timeout,
NoSession: cfg.NoSession,
}
if cfg.OnEvent != nil {
sdkCfg.OnEvent = func(e kit.Event) {
se := SDKEventToSubagentEvent(e)
if se.Type != "" {
cfg.OnEvent(se)
}
}
}
result, err := k.Subagent(ctx, sdkCfg)
if result == nil {
return nil, &extensions.SubagentResult{Error: err}, err
}
extResult := &extensions.SubagentResult{
Response: result.Response,
Error: err,
SessionID: result.SessionID,
Elapsed: result.Elapsed,
}
if result.Usage != nil {
extResult.Usage = &extensions.SubagentUsage{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
}
}
return nil, extResult, err
}
-19
View File
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
return filepath.Join(base, "kit", "git")
}
// GetInstalledPackages returns all installed packages from both scopes.
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
var all []ManifestEntry
global, err := i.loadManifest(ScopeGlobal)
if err != nil {
return nil, fmt.Errorf("loading global manifest: %w", err)
}
all = append(all, global.Packages...)
project, err := i.loadManifest(ScopeProject)
if err != nil {
return nil, fmt.Errorf("loading project manifest: %w", err)
}
all = append(all, project.Packages...)
return all, nil
}
// IsInstalled checks if a package is installed in either scope.
// Returns (scope, true) if installed, ("", false) otherwise.
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
+45 -46
View File
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
}
}
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
// saveManifest round-trip against a temp directory, ensuring an absent
// manifest loads as empty and a saved manifest reads back identically.
func TestLoadAndSaveManifest(t *testing.T) {
tempDir := t.TempDir()
installer := &Installer{
projectGitRoot: tempDir,
globalGitRoot: tempDir,
}
manifestPath := filepath.Join(tempDir, "packages.json")
// Test loading non-existent manifest
manifest, err := loadManifestFromPath(manifestPath)
manifest, err := installer.loadManifest(ScopeGlobal)
if err != nil {
t.Fatalf("loadManifestFromPath() error = %v", err)
t.Fatalf("loadManifest() error = %v", err)
}
if len(manifest.Packages) != 0 {
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
}
// Save it
err = saveManifestToPath(manifest, manifestPath)
err = installer.saveManifest(manifest, ScopeGlobal)
if err != nil {
t.Fatalf("saveManifestToPath() error = %v", err)
t.Fatalf("saveManifest() error = %v", err)
}
// Verify it was written to expected path
if _, err := os.Stat(manifestPath); err != nil {
t.Fatalf("manifest file not created: %v", err)
}
// Load it back
loaded, err := loadManifestFromPath(manifestPath)
loaded, err := installer.loadManifest(ScopeGlobal)
if err != nil {
t.Fatalf("loadManifestFromPath() error = %v", err)
t.Fatalf("loadManifest() error = %v", err)
}
if len(loaded.Packages) != 1 {
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
}
}
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
// followed by removeFromManifest leaves the manifest in its original
// (empty) state, using a temp-directory installer scope.
func TestAddAndRemoveFromManifest(t *testing.T) {
tempDir := t.TempDir()
// Set up environment for manifest path
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
t.Fatalf("Setenv() error = %v", err)
installer := &Installer{
projectGitRoot: tempDir,
globalGitRoot: tempDir,
}
defer func() {
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
t.Logf("Unsetenv() error = %v", err)
}
}()
// The manifest path when XDG_DATA_HOME is set
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
// Add an entry
entry := ManifestEntry{
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
Scope: ScopeGlobal,
}
err := addEntryToManifest(entry, ScopeGlobal)
if err != nil {
t.Fatalf("addEntryToManifest() error = %v", err)
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
t.Fatalf("addToManifest() error = %v", err)
}
// Verify it was added
manifest, err := loadManifestFromPath(manifestPath)
manifest, err := installer.loadManifest(ScopeGlobal)
if err != nil {
t.Fatalf("loadManifestFromPath() error = %v", err)
t.Fatalf("loadManifest() error = %v", err)
}
if len(manifest.Packages) != 1 {
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
}
// Remove it
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
if err != nil {
t.Fatalf("removeEntryFromManifest() error = %v", err)
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
t.Fatalf("removeFromManifest() error = %v", err)
}
// Verify it was removed
manifest, err = loadManifestFromPath(manifestPath)
manifest, err = installer.loadManifest(ScopeGlobal)
if err != nil {
t.Fatalf("loadManifestFromPath() error = %v", err)
t.Fatalf("loadManifest() error = %v", err)
}
if len(manifest.Packages) != 0 {
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
}
}
// TestFindInManifest writes a manifest file directly to the path
// resolved by the package-level manifestPathForScope helper and then
// confirms FindInManifest locates the entry by identity (and returns
// nil for a non-existent identity).
func TestFindInManifest(t *testing.T) {
tempDir := t.TempDir()
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
t.Fatalf("Setenv() error = %v", err)
}
defer func() {
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
t.Logf("Unsetenv() error = %v", err)
}
}()
t.Setenv("XDG_DATA_HOME", tempDir)
// Add an entry to global manifest
entry := ManifestEntry{
Source: "git:github.com/user/repo",
Host: "github.com",
Path: "user/repo",
Scope: ScopeGlobal,
// Write a manifest entry directly via the package-level path resolver
// so FindInManifest (which uses manifestPathForScope) can read it back.
manifestPath := manifestPathForScope(ScopeGlobal)
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
err := addEntryToManifest(entry, ScopeGlobal)
if err != nil {
t.Fatalf("addEntryToManifest() error = %v", err)
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
t.Fatalf("WriteFile() error = %v", err)
}
// Find it
-73
View File
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
return &manifest, nil
}
// saveManifestToScope saves the manifest to the given scope.
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
path := manifestPathForScope(scope)
return saveManifestToPath(manifest, path)
}
// saveManifestToPath saves a manifest to a specific file path.
func saveManifestToPath(manifest *Manifest, path string) error {
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return fmt.Errorf("creating manifest directory: %w", err)
}
data, err := json.MarshalIndent(manifest, "", " ")
if err != nil {
return fmt.Errorf("encoding manifest: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("writing manifest: %w", err)
}
return nil
}
// manifestPathForScope returns the manifest file path for a scope.
func manifestPathForScope(scope InstallScope) string {
if scope == ScopeProject {
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
return filepath.Join(base, "kit", "git", "packages.json")
}
// GetGlobalManifest returns the global manifest.
func GetGlobalManifest() (*Manifest, error) {
return loadManifestFromScope(ScopeGlobal)
}
// GetProjectManifest returns the project manifest.
func GetProjectManifest() (*Manifest, error) {
return loadManifestFromScope(ScopeProject)
}
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
manifest, err := loadManifestFromScope(scope)
if err != nil {
return err
}
// Remove any existing entry with same identity
identity := entry.Identity()
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
for _, p := range manifest.Packages {
if p.Identity() != identity {
filtered = append(filtered, p)
}
}
filtered = append(filtered, entry)
manifest.Packages = filtered
return saveManifestToScope(manifest, scope)
}
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
func removeEntryFromManifest(identity string, scope InstallScope) error {
manifest, err := loadManifestFromScope(scope)
if err != nil {
return err
}
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
for _, p := range manifest.Packages {
if p.Identity() != identity {
filtered = append(filtered, p)
}
}
manifest.Packages = filtered
return saveManifestToScope(manifest, scope)
}
// FindInManifest finds an entry by identity in either global or project manifest.
// Returns the entry and its scope, or nil if not found.
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
-225
View File
@@ -2,22 +2,15 @@
package extensions
import (
"bufio"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
"time"
)
// ---------------------------------------------------------------------------
// Subagent types
// ---------------------------------------------------------------------------
// SubagentConfig configures a subagent spawn.
type SubagentConfig struct {
// Prompt is the task/instruction for the subagent (required).
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
func (h *SubagentHandle) Done() <-chan struct{} {
return h.done
}
// ---------------------------------------------------------------------------
// Internal helpers
// ---------------------------------------------------------------------------
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
type subagentJSONOutput struct {
Response string `json:"response"`
StopReason string `json:"stop_reason,omitempty"`
SessionID string `json:"session_id,omitempty"`
Usage *struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
} `json:"usage,omitempty"`
}
var subagentCounter atomic.Uint64
func generateSubagentID() string {
n := subagentCounter.Add(1)
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
}
func findKitBinary() string {
// Try the current process executable first.
if exe, err := os.Executable(); err == nil {
if _, err := os.Stat(exe); err == nil {
return exe
}
}
// Fall back to PATH lookup.
if p, err := exec.LookPath("kit"); err == nil {
return p
}
return "kit"
}
// ---------------------------------------------------------------------------
// SpawnSubagent implementation
// ---------------------------------------------------------------------------
// SpawnSubagent spawns a child Kit instance to perform a task.
//
// When config.Blocking is true, blocks until completion and returns the result
// directly (handle is nil). When false, returns immediately with a handle for
// monitoring/cancellation.
//
// The subagent runs with --json --no-session --no-extensions flags by default,
// ensuring isolation from the parent's extensions and session state.
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
if cfg.Prompt == "" {
return nil, nil, fmt.Errorf("prompt is required")
}
timeout := cfg.Timeout
if timeout == 0 {
timeout = 5 * time.Minute
}
kitBinary := findKitBinary()
// Build subprocess arguments.
args := []string{
"--json",
"--no-extensions",
}
if cfg.NoSession {
args = append(args, "--no-session")
}
if cfg.Model != "" {
args = append(args, "--model", cfg.Model)
}
// Handle system prompt - write to temp file if provided.
var tmpFile *os.File
if cfg.SystemPrompt != "" {
var err error
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
if err != nil {
return nil, nil, fmt.Errorf("create temp file: %w", err)
}
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpFile.Name())
return nil, nil, fmt.Errorf("write system prompt: %w", err)
}
_ = tmpFile.Close()
args = append(args, "--system-prompt", tmpFile.Name())
}
// Add the prompt as a positional argument.
args = append(args, cfg.Prompt)
// Create command with timeout context.
ctx, cancel := context.WithTimeout(context.Background(), timeout)
cmd := exec.CommandContext(ctx, kitBinary, args...)
cmd.Env = os.Environ()
stdout, err := cmd.StdoutPipe()
if err != nil {
cancel()
if tmpFile != nil {
_ = os.Remove(tmpFile.Name())
}
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
cancel()
if tmpFile != nil {
_ = os.Remove(tmpFile.Name())
}
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
}
handle := &SubagentHandle{
ID: generateSubagentID(),
done: make(chan struct{}),
}
// Start the subprocess.
start := time.Now()
if err := cmd.Start(); err != nil {
cancel()
if tmpFile != nil {
_ = os.Remove(tmpFile.Name())
}
return nil, nil, fmt.Errorf("start subprocess: %w", err)
}
handle.mu.Lock()
handle.proc = cmd.Process
handle.mu.Unlock()
// Run the subprocess monitoring in a goroutine.
go func() {
defer close(handle.done)
defer cancel()
if tmpFile != nil {
defer func() { _ = os.Remove(tmpFile.Name()) }()
}
var wg sync.WaitGroup
var stdoutBuf strings.Builder
// Read stderr (live output).
wg.Go(func() {
scanner := bufio.NewScanner(stderr)
scanner.Buffer(make([]byte, 256*1024), 256*1024)
for scanner.Scan() {
line := scanner.Text()
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
cfg.OnOutput(line + "\n")
}
}
})
// Read stdout (JSON output).
scanner := bufio.NewScanner(stdout)
scanner.Buffer(make([]byte, 256*1024), 256*1024)
for scanner.Scan() {
stdoutBuf.WriteString(scanner.Text() + "\n")
}
wg.Wait()
waitErr := cmd.Wait()
elapsed := time.Since(start)
// Build result.
result := SubagentResult{Elapsed: elapsed}
if waitErr != nil {
result.Error = waitErr
if exitErr, ok := waitErr.(*exec.ExitError); ok {
result.ExitCode = exitErr.ExitCode()
} else {
result.ExitCode = 1
}
}
// Parse JSON output.
raw := strings.TrimSpace(stdoutBuf.String())
var parsed subagentJSONOutput
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
result.Response = parsed.Response
result.SessionID = parsed.SessionID
if parsed.Usage != nil {
result.Usage = &SubagentUsage{
InputTokens: parsed.Usage.InputTokens,
OutputTokens: parsed.Usage.OutputTokens,
}
}
} else {
// Fallback: use raw stdout.
result.Response = raw
}
handle.mu.Lock()
handle.result = &result
handle.proc = nil
handle.mu.Unlock()
if cfg.OnComplete != nil {
cfg.OnComplete(result)
}
}()
if cfg.Blocking {
// Wait for completion and return result directly.
<-handle.done
handle.mu.Lock()
r := handle.result
handle.mu.Unlock()
return nil, r, nil
}
return handle, nil, nil
}
-17
View File
@@ -3,7 +3,6 @@ package models
import (
"crypto/sha256"
"encoding/hex"
"maps"
"os"
"charm.land/fantasy"
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
// Prefix with "kit-" to identify KIT-generated cache keys
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
}
// mergeProviderOptions merges multiple ProviderOptions maps.
// Later maps take precedence over earlier ones.
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
result := make(fantasy.ProviderOptions)
for _, opt := range opts {
maps.Copy(result, opt)
}
if len(result) == 0 {
return nil
}
return result
}
-56
View File
@@ -3,8 +3,6 @@ package models
import (
"os"
"testing"
"charm.land/fantasy"
)
func TestModelInfo_SupportsCaching(t *testing.T) {
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
t.Errorf("OpenAI caching should work when thinking is OFF")
}
}
func TestMergeProviderOptions(t *testing.T) {
opts1 := fantasy.ProviderOptions{
"provider1": &testProviderData{value: "value1"},
}
opts2 := fantasy.ProviderOptions{
"provider2": &testProviderData{value: "value2"},
}
merged := mergeProviderOptions(opts1, opts2)
if len(merged) != 2 {
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
}
if _, ok := merged["provider1"]; !ok {
t.Errorf("merged options should contain 'provider1' key")
}
if _, ok := merged["provider2"]; !ok {
t.Errorf("merged options should contain 'provider2' key")
}
// Later options should override earlier ones
opts3 := fantasy.ProviderOptions{
"provider1": &testProviderData{value: "overridden"},
}
merged2 := mergeProviderOptions(opts1, opts3)
if data, ok := merged2["provider1"].(*testProviderData); ok {
if data.value != "overridden" {
t.Errorf("later options should override earlier ones, got %q", data.value)
}
}
if mergeProviderOptions() != nil {
t.Errorf("mergeProviderOptions with no args should return nil")
}
}
// testProviderData is a simple implementation of ProviderOptionsData for testing
type testProviderData struct {
value string
}
func (t *testProviderData) Options() {}
func (t *testProviderData) MarshalJSON() ([]byte, error) {
return []byte(`"` + t.value + `"`), nil
}
func (t *testProviderData) UnmarshalJSON(data []byte) error {
return nil
}
-168
View File
@@ -1,168 +0,0 @@
package models
import (
"context"
"sync"
"time"
"charm.land/fantasy"
)
// ProviderPool manages reusable LLM provider instances to reduce overhead
// when spawning multiple subagents or making repeated completion calls.
type ProviderPool struct {
mu sync.RWMutex
providers map[string]*pooledProvider
ttl time.Duration
closed bool
closeCh chan struct{}
}
type pooledProvider struct {
model fantasy.LanguageModel
closer func() error
providerOpts fantasy.ProviderOptions
created time.Time
lastUsed time.Time
refs int32
}
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
const DefaultPoolTTL = 5 * time.Minute
// globalPool is the singleton provider pool instance.
var globalPool *ProviderPool
var poolOnce sync.Once
// GetGlobalPool returns the singleton provider pool instance.
func GetGlobalPool() *ProviderPool {
poolOnce.Do(func() {
globalPool = NewProviderPool(DefaultPoolTTL)
})
return globalPool
}
// NewProviderPool creates a provider pool with the given TTL for idle providers.
func NewProviderPool(ttl time.Duration) *ProviderPool {
p := &ProviderPool{
providers: make(map[string]*pooledProvider),
ttl: ttl,
closeCh: make(chan struct{}),
}
go p.cleanupLoop()
return p
}
// Get returns a provider for the model string, creating one if needed.
// The returned release function must be called when the provider is no longer
// needed. The provider may be reused by subsequent Get calls.
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
p.mu.Lock()
// Check if we have an existing provider.
if pp, ok := p.providers[modelString]; ok {
pp.refs++
pp.lastUsed = time.Now()
p.mu.Unlock()
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
}
p.mu.Unlock()
// Create a new provider outside the lock.
config := &ProviderConfig{ModelString: modelString}
result, err := CreateProvider(ctx, config)
if err != nil {
return nil, nil, nil, err
}
p.mu.Lock()
defer p.mu.Unlock()
// Double-check: another goroutine may have created one while we were unlocked.
if pp, ok := p.providers[modelString]; ok {
// Close the one we just created and use the existing one.
if result.Closer != nil {
_ = result.Closer.Close()
}
pp.refs++
pp.lastUsed = time.Now()
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
}
var closerFn func() error
if result.Closer != nil {
closerFn = result.Closer.Close
}
pp := &pooledProvider{
model: result.Model,
closer: closerFn,
providerOpts: result.ProviderOptions,
created: time.Now(),
lastUsed: time.Now(),
refs: 1,
}
p.providers[modelString] = pp
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
}
func (p *ProviderPool) release(modelString string) {
p.mu.Lock()
defer p.mu.Unlock()
if pp, ok := p.providers[modelString]; ok {
pp.refs--
pp.lastUsed = time.Now()
}
}
func (p *ProviderPool) cleanupLoop() {
ticker := time.NewTicker(p.ttl / 2)
defer ticker.Stop()
for {
select {
case <-p.closeCh:
return
case <-ticker.C:
p.cleanup()
}
}
}
func (p *ProviderPool) cleanup() {
p.mu.Lock()
defer p.mu.Unlock()
now := time.Now()
for key, pp := range p.providers {
// Only clean up providers with no active references and past TTL.
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
if pp.closer != nil {
_ = pp.closer()
}
delete(p.providers, key)
}
}
}
// Close shuts down the pool and releases all providers.
func (p *ProviderPool) Close() {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return
}
p.closed = true
close(p.closeCh)
for key, pp := range p.providers {
if pp.closer != nil {
_ = pp.closer()
}
delete(p.providers, key)
}
p.mu.Unlock()
}
-25
View File
@@ -179,31 +179,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
return templates, nil
}
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
// It returns the deduplicated list and diagnostics for any collisions.
// This is a standalone function for when you need to deduplicate an existing list.
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
seen := make(map[string]*PromptTemplate)
var result []*PromptTemplate
var diagnostics []Diagnostic
for _, tpl := range templates {
if existing, ok := seen[tpl.Name]; ok {
diagnostics = append(diagnostics, Diagnostic{
Name: tpl.Name,
KeptPath: existing.FilePath,
DroppedPath: tpl.FilePath,
Reason: "duplicate template name (first-match-wins)",
})
} else {
seen[tpl.Name] = tpl
result = append(result, tpl)
}
}
return result, diagnostics
}
// loadDefaultTemplates returns the built-in default templates.
// These are embedded templates that ship with Kit.
func loadDefaultTemplates() []*PromptTemplate {
+18 -9
View File
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
_, _ = tm.AppendMessage(msg4)
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
// Kept messages must appear BEFORE post-compaction messages so the LLM
// sees the conversation in chronological order. Otherwise the latest
// post-compaction user message would be followed by an older kept user
// message, breaking user/assistant alternation and causing the model to
// respond as if the post-compaction turn never happened.
messages, _, _ := tm.BuildContext()
if len(messages) != 3 {
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
}
// Verify order: summary, M4 (new), M3 (kept)
// Verify order: summary, M3 (kept), M4 (new)
if messages[0].Role != fantasy.MessageRoleSystem {
t.Errorf("first message should be summary, got %s", messages[0].Role)
}
if messages[1].Role != fantasy.MessageRoleAssistant {
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
if messages[1].Role != fantasy.MessageRoleUser {
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
}
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
if m3Text != "Message 3 - kept" {
t.Errorf("unexpected M3 text: %s", m3Text)
}
if messages[2].Role != fantasy.MessageRoleAssistant {
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
}
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
if m4Text != "Message 4 - after compaction" {
t.Errorf("unexpected M4 text: %s", m4Text)
}
if messages[2].Role != fantasy.MessageRoleUser {
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
}
// Verify that M1 is NOT in the context
for i, msg := range messages {
+82 -79
View File
@@ -755,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// If there is a compaction, inject the summary first and collect
// the kept messages starting from FirstKeptEntryID (since the
// compaction entry's parent chain doesn't include them).
// If there is a compaction, inject the summary first, then the
// preserved "kept" messages (chronologically before the compaction),
// then the post-compaction messages (chronologically after).
//
// Order matters: the kept messages must come BEFORE the post-compaction
// branch so the LLM sees the conversation in chronological order. If the
// kept messages were appended last, the latest user message in the
// current branch would be followed by an older kept user message,
// breaking the strict user/assistant alternation that providers expect
// and causing the model to respond as if the previous turn never
// happened.
if lastCompaction != nil {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleSystem,
@@ -768,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
},
})
// Collect entries from the compaction entry itself (at compactionIndex)
// and any entries before it in the branch (newer messages).
for i := compactionIndex; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue // skip malformed entries
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
// Convert branch summary to a user message for context.
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Already handled above (summary injected).
continue
}
}
// Now collect the kept messages starting from FirstKeptEntryID.
// These are not in the current branch because the compaction entry
// is parented to the first kept entry's parent, not the first kept entry.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting
// messages that were added after the compaction.
// Step 1: collect the kept messages starting from FirstKeptEntryID.
// These are not on the current branch (the compaction entry is a
// new root with no parent), so we iterate tm.entries in append order
// and stop when we reach the compaction entry itself.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
@@ -825,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// Stop when we reach the compaction entry itself.
// Messages after the compaction are collected from the branch walk above.
// Stop when we reach the compaction entry itself; messages
// after it are collected from the branch walk below.
if entryID == lastCompaction.ID {
break
}
// Process this kept entry.
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
@@ -860,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
}
}
// Step 2: collect entries on the current branch after the compaction
// entry (these are post-compaction messages). The compaction entry
// itself is skipped — its summary was already injected above.
for i := compactionIndex; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
messages = append(messages, msgs...)
case *BranchSummaryEntry:
if e.Summary != "" {
messages = append(messages, fantasy.Message{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{
fantasy.TextPart{
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
},
},
})
}
case *ModelChangeEntry:
provider = e.Provider
modelID = e.ModelID
case *CompactionEntry:
// Summary already injected above.
continue
}
}
return messages, provider, modelID
}
@@ -1030,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
var ids []string
// If there's a compaction, we need to collect IDs from:
// 1. Entries after the compaction entry in the branch (newer messages)
// 2. Entries from FirstKeptEntryID onwards (kept messages)
// If there's a compaction, we collect IDs in the same order as
// BuildContext: [summary placeholder, kept messages, post-compaction
// messages]. This ordering must stay in sync with BuildContext so a
// cut-point index can be mapped back to the correct entry ID.
if lastCompaction != nil {
// Placeholder for the summary system message (no entry ID).
ids = append(ids, "")
// Collect IDs from entries after the compaction entry (newer messages).
for i := compactionIndex + 1; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
// Collect IDs from the kept messages starting at FirstKeptEntryID.
// We iterate through entries in order (not using getBranchLocked) to avoid
// walking back to old compacted messages.
// We stop when we reach the compaction entry to avoid double-counting.
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
// Iterate tm.entries in append order and stop at the compaction
// entry to avoid double-counting post-compaction messages.
if lastCompaction.FirstKeptEntryID != "" {
found := false
for _, entry := range tm.entries {
entryID := tm.EntryID(entry)
// Skip entries until we reach the first kept entry.
if !found {
if entryID == lastCompaction.FirstKeptEntryID {
found = true
@@ -1076,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
}
}
// Stop when we reach the compaction entry itself.
if entryID == lastCompaction.ID {
break
}
@@ -1100,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
}
}
// Step 2: IDs of entries after the compaction entry on the current
// branch (post-compaction messages).
for i := compactionIndex + 1; i < len(branch); i++ {
entry := branch[i]
switch e := entry.(type) {
case *MessageEntry:
msg, err := e.ToMessage()
if err != nil {
continue
}
msgs := msg.ToLLMMessages()
for range msgs {
ids = append(ids, e.ID)
}
case *BranchSummaryEntry:
if e.Summary != "" {
ids = append(ids, e.ID)
}
}
}
return ids
}
-63
View File
@@ -28,15 +28,6 @@ type blockRenderer struct {
// renderingOption configures block rendering
type renderingOption func(*blockRenderer)
// WithFullWidth returns a renderingOption that configures the block renderer
// to expand to the full available width of its container. When enabled, the
// block will fill the entire horizontal space rather than sizing to its content.
func WithFullWidth() renderingOption {
return func(c *blockRenderer) {
c.fullWidth = true
}
}
// WithNoBorder returns a renderingOption that disables all borders on the
// block, rendering content with only padding.
func WithNoBorder() renderingOption {
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
}
}
// WithMarginTop returns a renderingOption that sets the top margin
// for the block. The margin is specified in number of lines and adds
// vertical space above the block.
func WithMarginTop(margin int) renderingOption {
return func(c *blockRenderer) {
c.marginTop = margin
}
}
// WithMarginBottom returns a renderingOption that sets the bottom margin
// for the block. The margin is specified in number of lines and adds
// vertical space below the block.
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
}
}
// WithPaddingLeft returns a renderingOption that sets the left padding
// for the block content. The padding is specified in number of characters
// and adds horizontal space between the left border and the content.
func WithPaddingLeft(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingLeft = padding
}
}
// WithPaddingRight returns a renderingOption that sets the right padding
// for the block content. The padding is specified in number of characters
// and adds horizontal space between the content and the right border.
func WithPaddingRight(padding int) renderingOption {
return func(c *blockRenderer) {
c.paddingRight = padding
}
}
// WithPaddingTop returns a renderingOption that sets the top padding
// for the block content. The padding is specified in number of lines
// and adds vertical space between the top border and the content.
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
}
}
// WithBackground returns a renderingOption that sets the background color
// for the entire block. The color parameter accepts any color.Color value,
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
func WithBackground(c color.Color) renderingOption {
return func(br *blockRenderer) {
br.background = &c
}
}
// WithForeground returns a renderingOption that overrides the default text
// foreground color (theme.Text) for the block. Useful for muted or
// de-emphasized content blocks.
func WithForeground(c color.Color) renderingOption {
return func(br *blockRenderer) {
br.foreground = &c
}
}
// WithWidth returns a renderingOption that sets a specific width for the block
// in characters. This overrides the default container width and allows precise
// control over the block's horizontal dimensions.
func WithWidth(width int) renderingOption {
return func(c *blockRenderer) {
c.width = width
}
}
// renderContentBlock renders content with configurable styling options
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
renderer := &blockRenderer{
-19
View File
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
return c.usageTracker
}
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
// through the CLI's rendering system for consistent message formatting and display.
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
return NewCLIDebugLogger(c)
}
// SetModelName updates the current AI model name being used in the conversation.
// This name is displayed in message headers to indicate which model is responding.
func (c *CLI) SetModelName(modelName string) {
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
}
// DisplayAssistantMessage renders and displays an AI assistant's response message
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
// with an empty model name for backward compatibility.
func (c *CLI) DisplayAssistantMessage(message string) error {
return c.DisplayAssistantMessageWithModel(message, "")
}
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
// with the specified model name shown in the message header. The message is
// formatted according to the current display mode and includes timestamp information.
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
fmt.Println(rendered)
}
// DisplayCancellation displays a system message indicating that the current
// AI generation has been cancelled by the user (typically via ESC key).
func (c *CLI) DisplayCancellation() {
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
}
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
// Debug messages are formatted distinctively and only shown when the CLI is
// initialized with debug=true.
+6 -12
View File
@@ -161,6 +161,12 @@ var SlashCommands = []SlashCommand{
Category: "Navigation",
Aliases: []string{"/r"},
},
{
Name: "/copy",
Description: "Copy the last message to the system clipboard",
Category: "System",
Aliases: []string{"/cp"},
},
{
Name: "/export",
Description: "Export session (JSONL by default, or /export path.jsonl)",
@@ -199,18 +205,6 @@ func GetCommandByName(name string) *SlashCommand {
return nil
}
// GetAllCommandNames returns a complete list of all command names and their aliases.
// This is useful for command completion, validation, and help display. The returned
// slice contains both primary command names and all alternative aliases.
func GetAllCommandNames() []string {
var names []string
for _, cmd := range SlashCommands {
names = append(names, cmd.Name)
names = append(names, cmd.Aliases...)
}
return names
}
// ExtensionCommand is a slash command registered by an extension. Unlike
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
// extension commands carry their own Execute callback.
-79
View File
@@ -1,79 +0,0 @@
package ui
import (
"fmt"
"strings"
"time"
)
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
// It provides debug logging functionality that integrates with the CLI's display
// system, ensuring debug messages are properly formatted and displayed alongside
// other conversation content.
type CLIDebugLogger struct {
cli *CLI
}
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
// debug output through the provided CLI instance. The logger will respect the CLI's
// debug mode setting and display format preferences.
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
return &CLIDebugLogger{cli: cli}
}
// LogDebug processes and displays a debug message through the CLI's rendering system.
// Messages are formatted with appropriate emojis and tags based on their content type
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
// multi-line debug output and connection pool status messages with context-aware formatting.
func (l *CLIDebugLogger) LogDebug(message string) {
if l.cli == nil || !l.cli.debug {
return
}
// Format the message to include all the debug info in a structured way
var formattedMessage string
// Check if this is a multi-line debug output (like connection info)
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
// Extract the tag and content
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
content := after
content = strings.TrimSpace(content)
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
content := after
content = strings.TrimSpace(content)
// Add appropriate emoji based on the message content
if strings.Contains(content, "Creating new connection") {
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
} else if strings.Contains(content, "Reusing") {
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
} else if strings.Contains(content, "closed") {
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
} else if strings.Contains(content, "Failed to close") {
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
} else {
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
}
} else {
formattedMessage = message
}
} else {
formattedMessage = message
}
// Use the CLI's debug message rendering
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
}
// IsDebugEnabled checks whether debug logging is currently active. Returns true
// if the CLI instance exists and has debug mode enabled, allowing callers to
// conditionally perform expensive debug operations only when necessary.
func (l *CLIDebugLogger) IsDebugEnabled() bool {
return l.cli != nil && l.cli.debug
}
-62
View File
@@ -25,17 +25,6 @@ type TextMessageItem struct {
timestamp time.Time
}
// NewTextMessageItem creates a new text message for the scrollback.
// The content should be pre-rendered using MessageRenderer for proper styling.
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
return &TextMessageItem{
id: id,
role: role,
content: content,
timestamp: time.Now(),
}
}
// NewStyledMessageItem creates a message item with pre-rendered styled content.
// This is the preferred way to create messages when you have styled content from MessageRenderer.
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
@@ -316,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
}
// --------------------------------------------------------------------------
// SystemMessageItem - System messages (commands, info, errors)
// --------------------------------------------------------------------------
// SystemMessageItem represents a system message (commands, info, errors).
type SystemMessageItem struct {
id string
content string
timestamp time.Time
cachedRender string
cachedWidth int
}
// NewSystemMessageItem creates a new system message for the scrollback.
func NewSystemMessageItem(id, content string) *SystemMessageItem {
return &SystemMessageItem{
id: id,
content: content,
timestamp: time.Now(),
}
}
func (m *SystemMessageItem) ID() string {
return m.id
}
func (m *SystemMessageItem) Render(width int) string {
// Return cached render if width matches
if m.cachedWidth == width && m.cachedRender != "" {
return m.cachedRender
}
// Simple system message formatting
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
// Cache and return
m.cachedRender = rendered
m.cachedWidth = width
return rendered
}
func (m *SystemMessageItem) Height() int {
if m.cachedRender != "" {
return strings.Count(m.cachedRender, "\n") + 1
}
// Estimate
if m.cachedWidth > 0 {
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
}
return 3
}
// --------------------------------------------------------------------------
// Helper: generateMessageID
// --------------------------------------------------------------------------
+171 -14
View File
@@ -129,9 +129,10 @@ type AppController interface {
// SkillItem holds display metadata about a loaded skill for the startup
// [Skills] section. Built by the CLI layer from the SDK's []*kit.Skill.
type SkillItem struct {
Name string // Skill name (e.g. "btca-cli").
Path string // Absolute path to the skill file.
Source string // "project" or "user" (global).
Name string // Skill name (e.g. "btca-cli").
Path string // Absolute path to the skill file.
Source string // "project" or "user" (global).
Description string // Short summary used in autocomplete and help.
}
// MCPPromptInfo describes an MCP prompt for display in the TUI (autocomplete,
@@ -912,6 +913,20 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
}
}
// Merge skills into autocomplete as /skill:<name> commands. Skills accept
// optional trailing args, so HasArgs is true — Enter populates the input
// with "/skill:name " rather than auto-submitting.
if ic, ok := m.input.(*InputComponent); ok && len(opts.SkillItems) > 0 {
for _, s := range opts.SkillItems {
ic.commands = append(ic.commands, commands.SlashCommand{
Name: "/skill:" + s.Name,
Description: formatSkillDescription(s),
Category: "Skills",
HasArgs: true,
})
}
}
// Merge MCP prompts into autocomplete as /<server>:<prompt> commands.
if ic, ok := m.input.(*InputComponent); ok && len(opts.MCPPrompts) > 0 {
for _, p := range opts.MCPPrompts {
@@ -1251,7 +1266,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.scrollList.autoScroll = false
case tea.MouseWheelDown:
m.scrollList.ScrollBy(scrollLines)
if m.scrollList.AtBottom() {
// Only re-enable auto-scroll when the user is not actively
// selecting text. Otherwise a wheel-down during a drag-select
// would re-arm GotoBottom on the next stream chunk, shifting
// the highlighted row out from under the cursor.
if m.scrollList.AtBottom() && !m.scrollList.IsMouseDown() {
m.scrollList.autoScroll = true
}
}
@@ -1259,9 +1278,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Mouse click selection (crush-style character-level) ──────────────────
case tea.MouseClickMsg:
if msg.Button == tea.MouseLeft {
// Calculate viewport-relative coordinates.
viewY := msg.Y - m.scrollbackYOffset
if viewY >= 0 && viewY < m.scrollList.height {
// Compute the scrollback origin from the current frame's layout
// rather than the stale cached value from the previous View().
// scrollbackYOffset/scrollList.height are only refreshed inside
// View() and lag behind any state change that resized the header
// (extension widgets, warning rows, etc.) since the last render.
yOff, vpHeight := m.currentScrollbackBounds()
viewY := msg.Y - yOff
if viewY >= 0 && viewY < vpHeight {
// Clear any previous selection on a new click.
// HandleMouseDown will set up new selection state.
if m.scrollList.HandleMouseDown(msg.X, viewY) {
@@ -1272,8 +1296,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Mouse motion/drag for character-level selection ──────────────────────
case tea.MouseMotionMsg:
viewY := msg.Y - m.scrollbackYOffset
if viewY >= 0 && viewY < m.scrollList.height {
yOff, vpHeight := m.currentScrollbackBounds()
viewY := msg.Y - yOff
if viewY >= 0 && viewY < vpHeight {
m.scrollList.HandleMouseDrag(msg.X, viewY)
}
@@ -1603,10 +1628,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// ── Cancel timer expired ─────────────────────────────────────────────────
case uicore.CancelTimerExpiredMsg:
if m.canceling {
m.layoutDirty = true
}
m.canceling = false
// ── Ctrl+C reset timer expired ────────────────────────────────────────────
case uicore.CtrlCResetMsg:
if m.ctrlCPressedOnce {
m.layoutDirty = true
}
m.ctrlCPressedOnce = false
// ── Input submitted ──────────────────────────────────────────────────────
@@ -3095,6 +3126,8 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
return m.handleResumeCommand()
case "/export":
return m.handleExportCommand(args)
case "/copy":
return m.handleCopyCommand()
case "/share":
return m.handleShareCommand()
case "/import":
@@ -3395,13 +3428,46 @@ func (m *AppModel) refreshPromptTemplates() {
}
}
// refreshSkillItems reloads skill items from the provider callback.
// Called on ContentReloadEvent.
// refreshSkillItems reloads skill items from the provider callback and
// updates the autocomplete entries. Called on ContentReloadEvent.
func (m *AppModel) refreshSkillItems() {
if m.getSkillItems == nil {
return
}
m.skillItems = m.getSkillItems()
newItems := m.getSkillItems()
m.skillItems = newItems
if ic, ok := m.input.(*InputComponent); ok {
// Remove old Skills commands and add fresh ones.
var kept []commands.SlashCommand
for _, sc := range ic.commands {
if sc.Category != "Skills" {
kept = append(kept, sc)
}
}
for _, s := range newItems {
kept = append(kept, commands.SlashCommand{
Name: "/skill:" + s.Name,
Description: formatSkillDescription(s),
Category: "Skills",
HasArgs: true,
})
}
ic.commands = kept
}
}
// formatSkillDescription returns the autocomplete description for a skill,
// prefixed with [project] or [user] so users can tell colliding names apart.
func formatSkillDescription(s SkillItem) string {
prefix := "[user]"
if s.Source == "project" {
prefix = "[project]"
}
if s.Description == "" {
return prefix
}
return prefix + " " + s.Description
}
// refreshMCPPrompts reloads MCP prompts from the provider callback and
@@ -3476,6 +3542,7 @@ func (m *AppModel) printHelpMessage() {
"**System:**\n" +
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
"- `/clear`: Clear message history\n" +
"- `/copy`: Copy the last message to the system clipboard\n" +
"- `/export [path]`: Export session as JSONL\n" +
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
"- `/reset-usage`: Reset usage statistics\n" +
@@ -3712,7 +3779,12 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
}
// Auto-scroll to bottom if enabled (iteratr pattern)
// Don't call SetItems() - the slice reference hasn't changed
if m.scrollList != nil {
//
// CRITICAL: never scroll the viewport while the user is actively
// selecting text (mouse button held). Doing so shifts the
// highlighted content out from under the cursor and produces the
// off-by-N-row drift users see when copy-selecting during streaming.
if m.scrollList != nil && !m.scrollList.IsMouseDown() {
if m.scrollList.autoScroll {
m.scrollList.GotoBottom()
} else if m.scrollList.AtBottom() {
@@ -3740,6 +3812,36 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
m.refreshContent()
}
// currentScrollbackBounds returns the live (yOffset, viewportHeight) for the
// scrollback region, computed from the current state — not from the cached
// values populated inside View().
//
// scrollbackYOffset and scrollList.height are refreshed once per render, so
// any state change that resizes the header (extension widget toggles,
// warning rows, queued messages, etc.) leaves the cached values one frame
// stale. Mouse click handlers in Update() can then place the cursor on the
// wrong line, producing the off-by-N-row drift seen during copy-selection.
//
// This recomputes the header height by rendering it (cheap — the renderer
// returns "" when no extension header is set) and recomputes the viewport
// height the same way distributeHeight() does, so both inputs to the
// y → (item, line) mapping are always current.
func (m *AppModel) currentScrollbackBounds() (yOffset, viewportHeight int) {
// Force a fresh layout if anything in Update() marked the state dirty;
// otherwise scrollList.height still reflects the previous frame.
if m.layoutDirty {
m.distributeHeight()
m.layoutDirty = false
}
if headerView := m.renderHeaderFooter(m.getHeader); headerView != "" {
yOffset = lipgloss.Height(headerView)
}
if m.scrollList != nil {
viewportHeight = m.scrollList.height
}
return yOffset, viewportHeight
}
// distributeHeight recalculates child component heights after a window resize,
// queue change, widget update, or state transition, and propagates the computed
// stream height to the StreamComponent.
@@ -3812,7 +3914,20 @@ func (m *AppModel) distributeHeight() {
headerFooterLines += lipgloss.Height(footerView)
}
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines, 0)
// Account for transient warning rows that View() injects between the
// scrollback and the separator. These flags are toggled by ESC/Ctrl+C
// handlers; without subtracting them here the joined view exceeds
// m.height by one line per active warning and the bottom of the screen
// gets silently clipped — which in turn invalidates scrollbackYOffset.
var warningLines int
if m.canceling {
warningLines++
}
if m.ctrlCPressedOnce {
warningLines++
}
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines-warningLines, 0)
// In alt screen mode, give the calculated height to ScrollList instead of stream.
// The stream component still exists but is embedded as the last item in scrollList.
@@ -4236,6 +4351,48 @@ func (m *AppModel) handleNameCommand(args string) tea.Cmd {
return nil
}
// handleCopyCommand copies the last user or assistant message to the system
// clipboard. Skips transient system messages (e.g. /help output) so the user
// gets the actual last conversational message.
func (m *AppModel) handleCopyCommand() tea.Cmd {
if len(m.messages) == 0 {
m.printSystemMessage("No messages to copy.")
return nil
}
var (
text string
role string
)
for i := len(m.messages) - 1; i >= 0; i-- {
switch msg := m.messages[i].(type) {
case *TextMessageItem:
if msg.role == "user" || msg.role == "assistant" {
text = msg.content
role = msg.role
}
case *StreamingMessageItem:
if msg.role == "assistant" || msg.role == "reasoning" {
text = msg.content.String()
role = msg.role
}
}
if text != "" {
break
}
}
if strings.TrimSpace(text) == "" {
m.printSystemMessage("No copyable message found.")
return nil
}
m.printSystemMessage(fmt.Sprintf(
"Copied last %s message to clipboard (%d chars).", role, len(text),
))
return clipboard.CopyToClipboard(text)
}
// handleExportCommand exports the current session to a file.
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
//
+1 -60
View File
@@ -19,28 +19,7 @@ import (
// - @path/to/file.txt (unquoted, no spaces)
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
// UserBlock renders a user message with herald Tip styling.
// The width parameter controls line wrapping so long messages don't overflow.
// Any @file tokens in the content are highlighted with the theme accent color.
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
if strings.TrimSpace(content) == "" {
content = "(empty message)"
}
// Wrap content before passing to herald Alert so long lines break
// inside the alert box. Subtract 4 to account for the alert bar
// prefix ("│ ") and a small margin.
if width > 4 {
content = lipgloss.Wrap(content, width-4, "")
}
// Highlight @file tokens with accent color so file references are
// visually distinct from surrounding prompt text.
content = HighlightFileTokens(content, theme)
rendered := ty.Tip(content)
return styleMarginBottom(theme, rendered)
}
// UserBlock-related rendering helpers and herald typography.
// HighlightFileTokens wraps @file tokens in the given text with the theme
// accent color so they stand out visually in rendered user messages.
@@ -154,44 +133,6 @@ func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) strin
return styleMarginBottom(theme, rendered)
}
// ToolBlock renders a tool execution result with header and body.
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
var icon string
iconColor := theme.Success
if isError {
icon = "×"
iconColor = theme.Error
} else {
icon = "✓"
}
// Style the tool name with color
nameColor := theme.Info
if isError {
nameColor = theme.Error
}
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
// Build the content: icon + name + params on first line, then body
headerLine := styledIcon + " " + styledName
if params != "" {
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
}
if strings.TrimSpace(body) == "" {
body = ty.Italic("(no output)")
}
// Compose: icon + name + params, then body
fullContent := ty.Compose(
headerLine,
"",
body,
)
return styleMarginBottom(theme, fullContent)
}
// styleMarginBottom applies a 1-line margin bottom using the theme.
func styleMarginBottom(theme style.Theme, content string) string {
return style.GetCachedStyles().MarginBottom1.Render(content)
+9 -29
View File
@@ -4,30 +4,9 @@ import (
"strings"
"testing"
"github.com/indaco/herald"
"github.com/mark3labs/kit/internal/ui/style"
)
// testTypography creates a herald Typography for tests.
func testTypography(theme style.Theme) *herald.Typography {
return herald.New(
herald.WithPalette(herald.ColorPalette{
Primary: theme.Primary,
Secondary: theme.Secondary,
Tertiary: theme.Info,
Accent: theme.Accent,
Highlight: theme.Highlight,
Muted: theme.Muted,
Text: theme.Text,
Surface: theme.Background,
Base: theme.CodeBg,
}),
herald.WithAlertLabel(herald.AlertTip, ""),
herald.WithAlertIcon(herald.AlertTip, ""),
)
}
func TestHighlightFileTokens(t *testing.T) {
theme := style.DefaultTheme()
@@ -88,24 +67,25 @@ func TestHighlightFileTokens(t *testing.T) {
}
}
func TestUserBlockHighlightsFileTokens(t *testing.T) {
// TestHighlightFileTokensInjectsANSI verifies that HighlightFileTokens
// preserves the original @file references in the output and wraps each
// token with ANSI escape codes for the theme accent color.
func TestHighlightFileTokensInjectsANSI(t *testing.T) {
theme := style.DefaultTheme()
ty := testTypography(theme)
// A user message with @file tokens should contain ANSI escapes around the token.
content := "refactor @main.go and @utils.go"
result := UserBlock(content, 80, ty, theme)
result := HighlightFileTokens(content, theme)
// The rendered output should contain both file references.
// The output should still contain both file references.
if !strings.Contains(result, "@main.go") {
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
t.Errorf("HighlightFileTokens output should contain @main.go, got:\n%s", result)
}
if !strings.Contains(result, "@utils.go") {
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
t.Errorf("HighlightFileTokens output should contain @utils.go, got:\n%s", result)
}
// Verify ANSI codes are present (the tokens are styled).
if !strings.Contains(result, "\x1b[") {
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
t.Errorf("HighlightFileTokens output should contain ANSI escape codes for styled @file tokens")
}
}
+19 -2
View File
@@ -60,10 +60,13 @@ func NewScrollList(width, height int) *ScrollList {
}
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
// the viewport will scroll to the bottom to show the latest content.
// the viewport will scroll to the bottom to show the latest content — EXCEPT
// when the user is actively selecting text (mouse button held), in which case
// the scroll position is locked so the highlighted content stays under the
// cursor. The pending bottom-scroll is deferred to MouseUp.
func (s *ScrollList) SetItems(items []MessageItem) {
s.items = items
if s.autoScroll {
if s.autoScroll && !s.sel.MouseDown {
s.GotoBottom()
}
}
@@ -157,6 +160,10 @@ func (s *ScrollList) HandleMouseDown(x, y int) bool {
// HandleMouseDrag handles mouse motion while button is held.
// Updates the selection endpoint for character-level precision.
// Returns true if selection was updated.
//
// Defensively disables auto-scroll on every drag update — even if the
// MouseDown handler missed (e.g. click landed in viewport padding), any
// active drag means the user is selecting and the viewport must not jump.
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
if !s.sel.MouseDown {
return false
@@ -171,6 +178,9 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
return false
}
// Hard-lock the viewport while dragging.
s.autoScroll = false
s.sel.DragItemIdx = itemIdx
s.sel.DragLineIdx = lineIdx
s.sel.DragCol = x
@@ -178,6 +188,13 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
return true
}
// IsMouseDown reports whether the user currently has the mouse button held
// (i.e. a selection drag is in progress). Used by the parent model to avoid
// re-enabling auto-scroll during streaming while the user is selecting.
func (s *ScrollList) IsMouseDown() bool {
return s.sel.MouseDown
}
// HandleMouseUp handles mouse button release.
// Returns true if there was an active selection.
func (s *ScrollList) HandleMouseUp() bool {
+132
View File
@@ -0,0 +1,132 @@
package ui
import (
"fmt"
"strings"
"testing"
)
// fakeItem is a deterministic MessageItem for ScrollList tests.
type fakeItem struct {
id string
lines int
}
func (f *fakeItem) ID() string { return f.id }
func (f *fakeItem) Render(_ int) string {
if f.lines <= 0 {
return ""
}
parts := make([]string, f.lines)
for i := range parts {
parts[i] = fmt.Sprintf("%s-line-%d", f.id, i)
}
return strings.Join(parts, "\n")
}
func (f *fakeItem) Height() int { return f.lines }
// makeItems builds n fake items of `lines` height each.
func makeItems(n, lines int) []MessageItem {
out := make([]MessageItem, n)
for i := range out {
out[i] = &fakeItem{id: fmt.Sprintf("item-%d", i), lines: lines}
}
return out
}
// TestScrollList_MouseDownPreventsAutoScroll verifies the core fix for the
// copy-selection drift bug: while the user has the mouse button held
// (drag-selecting), incoming content updates must NOT shift the viewport,
// because doing so moves the highlighted content out from under the cursor.
func TestScrollList_MouseDownPreventsAutoScroll(t *testing.T) {
sl := NewScrollList(80, 10)
sl.SetItems(makeItems(20, 2)) // 40 lines of content into a 10-line viewport
// Capture the auto-scrolled-to-bottom position.
startOffsetIdx := sl.offsetIdx
startOffsetLine := sl.offsetLine
// User clicks somewhere in the visible area, starting a drag-select.
if !sl.HandleMouseDown(5, 3) {
t.Fatalf("HandleMouseDown should accept a click inside the viewport")
}
if !sl.IsMouseDown() {
t.Fatalf("IsMouseDown should be true after HandleMouseDown")
}
// New content arrives. With autoScroll still true, SetItems would
// normally call GotoBottom() and shift the viewport. The fix should
// suppress that while MouseDown is held.
sl.SetItems(makeItems(30, 2)) // 60 lines now
if sl.offsetIdx != startOffsetIdx || sl.offsetLine != startOffsetLine {
t.Errorf("viewport scrolled during active drag: was (%d,%d), now (%d,%d)",
startOffsetIdx, startOffsetLine, sl.offsetIdx, sl.offsetLine)
}
// User releases the mouse — drag is over.
sl.HandleMouseUp()
if sl.IsMouseDown() {
t.Fatalf("IsMouseDown should be false after HandleMouseUp")
}
// After release, a fresh content update should resume auto-scrolling
// (move the offset to track the new bottom).
afterReleaseIdx := sl.offsetIdx
afterReleaseLine := sl.offsetLine
sl.SetItems(makeItems(50, 2))
if sl.offsetIdx == afterReleaseIdx && sl.offsetLine == afterReleaseLine {
t.Errorf("autoscroll did not resume after MouseUp: offset stuck at (%d,%d)",
afterReleaseIdx, afterReleaseLine)
}
}
// TestScrollList_DragDisablesAutoScroll verifies that any successful
// HandleMouseDrag call clears autoScroll, even when HandleMouseDown didn't
// observe it (e.g. a stale wheel-down event set it back to true mid-stream).
func TestScrollList_DragDisablesAutoScroll(t *testing.T) {
sl := NewScrollList(80, 10)
sl.SetItems(makeItems(20, 2))
// Begin a selection.
if !sl.HandleMouseDown(5, 3) {
t.Fatalf("HandleMouseDown failed")
}
// Simulate an external code path that re-enabled autoScroll while
// MouseDown is still held (the precise condition that caused drift).
sl.autoScroll = true
// Drag motion should hard-lock the viewport again.
if !sl.HandleMouseDrag(10, 4) {
t.Fatalf("HandleMouseDrag failed")
}
if sl.autoScroll {
t.Errorf("HandleMouseDrag must clear autoScroll to prevent mid-drag jumps")
}
}
// TestScrollList_SetItemsRespectsMouseDown is the most direct regression
// test: even with autoScroll enabled and new content appended at the
// bottom, SetItems must not move the viewport while a mouse drag is in
// progress. This is what caused the "highlighting shifts by 1+ rows
// during streaming" symptom reported by the user.
func TestScrollList_SetItemsRespectsMouseDown(t *testing.T) {
sl := NewScrollList(80, 5)
sl.SetItems(makeItems(10, 2)) // 20 lines into a 5-line viewport
// At bottom.
preIdx, preLine := sl.offsetIdx, sl.offsetLine
// Hold mouse down (no actual drag needed).
if !sl.HandleMouseDown(0, 0) {
t.Fatalf("HandleMouseDown failed")
}
// Append several more items as if streaming. With the bug, each
// SetItems would call GotoBottom and shift the offset.
for n := 11; n <= 15; n++ {
sl.SetItems(makeItems(n, 2))
if sl.offsetIdx != preIdx || sl.offsetLine != preLine {
t.Fatalf("viewport drifted during streaming with mouse held: "+
"start=(%d,%d) now=(%d,%d) after adding item %d",
preIdx, preLine, sl.offsetIdx, sl.offsetLine, n)
}
}
}
-95
View File
@@ -211,106 +211,11 @@ func DefaultTheme() Theme {
}
}
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
// padding, and appropriate width. Used for grouping related content in a visually
// distinct box.
func StyleCard(width int, theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Width(width).
Border(lipgloss.RoundedBorder()).
BorderForeground(theme.Border).
Padding(1, 2).
MarginBottom(1)
}
// IsDarkBackground returns the cached terminal background detection result.
func IsDarkBackground() bool {
return isDarkBg
}
// StyleHeader creates a lipgloss style for primary headers using the theme's
// primary color with bold text for emphasis and hierarchy.
func StyleHeader(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Primary).
Bold(true)
}
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
// secondary color with bold text, providing visual hierarchy below primary headers.
func StyleSubheader(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Secondary).
Bold(true)
}
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
// and italic formatting, suitable for supplementary or less important information.
func StyleMuted(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Muted).
Italic(true)
}
// StyleSuccess creates a lipgloss style for success messages using green colors
// with bold text to indicate successful operations or positive outcomes.
func StyleSuccess(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Success).
Bold(true)
}
// StyleError creates a lipgloss style for error messages using red colors
// with bold text to ensure visibility of problems or failures.
func StyleError(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Error).
Bold(true)
}
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
// colors with bold text to draw attention to potential issues or cautions.
func StyleWarning(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Warning).
Bold(true)
}
// StyleInfo creates a lipgloss style for informational messages using blue colors
// with bold text for general notifications and status updates.
func StyleInfo(theme Theme) lipgloss.Style {
return lipgloss.NewStyle().
Foreground(theme.Info).
Bold(true)
}
// CreateSeparator generates a horizontal separator line with the specified width,
// character, and color. Useful for visually dividing sections of content in the UI.
func CreateSeparator(width int, char string, c color.Color) string {
return lipgloss.NewStyle().
Foreground(c).
Width(width).
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
}
// CreateProgressBar generates a visual progress bar with filled and empty segments
// based on the percentage complete. The bar uses Unicode block characters for smooth
// appearance and theme colors to indicate progress.
func CreateProgressBar(width int, percentage float64, theme Theme) string {
filled := int(float64(width) * percentage / 100)
empty := width - filled
filledBar := lipgloss.NewStyle().
Foreground(theme.Success).
Render(lipgloss.PlaceHorizontal(filled, lipgloss.Left, "█"))
emptyBar := lipgloss.NewStyle().
Foreground(theme.Muted).
Render(lipgloss.PlaceHorizontal(empty, lipgloss.Left, "░"))
return filledBar + emptyBar
}
// CreateBadge generates a styled badge or label with inverted colors (text on
// colored background) for highlighting important tags, statuses, or categories.
func CreateBadge(text string, c color.Color) string {
-7
View File
@@ -6,13 +6,6 @@ import (
heraldmd "github.com/indaco/herald-md"
)
// BaseStyle returns a new, empty lipgloss style that can be customized with
// additional styling methods. This serves as the foundation for building more
// complex styled components.
func BaseStyle() lipgloss.Style {
return lipgloss.NewStyle()
}
// markdownTypographyCache holds the last-created Typography instance for
// herald-md rendering. It is cached to avoid re-initialization on every
// streaming flush tick. The cache is invalidated by SetTheme when the
-6
View File
@@ -543,12 +543,6 @@ func ApplyThemeWithoutSave(name string) error {
return nil
}
// RefreshThemeRegistry re-scans the themes directory. Call after the user
// drops a new file into ~/.config/kit/themes/.
func RefreshThemeRegistry() {
initThemeRegistry()
}
// RegisterThemeFromConfig adds a theme to the runtime registry from an
// extension's ThemeColorConfig (string hex pairs). Replaces any existing
// entry with the same name. The theme is immediately available via
-140
View File
@@ -1,140 +0,0 @@
package ui
import (
"fmt"
"strings"
"charm.land/bubbles/v2/textarea"
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
)
type ToolApprovalInput struct {
textarea textarea.Model
toolName string
toolArgs string
width int
selected bool // true when "yes" is highlighted and false when "no" is
approved bool
done bool
}
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
ta := textarea.New()
ta.Placeholder = ""
ta.ShowLineNumbers = false
ta.CharLimit = 0
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
ta.SetHeight(4) // Default to 3 lines like huh
ta.Focus()
// Style the textarea using theme colors.
theme := GetTheme()
styles := ta.Styles()
styles.Focused.Base = lipgloss.NewStyle()
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
styles.Focused.Prompt = lipgloss.NewStyle()
styles.Focused.CursorLine = lipgloss.NewStyle()
ta.SetStyles(styles)
return &ToolApprovalInput{
textarea: ta,
toolName: toolName,
toolArgs: toolArgs,
width: width,
selected: true,
}
}
func (t *ToolApprovalInput) Init() tea.Cmd {
return textarea.Blink
}
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.KeyPressMsg:
switch msg.String() {
case "y", "Y":
t.approved = true
t.done = true
return t, tea.Quit
case "n", "N":
t.approved = false
t.done = true
return t, tea.Quit
case "left":
t.selected = true
return t, nil
case "right":
t.selected = false
return t, nil
case "enter":
t.approved = t.selected
t.done = true
return t, tea.Quit
case "esc", "ctrl+c":
t.approved = false
t.done = true
return t, tea.Quit
}
}
return t, nil
}
func (t *ToolApprovalInput) View() tea.View {
if t.done {
return tea.NewView("we are done")
}
containerStyle := lipgloss.NewStyle()
theme := GetTheme()
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
titleStyle := lipgloss.NewStyle().
Foreground(theme.Text).
MarginBottom(1).
PaddingLeft(3)
// Input box with huh-like styling
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.ThickBorder()).
BorderLeft(true).
BorderRight(false).
BorderTop(false).
BorderBottom(false).
BorderForeground(theme.Primary).
PaddingLeft(2). // match message block paddingLeft
Width(t.width - 1) // full width minus left border
// Style for the currently selected/highlighted option
selectedStyle := lipgloss.NewStyle().
Foreground(theme.Success).
Bold(true).
Underline(true)
// Style for the unselected/unhighlighted option
unselectedStyle := lipgloss.NewStyle().
Foreground(theme.VeryMuted)
// Build the view
var view strings.Builder
view.WriteString(titleStyle.Render("Allow tool execution"))
view.WriteString("\n")
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
view.WriteString(details)
view.WriteString("Allow tool execution: ")
var yesText, noText string
if t.selected {
yesText = selectedStyle.Render("[y]es")
noText = unselectedStyle.Render("[n]o")
} else {
yesText = unselectedStyle.Render("[y]es")
noText = selectedStyle.Render("[n]o")
}
view.WriteString(yesText + "/" + noText + "\n")
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
}
+17 -2
View File
@@ -243,7 +243,7 @@ host.ClearSession()
## Re-exported Types
The SDK re-exports types so you don't need direct internal imports:
The SDK re-exports message/session/MCP types so you don't need direct internal imports. Agent-configuration types are Kit-owned (not aliases) and use only SDK types in their signatures, so consumers never need to import the underlying LLM-provider package.
```go
// Message types
@@ -251,13 +251,28 @@ kit.Message, kit.MessageRole, kit.ContentPart
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
// LLM types — concrete Kit-owned structs, no external library dependency
// LLM types — Kit-owned `LLM*` aliases over the underlying provider types,
// so consumers never import the provider package directly
kit.LLMMessage // {Role LLMMessageRole, Content string}
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
kit.LLMResponse // {Content, FinishReason, Usage}
kit.LLMFilePart // {Filename, Data []byte, MediaType}
// Agent configuration — concrete Kit-owned structs and function types.
// All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct
// these without importing any LLM-provider package.
kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control
kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool
kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress)
kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string)
kit.ToolExecutionHandler // func(toolCallID, toolName, toolArgs string, isStarting bool)
kit.ToolResultHandler // func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
kit.ResponseHandler // func(content string)
kit.StreamingResponseHandler // func(content string)
kit.ToolCallContentHandler // func(content string)
kit.SpinnerFunc // func(fn func() error) error
// MCP OAuth types
kit.MCPServer // *server.MCPServer for in-process MCP transport
kit.MCPServerConfig // Configuration for an MCP server (stdio, SSE, or in-process)
+208
View File
@@ -0,0 +1,208 @@
package kit
import (
"context"
"errors"
"testing"
"time"
"github.com/mark3labs/kit/internal/agent"
)
// TestAgentConfigToInternal verifies that the SDK-side AgentConfig converts
// faithfully to the internal agent.AgentConfig representation, preserving
// every field consumed by the internal agent layer.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigToInternal(t *testing.T) {
t.Run("nil receiver returns nil", func(t *testing.T) {
var c *AgentConfig
if got := c.toInternal(); got != nil {
t.Errorf("nil.toInternal() = %v, want nil", got)
}
})
t.Run("scalar fields round-trip", func(t *testing.T) {
c := &AgentConfig{
SystemPrompt: "sys",
MaxSteps: 7,
StreamingEnabled: true,
DisableCoreTools: true,
}
got := c.toInternal()
if got == nil {
t.Fatal("toInternal() = nil")
}
if got.SystemPrompt != "sys" {
t.Errorf("SystemPrompt = %q, want %q", got.SystemPrompt, "sys")
}
if got.MaxSteps != 7 {
t.Errorf("MaxSteps = %d, want 7", got.MaxSteps)
}
if !got.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
}
if !got.DisableCoreTools {
t.Error("DisableCoreTools = false, want true")
}
})
t.Run("tool slices propagate without conversion", func(t *testing.T) {
// Tool is a type alias for the underlying LLM-tool type, so the
// SDK []Tool and internal []fantasy.AgentTool slices share the
// same backing array after conversion.
tool := NewTool[struct{}]("noop", "noop", nil)
c := &AgentConfig{
CoreTools: []Tool{tool},
ExtraTools: []Tool{tool, tool},
}
got := c.toInternal()
if len(got.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(got.CoreTools))
}
if len(got.ExtraTools) != 2 {
t.Errorf("ExtraTools len = %d, want 2", len(got.ExtraTools))
}
})
t.Run("tool wrapper is invoked through internal config", func(t *testing.T) {
called := false
c := &AgentConfig{
ToolWrapper: func(in []Tool) []Tool {
called = true
return in
},
}
got := c.toInternal()
if got.ToolWrapper == nil {
t.Fatal("internal ToolWrapper is nil")
}
_ = got.ToolWrapper(nil)
if !called {
t.Error("SDK ToolWrapper was not invoked through the internal config")
}
})
t.Run("OnMCPServerLoaded propagates", func(t *testing.T) {
var captured string
wantErr := errors.New("boom")
c := &AgentConfig{
OnMCPServerLoaded: func(name string, _ int, _ error) {
captured = name
},
}
got := c.toInternal()
got.OnMCPServerLoaded("svr", 3, wantErr)
if captured != "svr" {
t.Errorf("OnMCPServerLoaded captured = %q, want %q", captured, "svr")
}
})
t.Run("DebugLogger propagates", func(t *testing.T) {
dl := &fakeDebugLogger{enabled: true}
c := &AgentConfig{DebugLogger: dl}
got := c.toInternal()
if got.DebugLogger == nil {
t.Fatal("internal DebugLogger is nil")
}
if !got.DebugLogger.IsDebugEnabled() {
t.Error("IsDebugEnabled = false, want true")
}
got.DebugLogger.LogDebug("hello")
if len(dl.messages) != 1 || dl.messages[0] != "hello" {
t.Errorf("messages = %v, want [hello]", dl.messages)
}
})
t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) {
c := &AgentConfig{
MCPTaskConfig: MCPTaskConfig{
PerServerMode: map[string]MCPTaskMode{
"build-svr": MCPTaskModeAlways,
},
DefaultTTL: 30 * time.Second,
PollInterval: 250 * time.Millisecond,
MaxPollInterval: 2 * time.Second,
Timeout: 5 * time.Minute,
Progress: func(_ MCPTaskProgress) {},
},
}
got := c.toInternal()
if got.MCPTaskConfig.DefaultTTL != 30*time.Second {
t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL)
}
if got.MCPTaskConfig.PollInterval != 250*time.Millisecond {
t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval)
}
if got.MCPTaskConfig.MaxPollInterval != 2*time.Second {
t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval)
}
if got.MCPTaskConfig.Timeout != 5*time.Minute {
t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout)
}
mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"]
if !ok {
t.Fatal("PerServerMode missing 'build-svr'")
}
if string(mode) != string(MCPTaskModeAlways) {
t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways)
}
if got.MCPTaskConfig.Progress == nil {
t.Fatal("internal Progress handler is nil")
}
})
t.Run("auth and token store factories are wired", func(t *testing.T) {
auth := &fakeAuthHandler{}
tokenCalls := 0
var tokenServer string
factory := MCPTokenStoreFactory(func(server string) (MCPTokenStore, error) {
tokenCalls++
tokenServer = server
return nil, nil
})
c := &AgentConfig{
AuthHandler: auth,
TokenStoreFactory: factory,
}
got := c.toInternal()
if got.AuthHandler == nil {
t.Fatal("internal AuthHandler is nil")
}
if got.TokenStoreFactory == nil {
t.Fatal("internal TokenStoreFactory is nil")
}
_, _ = got.TokenStoreFactory("https://example.test")
if tokenCalls != 1 {
t.Errorf("token factory call count = %d, want 1", tokenCalls)
}
if tokenServer != "https://example.test" {
t.Errorf("token factory server arg = %q", tokenServer)
}
if got.AuthHandler.RedirectURI() != "redirect" {
t.Errorf("RedirectURI = %q, want %q", got.AuthHandler.RedirectURI(), "redirect")
}
})
// Compile-time check that the internal type is what we expect.
//nolint:staticcheck // QF1011: explicit type asserts the conversion target.
var _ *agent.AgentConfig = (&AgentConfig{}).toInternal()
}
// fakeAuthHandler implements both kit.MCPAuthHandler and the structurally
// identical tools.MCPAuthHandler used by the internal layer.
type fakeAuthHandler struct{}
func (f *fakeAuthHandler) RedirectURI() string { return "redirect" }
func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
return "", nil
}
// fakeDebugLogger implements kit.DebugLogger for tests.
type fakeDebugLogger struct {
enabled bool
messages []string
}
func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) }
func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled }
+3 -3
View File
@@ -148,9 +148,9 @@ func parseToolArgs(toolArgs string) map[string]any {
// ---------------------------------------------------------------------------
// Finish reasons reported by the LLM provider on a completed turn. These
// mirror fantasy.FinishReason string values so comparisons against
// TurnEndEvent.StopReason / TurnResult.StopReason are stable across
// providers.
// mirror the underlying provider's finish reason string values so
// comparisons against TurnEndEvent.StopReason / TurnResult.StopReason are
// stable across providers.
const (
// FinishReasonStop: the model produced a natural stop (e.g. stop sequence
// or end-of-turn signal).
+36 -2
View File
@@ -58,6 +58,9 @@ type Kit struct {
// When false, per-model system prompts from modelSettings/customModels
// can replace the default prompt on model switch.
hasCustomSystemPrompt bool
// systemPromptSource holds the raw configured value (file path or text)
// when hasCustomSystemPrompt is true; empty when the built-in default is in use.
systemPromptSource string
// Hook registries — interception layer (see hooks.go).
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
@@ -632,6 +635,21 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
return nil
}
// HasCustomSystemPrompt reports whether the user explicitly configured a system
// prompt via --system-prompt, a config file entry, or SDK Options.SystemPrompt.
// When false, the built-in default (or a per-model override) is in use and can
// be replaced transparently on model switch.
func (m *Kit) HasCustomSystemPrompt() bool {
return m.hasCustomSystemPrompt
}
// GetSystemPromptSource returns the raw configured value — a file path or
// inline text — when HasCustomSystemPrompt is true; returns an empty string
// when the built-in default prompt is active.
func (m *Kit) GetSystemPromptSource() string {
return m.systemPromptSource
}
// composeSystemPrompt takes a base system prompt and composes it with the
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
// This mirrors the composition done during Kit.New() initialization.
@@ -1179,6 +1197,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
maxSteps int
streaming bool
hasCustomSystemPrompt bool
systemPromptSource string
)
if err := func() error {
@@ -1285,13 +1304,27 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
// explicitly set system-prompt, use the per-model prompt as the
// base instead of the global default.
{
basePrompt := viper.GetString("system-prompt")
rawPromptInput := viper.GetString("system-prompt")
// Resolve a file path to its content so PromptBuilder receives the
// actual prompt text rather than a literal path string. Without this,
// when system-prompt is set to a file path in the config file or via
// --system-prompt, the path itself becomes the effective system prompt
// sent to the model (LoadSystemPrompt only ran later, after viper had
// been overwritten with the augmented base text).
basePrompt, _ := config.LoadSystemPrompt(rawPromptInput)
if basePrompt == "" {
basePrompt = rawPromptInput
}
// Track whether the user explicitly configured a custom system
// prompt. When they haven't (basePrompt is the built-in default
// or empty), per-model system prompts can replace it on switch.
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
hasCustomSystemPrompt = userSetSystemPrompt
if hasCustomSystemPrompt {
systemPromptSource = rawPromptInput
}
// Check for per-model system prompt override when no explicit
// global system-prompt was configured by the user.
@@ -1456,7 +1489,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
if opts.CLI != nil {
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
setupOpts.SpinnerFunc = agent.SpinnerFunc(opts.CLI.SpinnerFunc)
setupOpts.UseBufferedLogger = opts.CLI.UseBufferedLogger
if opts.CLI.ProgressReaderFunc != nil {
providerConfig.ProgressReaderFunc = opts.CLI.ProgressReaderFunc
@@ -1500,6 +1533,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
opts: opts,
mcpConfig: mcpConfig,
hasCustomSystemPrompt: hasCustomSystemPrompt,
systemPromptSource: systemPromptSource,
beforeToolCall: beforeToolCall,
afterToolResult: afterToolResult,
beforeTurn: beforeTurn,
+90
View File
@@ -3,6 +3,7 @@ package kit_test
import (
"context"
"os"
"strings"
"testing"
"github.com/spf13/viper"
@@ -306,3 +307,92 @@ func TestSessionManagement(t *testing.T) {
// resetViper wipes viper's global state so a test case doesn't leak
// viper.Set() calls into the next one. Used via defer in subtests.
func resetViper() { viper.Reset() }
// TestNewSystemPromptFilePath is a regression test for issue #25.
//
// When Options.SystemPrompt (or the --system-prompt flag / config entry) is a
// file path, Kit must resolve the path to its file contents *before* the
// PromptBuilder composes the runtime context. Previously the path string
// itself was used verbatim as the base prompt, so the LLM received the path —
// not the prompt — as its system message.
func TestNewSystemPromptFilePath(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
}
defer resetViper()
const promptContent = "You are a strict regression-test persona. Marker: KIT-25-OK"
tmpFile, err := os.CreateTemp(t.TempDir(), "kit-system-prompt-*.md")
if err != nil {
t.Fatalf("failed to create temp prompt file: %v", err)
}
if _, err := tmpFile.WriteString(promptContent); err != nil {
t.Fatalf("failed to write temp prompt file: %v", err)
}
if err := tmpFile.Close(); err != nil {
t.Fatalf("failed to close temp prompt file: %v", err)
}
ctx := context.Background()
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
SystemPrompt: tmpFile.Name(),
Quiet: true,
NoSession: true,
})
if err != nil {
t.Fatalf("Failed to create Kit with system-prompt file: %v", err)
}
defer func() { _ = host.Close() }()
if !host.HasCustomSystemPrompt() {
t.Error("HasCustomSystemPrompt() = false; want true when --system-prompt is set")
}
if got, want := host.GetSystemPromptSource(), tmpFile.Name(); got != want {
t.Errorf("GetSystemPromptSource() = %q; want %q", got, want)
}
// The composed system prompt is written back to viper after PromptBuilder
// runs. It must contain the file's contents, not the file path.
composed := viper.GetString("system-prompt")
if !strings.Contains(composed, promptContent) {
t.Errorf("composed system-prompt does not contain file contents\n composed = %q\n want substring = %q", composed, promptContent)
}
if strings.TrimSpace(composed) == tmpFile.Name() {
t.Errorf("composed system-prompt is the file path verbatim (%q); LoadSystemPrompt was not applied before PromptBuilder", composed)
}
}
// TestNewSystemPromptInline confirms that inline system-prompt strings still
// flow through unchanged after the file-path resolution change.
func TestNewSystemPromptInline(t *testing.T) {
if os.Getenv("ANTHROPIC_API_KEY") == "" {
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
}
defer resetViper()
const inline = "You are a concise inline-prompt persona."
ctx := context.Background()
host, err := kit.New(ctx, &kit.Options{
Model: "anthropic/claude-sonnet-4-5-20250929",
SystemPrompt: inline,
Quiet: true,
NoSession: true,
})
if err != nil {
t.Fatalf("Failed to create Kit with inline system-prompt: %v", err)
}
defer func() { _ = host.Close() }()
if !host.HasCustomSystemPrompt() {
t.Error("HasCustomSystemPrompt() = false; want true for inline prompt")
}
if got := host.GetSystemPromptSource(); got != inline {
t.Errorf("GetSystemPromptSource() = %q; want %q", got, inline)
}
if composed := viper.GetString("system-prompt"); !strings.Contains(composed, inline) {
t.Errorf("composed system-prompt missing inline content; got %q", composed)
}
}
+64
View File
@@ -98,6 +98,70 @@ type MCPTaskProgress struct {
// dispatched on a goroutine.
type MCPTaskProgressHandler func(MCPTaskProgress)
// MCPTaskConfig configures task-aware MCP tools/call execution. All fields
// are optional; the zero value disables progress callbacks and applies
// sensible polling defaults inside the engine.
//
// For most consumers, the flat [Options] fields (`MCPTaskMode`,
// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`,
// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point.
// MCPTaskConfig is exposed for the low-level [AgentConfig] path.
type MCPTaskConfig struct {
// PerServerMode overrides the per-server task mode resolved from
// [MCPServerConfig]. Keys are server names. Missing entries fall back
// to the configured value.
PerServerMode map[string]MCPTaskMode
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
// tools/call. Zero means omit the TTL — let the server pick its own.
DefaultTTL time.Duration
// PollInterval is the fallback interval between tasks/get requests
// when the server does not suggest one. Zero defaults to 1 second.
PollInterval time.Duration
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
MaxPollInterval time.Duration
// Timeout is the maximum wall-clock duration to wait for a task to
// reach a terminal state. Zero defaults to 15 minutes. Independent
// of the per-call context deadline; whichever fires first wins.
Timeout time.Duration
// Progress, if non-nil, receives every status transition observed by
// the polling loop.
Progress MCPTaskProgressHandler
}
// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal
// tools-package representation. Keeps the dependency arrow internal-only.
func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig {
cfg := tools.MCPTaskConfig{
DefaultTTL: c.DefaultTTL,
PollInterval: c.PollInterval,
MaxPollInterval: c.MaxPollInterval,
Timeout: c.Timeout,
}
if len(c.PerServerMode) > 0 {
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode))
for k, v := range c.PerServerMode {
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
}
}
if c.Progress != nil {
h := c.Progress
cfg.Progress = func(p tools.MCPTaskProgress) {
h(MCPTaskProgress{
Server: p.Server,
TaskID: p.TaskID,
Status: MCPTaskStatus(p.Status),
Message: p.Message,
})
}
}
return cfg
}
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
// Stored on Options as a single value so the public surface stays compact;
// individual fields are exposed via WithMCP* builder functions.
+145 -18
View File
@@ -11,6 +11,7 @@ import (
"github.com/mark3labs/kit/internal/message"
"github.com/mark3labs/kit/internal/models"
"github.com/mark3labs/kit/internal/session"
"github.com/mark3labs/kit/internal/tools"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/server"
)
@@ -75,25 +76,151 @@ type Config = config.Config
// local (stdio) and remote (StreamableHTTP/SSE) server types.
type MCPServerConfig = config.MCPServerConfig
// ==== Agent Types (internal/agent/) ====
// ==== Agent Types ====
// AgentConfig holds configuration options for creating a new Agent.
type AgentConfig = agent.AgentConfig
// DebugLogger is an SDK-owned interface for low-level debug logging from
// the engine and MCP tool plumbing. Implementations must be safe for
// concurrent use.
//
// Most consumers do not need to provide one; pass [Options.Debug] = true
// to use the default logger. DebugLogger is exposed for the low-level
// [AgentConfig] path and for embedders that want to route debug output
// into their own logging system.
type DebugLogger interface {
// LogDebug records a single debug message. Implementations may drop,
// buffer, or render the message however they choose.
LogDebug(message string)
// IsDebugEnabled reports whether debug logging is active. Callers may
// check this before doing expensive formatting work.
IsDebugEnabled() bool
}
type (
// ToolCallHandler is a function type for handling tool calls as they happen.
ToolCallHandler = agent.ToolCallHandler
// ToolExecutionHandler is a function type for handling tool execution start/end events.
ToolExecutionHandler = agent.ToolExecutionHandler
// ToolResultHandler is a function type for handling tool results.
ToolResultHandler = agent.ToolResultHandler
// ResponseHandler is a function type for handling LLM responses.
ResponseHandler = agent.ResponseHandler
// StreamingResponseHandler is a function type for handling streaming LLM responses.
StreamingResponseHandler = agent.StreamingResponseHandler
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
ToolCallContentHandler = agent.ToolCallContentHandler
)
// AgentConfig holds configuration options for constructing an agent at the
// SDK boundary. All fields use SDK-owned types, so consumers can populate
// this struct without importing any underlying LLM-provider package.
//
// For most use cases, prefer the high-level [New] entry point with
// [Options]. AgentConfig is exposed for advanced consumers that need
// direct access to the lower-level agent configuration shape.
type AgentConfig struct {
// ModelConfig holds the LLM provider configuration. A nil value means
// that the default provider/model resolution will be used.
ModelConfig *ProviderConfig
// MCPConfig describes any MCP servers whose tools should be loaded
// alongside core tools.
MCPConfig *Config
// SystemPrompt is the system prompt sent to the LLM.
SystemPrompt string
// MaxSteps caps the number of LLM iterations per turn. A value of
// zero means no cap is applied at this layer.
MaxSteps int
// StreamingEnabled controls whether the agent streams responses.
StreamingEnabled bool
// AuthHandler handles OAuth authorization for remote MCP servers.
// When nil, remote MCP servers requiring OAuth will fail to connect.
AuthHandler MCPAuthHandler
// TokenStoreFactory, if non-nil, creates a custom token store for each
// remote MCP server's OAuth tokens. When nil, the default file-based
// token store is used.
TokenStoreFactory MCPTokenStoreFactory
// CoreTools overrides the default core tool set. If empty, [AllTools]
// is used. Provide a custom tool set (e.g. [CodingTools] or tools
// built with a custom WorkDir) to scope agent capabilities.
CoreTools []Tool
// DisableCoreTools, when true, prevents loading any core tools.
// Combined with empty CoreTools this yields a chat-only agent with
// no built-in tools.
DisableCoreTools bool
// ExtraTools are additional tools loaded alongside core and MCP tools.
ExtraTools []Tool
// ToolWrapper, if non-nil, wraps the combined tool list before it is
// handed to the LLM. Used to intercept tool calls or results.
ToolWrapper func([]Tool) []Tool
// OnMCPServerLoaded, if non-nil, is invoked once for each MCP server
// when its tools have finished loading (or failed). Called from a
// background goroutine.
OnMCPServerLoaded func(serverName string, toolCount int, err error)
// DebugLogger receives low-level debug output from the engine and the
// MCP tool plumbing. Nil means no debug output is emitted at this
// layer (regardless of [Options.Debug], which feeds the higher-level
// [New] entry point). Pass an implementation here when wiring a custom
// logger through the lower-level AgentConfig path.
DebugLogger DebugLogger
// MCPTaskConfig configures task-aware MCP tools/call execution — mode
// overrides, polling intervals, timeouts, and the progress handler.
// The zero value preserves historical synchronous-only behaviour for
// any server that didn't advertise task support during initialize.
MCPTaskConfig MCPTaskConfig
}
// toInternal converts an AgentConfig to its internal representation.
// Slice and function fields convert without allocation because [Tool]
// is a type alias for the underlying LLM-tool type.
func (c *AgentConfig) toInternal() *agent.AgentConfig {
if c == nil {
return nil
}
out := &agent.AgentConfig{
ModelConfig: c.ModelConfig,
MCPConfig: c.MCPConfig,
SystemPrompt: c.SystemPrompt,
MaxSteps: c.MaxSteps,
StreamingEnabled: c.StreamingEnabled,
CoreTools: c.CoreTools,
DisableCoreTools: c.DisableCoreTools,
ExtraTools: c.ExtraTools,
ToolWrapper: c.ToolWrapper,
OnMCPServerLoaded: c.OnMCPServerLoaded,
}
if c.AuthHandler != nil {
out.AuthHandler = c.AuthHandler
}
if c.TokenStoreFactory != nil {
out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory)
}
if c.DebugLogger != nil {
out.DebugLogger = c.DebugLogger
}
out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig()
return out
}
// ToolCallHandler is invoked when the LLM produces a tool call. It receives
// the call ID, tool name, and the JSON-encoded input arguments.
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
// ToolExecutionHandler is invoked at the start and end of tool execution.
// The isStarting flag distinguishes the two phases.
type ToolExecutionHandler func(toolCallID, toolName, toolArgs string, isStarting bool)
// ToolResultHandler is invoked after a tool finishes executing. The metadata
// parameter carries optional structured data (e.g. file-diff info) from the
// tool execution, JSON-encoded; it may be empty.
type ToolResultHandler func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
// ResponseHandler is invoked with the final assistant text for each turn.
type ResponseHandler func(content string)
// StreamingResponseHandler is invoked with each streamed text delta as it
// arrives from the LLM.
type StreamingResponseHandler func(content string)
// ToolCallContentHandler is invoked with any assistant text that accompanies
// a tool call within the same step.
type ToolCallContentHandler func(content string)
// ==== Provider & Model Types (internal/models/) ====
@@ -126,7 +253,7 @@ type ModelsRegistry = models.ModelsRegistry
// SpinnerFunc wraps a function in a loading spinner animation. Used for
// Ollama model loading. Signature: func(fn func() error) error.
type SpinnerFunc = agent.SpinnerFunc
type SpinnerFunc func(fn func() error) error
// ==== LLM Types ====
//
+96
View File
@@ -1,6 +1,7 @@
package kit_test
import (
"context"
"encoding/json"
"testing"
@@ -263,6 +264,101 @@ func TestConvertFromLLMMessage(t *testing.T) {
}
}
// TestAgentConfigNoFantasyImport verifies AgentConfig can be populated with
// every field — including CoreTools, ExtraTools, and ToolWrapper — using
// only SDK-owned types. This test deliberately does not import
// "charm.land/fantasy"; the package compiling at all is the proof that the
// SDK no longer leaks the dependency name through AgentConfig.
//
// Regression test for https://github.com/mark3labs/kit/issues/30.
func TestAgentConfigNoFantasyImport(t *testing.T) {
myTool := kit.NewTool[struct{}]("noop", "does nothing", func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
return kit.TextResult("ok"), nil
})
wrapperCalled := false
cfg := kit.AgentConfig{
SystemPrompt: "you are a tester",
MaxSteps: 5,
StreamingEnabled: true,
CoreTools: []kit.Tool{myTool},
ExtraTools: []kit.Tool{myTool},
DisableCoreTools: false,
ToolWrapper: func(in []kit.Tool) []kit.Tool {
wrapperCalled = true
return in
},
OnMCPServerLoaded: func(_ string, _ int, _ error) {},
}
if cfg.SystemPrompt != "you are a tester" {
t.Errorf("SystemPrompt = %q, want %q", cfg.SystemPrompt, "you are a tester")
}
if cfg.MaxSteps != 5 {
t.Errorf("MaxSteps = %d, want 5", cfg.MaxSteps)
}
if !cfg.StreamingEnabled {
t.Error("StreamingEnabled = false, want true")
}
if len(cfg.CoreTools) != 1 {
t.Errorf("CoreTools len = %d, want 1", len(cfg.CoreTools))
}
if len(cfg.ExtraTools) != 1 {
t.Errorf("ExtraTools len = %d, want 1", len(cfg.ExtraTools))
}
// Exercise the wrapper to confirm the func type is usable.
out := cfg.ToolWrapper(cfg.CoreTools)
if !wrapperCalled {
t.Error("ToolWrapper was not invoked")
}
if len(out) != 1 {
t.Errorf("wrapped tool list len = %d, want 1", len(out))
}
}
// TestAgentConfigToolWrapperSignature documents that AgentConfig.ToolWrapper
// uses kit.Tool (not the underlying provider type) in its signature.
func TestAgentConfigToolWrapperSignature(t *testing.T) {
//nolint:staticcheck // QF1011: explicit type asserts the SDK-side func signature.
var _ func([]kit.Tool) []kit.Tool = func(in []kit.Tool) []kit.Tool { return in }
cfg := kit.AgentConfig{
ToolWrapper: func(in []kit.Tool) []kit.Tool { return in },
}
if cfg.ToolWrapper == nil {
t.Fatal("ToolWrapper assignment failed")
}
}
// TestSpinnerFuncSignature verifies SpinnerFunc has the documented signature
// and can be constructed without importing any provider package.
func TestSpinnerFuncSignature(t *testing.T) {
called := false
var sp kit.SpinnerFunc = func(fn func() error) error {
called = true
return fn()
}
err := sp(func() error { return nil })
if err != nil {
t.Errorf("SpinnerFunc returned err: %v", err)
}
if !called {
t.Error("SpinnerFunc did not invoke fn")
}
}
// TestHandlerTypesSignatures verifies the SDK-owned handler function types
// can be assigned from plain function literals using only standard library
// types in their signatures (no provider-package import required).
func TestHandlerTypesSignatures(t *testing.T) {
var _ kit.ToolCallHandler = func(_, _, _ string) {}
var _ kit.ToolExecutionHandler = func(_, _, _ string, _ bool) {}
var _ kit.ToolResultHandler = func(_, _, _, _, _ string, _ bool) {}
var _ kit.ResponseHandler = func(_ string) {}
var _ kit.StreamingResponseHandler = func(_ string) {}
var _ kit.ToolCallContentHandler = func(_ string) {}
}
// containsStr is a tiny helper to avoid importing strings in test.
func containsStr(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && indexStr(s, substr) >= 0)
-5
View File
@@ -1,5 +0,0 @@
# Specs
| Spec | Status | Description |
|------|--------|-------------|
| [unified-bubbletea-architecture](unified-bubbletea-architecture.md) | Draft | Replace micro-program pattern with single Bubble Tea program + thick app layer |