mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
29 Commits
v0.58.2
...
audit-cleanup
| Author | SHA1 | Date | |
|---|---|---|---|
| 2016570e2d | |||
| d557f4b870 | |||
| 65054fe3db | |||
| 97d2246375 | |||
| 1e12505741 | |||
| 6755597c9b | |||
| 45689cb30d | |||
| 78570d4188 | |||
| 7cf38b37ee | |||
| 4ef57eec4e | |||
| cbd828e190 | |||
| d304805106 | |||
| 6e36053856 | |||
| 92eaaf6a59 | |||
| e6084b7bd0 | |||
| 34d5abff9c | |||
| fc0ddd5f4f | |||
| 7aa6160c75 | |||
| e830bf87ca | |||
| 3881d1c28f | |||
| 53f6682bd0 | |||
| 996b15c9b9 | |||
| aeb704367c | |||
| d2e23295b6 | |||
| e5a13e2e12 | |||
| 558fb5214f | |||
| 61408ed490 | |||
| 3cfb6437f9 | |||
| d33ad4028b |
@@ -13,8 +13,6 @@
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
// - The extension runner serializes handler calls per-extension, so
|
||||
// concurrent subagent events cannot race on this shared state.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -45,8 +43,7 @@ const (
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state — safe because the runner serializes all handler
|
||||
// invocations for the same extension (per-extension reentrant mutex).
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
@@ -285,8 +282,8 @@ func Init(api ext.API) {
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry — build a new slice to avoid aliasing bugs
|
||||
newEntries := make([]*submonEntry, 0, len(submonEntries))
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
---
|
||||
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
|
||||
---
|
||||
|
||||
Perform a comprehensive **read-only** audit of this repository and report
|
||||
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
|
||||
hints from the user: $@
|
||||
|
||||
## Scope
|
||||
|
||||
If the user supplied focus hints above (a package path, a subsystem name, a
|
||||
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
|
||||
audit the whole repo, prioritising the highest-traffic packages first
|
||||
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Map the repo first**:
|
||||
- `ls` / `find` the top-level layout and list every Go package
|
||||
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
|
||||
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
|
||||
extension surface)
|
||||
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
|
||||
(e.g. "no dependency name leakage", "UI never imports extensions
|
||||
directly") — these define what counts as a violation
|
||||
|
||||
2. **Hunt for dead code**:
|
||||
- Run `go vet ./...` and capture warnings
|
||||
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
|
||||
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
|
||||
with zero non-test references inside the module are suspects
|
||||
- Check for unreferenced files, `// TODO: remove` markers, commented-out
|
||||
blocks, and `_ = x` discard patterns
|
||||
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
|
||||
them and include their output verbatim
|
||||
- **Do not delete anything** — list candidates with file:line and a
|
||||
confidence level (high / medium / low)
|
||||
|
||||
3. **Find unnecessary duplication**:
|
||||
- Look for near-identical function bodies, struct shapes, or switch
|
||||
statements across packages — `grep` for repeated function signatures
|
||||
and copy-pasted string literals / error messages is a fast first pass
|
||||
- Distinguish *coincidental* duplication (two things that happen to look
|
||||
alike but evolve independently) from *unnecessary* duplication (same
|
||||
intent, drifting in lockstep) — only flag the latter
|
||||
- For each cluster, propose where the extracted helper should live
|
||||
(which package, which file) and whether it crosses a boundary
|
||||
|
||||
4. **Check concerns / boundary violations**:
|
||||
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
|
||||
in exported signatures, and for dependency-name leakage in exported
|
||||
names / godoc (e.g. library jargon appearing in `LLM*` types)
|
||||
- **UI ↔ extensions**: grep `internal/ui/` for any import of
|
||||
`internal/extensions/` — per AGENTS.md the UI must not import
|
||||
extensions directly; converters in `cmd/root.go` should bridge them
|
||||
- **cmd vs internal**: business logic living in `cmd/` that should be
|
||||
in `internal/` (and vice versa)
|
||||
- **Cyclic risk**: packages that import each other transitively or that
|
||||
reach across sibling boundaries unexpectedly
|
||||
- For each violation, cite the offending import / signature with
|
||||
file:line
|
||||
|
||||
5. **Spot refactor opportunities**:
|
||||
- Long functions (>80 lines) doing multiple unrelated things
|
||||
- Deeply nested conditionals that flatten well with early returns
|
||||
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
|
||||
that could become helpers — but only where the wrapping context is
|
||||
genuinely uniform
|
||||
- Structs with too many fields that hint at split responsibilities
|
||||
- Exported APIs that would be cleaner with options structs / functional
|
||||
options
|
||||
- Tests that share setup boilerplate ripe for a helper
|
||||
- Flag each with: location, current shape (1-2 lines), proposed shape
|
||||
(1-2 lines), and estimated risk (low / medium / high)
|
||||
|
||||
6. **Cross-check against project rules**:
|
||||
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
|
||||
findings contradicts the documented gotchas (Yaegi interface ban,
|
||||
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
|
||||
"refactor" would reintroduce a known pitfall, drop it from the report
|
||||
and note why
|
||||
|
||||
7. **Write the report** as your final message (do not write it to disk)
|
||||
structured as:
|
||||
|
||||
```
|
||||
# Code Audit Report
|
||||
|
||||
## Summary
|
||||
- N dead-code candidates
|
||||
- N duplication clusters
|
||||
- N boundary violations
|
||||
- N refactor opportunities
|
||||
|
||||
## Dead Code
|
||||
### High confidence
|
||||
- path/to/file.go:LINE — symbol — reason
|
||||
|
||||
### Medium confidence
|
||||
...
|
||||
|
||||
## Duplication
|
||||
### Cluster: <short name>
|
||||
- Sites: file:line, file:line, …
|
||||
- Suggested home: package/path
|
||||
- Notes: …
|
||||
|
||||
## Boundary Violations
|
||||
- Rule: <which rule from AGENTS.md / project convention>
|
||||
- Offender: file:line
|
||||
- Fix sketch: …
|
||||
|
||||
## Refactor Opportunities
|
||||
- Location: file:line
|
||||
- Current: …
|
||||
- Proposed: …
|
||||
- Risk: low/medium/high
|
||||
- Why it's worth it: …
|
||||
|
||||
## Suggested Next Steps
|
||||
1. …
|
||||
2. …
|
||||
```
|
||||
|
||||
8. **End the report with an explicit reminder** that no files were modified,
|
||||
and recommend the user pick the highest-leverage items to act on
|
||||
manually (or via a follow-up `/fix-issue` style prompt) rather than
|
||||
running a sweeping refactor.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
|
||||
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
|
||||
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
|
||||
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
|
||||
straight to it
|
||||
- **Be honest about confidence**: false positives in a code audit are
|
||||
expensive — prefer "medium confidence, worth a look" over confidently
|
||||
wrong claims
|
||||
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
|
||||
anything that's purely stylistic unless it directly causes one of the
|
||||
four issue categories above
|
||||
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
|
||||
`vendor/`) and obvious third-party copies
|
||||
- **Don't propose architectural rewrites** — stay within the existing
|
||||
shape of the repo and recommend incremental, reviewable changes
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
description: Open a GitHub PR for the current branch using the repo's PR template
|
||||
---
|
||||
|
||||
Open a GitHub pull request for the current branch, filling out the repository's PR template with a description grounded in the actual commits and diff.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Verify the branch is pushed**:
|
||||
- `git status -sb` and `git log @{u}..HEAD --oneline 2>/dev/null` — if there is no upstream or unpushed commits, run `git push -u origin "$(git branch --show-current)"` first
|
||||
- If the working tree is dirty, stop and tell the user to commit first (suggest `/commit-push`)
|
||||
2. **Gather context**:
|
||||
- `git log origin/main..HEAD --oneline` — list of commits going into the PR
|
||||
- `git diff origin/main...HEAD --stat` then `git diff origin/main...HEAD` — read the actual changes
|
||||
- Identify the linked issue (from commit messages, branch name, or extra user input: $@) — capture as `Fixes #N` if applicable
|
||||
3. **Locate the PR template**:
|
||||
- Check `.github/pull_request_template.md`, `.github/PULL_REQUEST_TEMPLATE.md`, or `docs/pull_request_template.md`
|
||||
- If none exists, use a minimal `## Description` / `## Type of Change` / `## Checklist` structure
|
||||
4. **Draft the PR body** by filling out the template:
|
||||
- **Description**: 1–3 short paragraphs explaining *what* changed and *why*, grounded in the diff. Include a brief before/after example for new APIs when useful.
|
||||
- **Fixes #N**: only if there is a real linked issue
|
||||
- **Type of Change**: tick the single most accurate box with `[x]` (leave others as `[ ]`)
|
||||
- **Checklist**: tick items that are genuinely true (style, self-review, tests added, docs updated)
|
||||
- **Additional Information**: bullet list of added / modified files and any backward-compatibility notes
|
||||
- Remove template sections explicitly marked "remove if not applicable" (e.g. MCP Spec Compliance) when they don't apply
|
||||
5. **Write the body to a temp file**: `/tmp/pr-body-<branch-or-issue>.md` — never inline a long body via `--body`, always use `--body-file`
|
||||
6. **Choose the title**: prefer the subject of the primary commit if it already follows Conventional Commits; otherwise craft one in the same style (`<type>(<scope>): <imperative summary>`, ≤72 chars)
|
||||
7. **Create the PR**:
|
||||
```
|
||||
gh pr create \
|
||||
--title "<title>" \
|
||||
--body-file /tmp/pr-body-<...>.md \
|
||||
--base main \
|
||||
--head "$(git branch --show-current)"
|
||||
```
|
||||
Use the repo's actual default branch if it isn't `main` (`gh repo view --json defaultBranchRef -q .defaultBranchRef.name`)
|
||||
8. **Report the PR URL** returned by `gh` and stop
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff and commit messages — do **not** invent features that aren't in the code
|
||||
- One PR per logical change; if the branch contains unrelated commits, surface that and ask before continuing
|
||||
- Keep the description focused on reviewer-relevant information (what / why), not a replay of the diff
|
||||
- Only check checklist boxes that are actually satisfied; leave the rest unchecked rather than lying
|
||||
- If `gh` is not authenticated (`gh auth status` fails), stop and tell the user
|
||||
|
||||
$@
|
||||
@@ -2,7 +2,7 @@
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
Create a feature request for the Kit repository. The user wants to request: $@
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
@@ -16,7 +16,7 @@ This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$+`
|
||||
1. **Understand the request** from the user input: $@
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
@@ -16,7 +16,7 @@ This repository has structured issue templates. You MUST use the appropriate tem
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$+`:
|
||||
1. **Determine the issue type** from the user input: $@
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
---
|
||||
description: Implement the fix/feature/docs change requested by a GitHub issue
|
||||
---
|
||||
|
||||
Resolve GitHub issue #$1 by reading it, classifying it, and producing the appropriate code or doc change. **Stop once the working tree contains the change** — committing, pushing, and opening a PR are handled by `/commit-push` and `/create-pr`.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch the issue**:
|
||||
- Run: gh issue view $1 --json number,title,body,labels,state,author,comments
|
||||
- If the issue is closed, stop and ask the user whether to proceed
|
||||
- Read the **entire** thread including comments — the latest comment often refines the ask
|
||||
|
||||
2. **Classify the issue** from labels, title prefix, and body content:
|
||||
- `bug` / `fix:` → reproduce, then fix
|
||||
- `enhancement` / `feature` / `feat:` → design, then implement
|
||||
- `documentation` / `docs:` → locate and update docs
|
||||
- `question` / `discussion` → answer in a comment, do **not** write code
|
||||
- Anything else → ask the user how to proceed
|
||||
|
||||
3. **Create a working branch** off the default branch:
|
||||
- `git checkout main && git pull --ff-only`
|
||||
- Branch name: <type>/$1-<slug> (e.g. `fix/42-borderColor-ignored`, `feat/57-keyboard-clear`, `docs/63-widget-lifecycle`)
|
||||
|
||||
4. **Do the work** based on type:
|
||||
|
||||
### Bug (`bug` label / `fix:` title)
|
||||
- Reproduce the failure first (write a failing test if feasible) — if you cannot reproduce, comment on the issue asking for clarification and stop
|
||||
- Locate the root cause; do not patch symptoms
|
||||
- Add or extend a regression test that fails before and passes after the fix
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Feature (`enhancement` / `feature` label / `feat:` title)
|
||||
- Re-read the motivation and proposed implementation in the issue body
|
||||
- For large, ambiguous, or breaking changes, sketch the design in a comment on the issue and wait for sign-off before writing code
|
||||
- Implement behind sensible defaults; add godoc on every exported symbol
|
||||
- Add unit tests covering the new behaviour and edge cases
|
||||
- Update `README.md` / `docs/` if the public surface changed
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Documentation (`documentation` label / `docs:` title)
|
||||
- Open the file/URL referenced in the issue's "Documentation Location"
|
||||
- Apply the suggested improvement; verify code samples compile (`go build ./...`)
|
||||
- No tests required, but run `golangci-lint run` if Go files were touched
|
||||
|
||||
5. **Report**:
|
||||
- Branch name (`git branch --show-current`)
|
||||
- Summary of files changed (`git status -s`) and the diff highlights
|
||||
- Test/lint results (pass/fail with key output)
|
||||
- Suggest the next step explicitly:
|
||||
- `/commit-push` to commit with a Conventional Commit subject (the message should reference `(#$1)` and include `Fixes #$1` so merge auto-closes)
|
||||
- then `/create-pr $1` to open the pull request
|
||||
|
||||
## Guidelines
|
||||
|
||||
- This prompt **stops at a clean working tree with the change applied** — do not run `git commit`, `git push`, or `gh pr create`
|
||||
- If the issue is unclear, post a clarifying comment on the issue and stop; do not guess
|
||||
- Keep the change scoped to the issue; surface unrelated cleanups separately
|
||||
- For breaking changes or architecture shifts, propose the design on the issue first and wait for maintainer sign-off
|
||||
- If the issue is a duplicate or already fixed on `main`, comment with the reference and stop
|
||||
- Do not close the issue manually — the eventual PR's `Fixes #$1` handles that on merge
|
||||
+48
-13
@@ -2,7 +2,7 @@
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
@@ -16,30 +16,64 @@ It becomes a `/slug` slash command in the kit input box — typed as `/filename`
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
Body text of the prompt. Reference user-supplied arguments
|
||||
with positional placeholders (see "Argument placeholders" below).
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
- **Required args**: kit infers required positional args from the highest `$N` it finds *outside* backtick/tilde code fences — a stray `$2` in active prose means kit will refuse to run without 2 arguments
|
||||
|
||||
## Argument placeholders
|
||||
|
||||
kit performs shell-style substitution before sending the prompt to the model:
|
||||
|
||||
- `$1`, `$2`, … — positional arguments (1-indexed)
|
||||
- `${1}`, `${2}`, … — same, brace form (use when followed by digits/letters: `${1}_suffix`)
|
||||
- `$@` — all arguments joined by spaces (zero or more, optional)
|
||||
- `$+` — all arguments, **at least one required**
|
||||
- `$ARGUMENTS` / `${ARGUMENTS}` — alias for `$@`
|
||||
- `${@:N}` — args from the Nth onwards (1-indexed, bash-style)
|
||||
- `${@:N:L}` — `L` args starting from the Nth
|
||||
|
||||
### ⚠️ Critical: code fences and inline code preserve placeholders verbatim
|
||||
|
||||
Anything inside triple-backtick fences, `~~~` fences, or single-backtick `inline` code spans is **left untouched** so example code samples don't get corrupted. That means:
|
||||
|
||||
- An inline-coded `gh issue view $1` stays literal `$1` in the model's input ❌
|
||||
- The same command without backticks: gh issue view $1 → expands to `gh issue view 42` ✓
|
||||
|
||||
**Rule of thumb:** if you want a placeholder to substitute, keep it outside backticks and fences. If you want a literal `$1` in the output (e.g. teaching the user shell syntax), put it inside backticks.
|
||||
|
||||
### Workarounds for "I want it to look like code AND substitute"
|
||||
|
||||
1. **Drop the backticks** around just the placeholder portion — the rest can still read as a command line in prose
|
||||
2. **Use a 4-space-indented code block** instead of a triple-backtick fence — kit only skips backtick/tilde fences, so indentation-style code blocks still get substitution:
|
||||
|
||||
git push -u origin "$(git branch --show-current)"
|
||||
gh pr create --title "fix: ... (#$1)" --base main
|
||||
|
||||
3. **Bind once, reference loosely**: put `Issue: $1` at the top in prose, then leave the backticked examples literal — the model will substitute mentally
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
1. **Understand the workflow** the user described in $@ — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
4. **Decide on arguments**:
|
||||
- No args needed → omit placeholders entirely
|
||||
- One required value (issue number, PR url, file path) → use `$1`
|
||||
- Free-form trailing context → end with a single `$@` line
|
||||
- Multiple distinct values → use `$1`, `$2`, … and document each at the top
|
||||
5. **Draft the body**:
|
||||
- Open with a single sentence stating the goal, weaving in `$1`/`$@` where the value belongs
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
- **Audit every backtick and code fence**: any `$N` or `$@` inside them will not expand — was that intentional? If not, apply one of the workarounds above
|
||||
6. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
7. **Verify substitution** by mentally (or actually) replacing `$1`/`$@` with a sample value and confirming every reference resolves — and that the prompt's *own* example snippets don't accidentally bump the required-arg count (wrap illustrative `$N` examples in triple-backtick fences, not 4-space indentation, so `RequiredArgs()` ignores them)
|
||||
8. **Confirm** by showing the final file content and the slash command that activates it (e.g. `/code-review 42`)
|
||||
|
||||
## Guidelines
|
||||
|
||||
@@ -47,3 +81,4 @@ $1 $2 etc. for positional arguments.
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
- When in doubt about substitution behaviour, write the file and run `/<slug> testvalue` once to confirm — wrong placement of backticks is the #1 failure mode
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
description: Audit and update project documentation (README and docs site) for a recent change
|
||||
---
|
||||
|
||||
Review recent code changes, identify all documentation surfaces that should
|
||||
mention them, and update each one — grounded in the actual diff, not guesses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Identify the change**:
|
||||
- If the user input ($@) names a commit / PR / branch / topic, use that as the focus
|
||||
- Otherwise inspect `git log origin/main..HEAD --oneline` and `git diff origin/main...HEAD --stat` to discover what shipped on the current branch
|
||||
- Read the actual diff (`git diff origin/main...HEAD`) — never document features that aren't in the code
|
||||
|
||||
2. **Inventory the doc surfaces**:
|
||||
- `README.md` at the repo root
|
||||
- Any docs site (commonly `www/`, `docs/`, `site/`) — list its pages and identify the one(s) most thematically related to the change
|
||||
- Inline godoc / API reference comments on the new exported symbols
|
||||
- `CHANGELOG.md` if the project keeps one
|
||||
- Any `examples/` directory entries that demonstrate the affected area
|
||||
|
||||
3. **Audit each surface** with `grep`:
|
||||
- Search for the names of related existing APIs (e.g. if you added `IterTools`, grep for `ListTools`) to find every page that already discusses the area
|
||||
- Decide for each hit: does it need a cross-reference, a side-by-side comparison, or to stay untouched?
|
||||
|
||||
4. **Decide where new content lives**:
|
||||
- Prefer extending an existing page over creating a new one
|
||||
- For a docs site, place new sections near related content (check the page's `## Heading` outline first)
|
||||
- Skip surfaces that genuinely don't apply (e.g. a server-focused README for a client-only change) and say so explicitly
|
||||
|
||||
5. **Draft the updates**:
|
||||
- Lead with a one-sentence statement of what's new and why
|
||||
- Show concrete code examples copied from real signatures — verify against the source files
|
||||
- Include a comparison / "when to use which" table when adding an alternative to an existing API
|
||||
- Note backwards-compatibility behaviour if relevant
|
||||
|
||||
6. **Verify the docs build** before committing:
|
||||
- For vocs / docusaurus / mkdocs sites, run the local build command (e.g. `npx vocs build`, `mkdocs build`) and fix any MDX/markdown errors
|
||||
- For godoc, run `go vet ./...` and `go doc <pkg> <Symbol>` to sanity-check rendering
|
||||
|
||||
7. **Report**:
|
||||
- List every file changed and every file deliberately left alone (with a one-line reason)
|
||||
- Suggest the next step (typically `/commit-push`) — do not auto-commit unless asked
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff before writing anything — invented API names erode trust faster than missing docs
|
||||
- One change per doc commit; keep doc updates separate from code changes when possible
|
||||
- Match the existing voice and formatting of each surface (headings, code-fence languages, table styles)
|
||||
- Prefer linking between pages over duplicating content
|
||||
|
||||
$@
|
||||
@@ -29,7 +29,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
- **Session Management**: Tree-based conversation history with branching support
|
||||
- **Non-Interactive Mode**: Script-friendly positional args with JSON output
|
||||
- **ACP Server**: Run Kit as an [Agent Client Protocol](https://agentclientprotocol.com) agent over stdio
|
||||
- **Go SDK**: Embed Kit in your own applications
|
||||
- **Go SDK**: Embed Kit in your own applications with full agent lifecycle events (30+ event types) and behavior-modifying hooks
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -162,6 +162,11 @@ mcpServers:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers that don't require auth
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
url: "https://builds.mcp.example.com"
|
||||
tasksMode: always # async task execution — see MCP Tasks below
|
||||
```
|
||||
|
||||
## CLI Reference
|
||||
@@ -626,6 +631,36 @@ in a custom `MCPTokenStoreFactory` for encrypted, DB-backed, or in-memory
|
||||
storage. See the [SDK options docs](/sdk/options#mcp-oauth-authorization) for
|
||||
the full matrix.
|
||||
|
||||
### MCP Tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`, so cooperating MCP servers can respond to `tools/call`
|
||||
with a `taskId` instead of blocking the connection. Kit then polls
|
||||
`tasks/get` / `tasks/result` until the task reaches a terminal state, and
|
||||
best-effort `tasks/cancel`s on context cancellation.
|
||||
|
||||
Defaults are safe — a server that doesn't advertise task capability runs
|
||||
synchronously, exactly as before. Opt in per server via `tasksMode` in
|
||||
`.kit.yml` (`auto` | `never` | `always`) or programmatically through the SDK:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskTimeout: 15 * time.Minute,
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s: %s", p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
```
|
||||
|
||||
See the [configuration docs](/configuration#mcp-tasks-long-running-tools) and
|
||||
[SDK options → MCP Tasks](/sdk/options#mcp-tasks) for the full surface.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
@@ -646,7 +681,28 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. Binary data (images, audio, etc.) in `ToolOutput.Data` is automatically forwarded to the LLM when `MediaType` is set. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
#### Return Helpers
|
||||
|
||||
| Helper | Description |
|
||||
| --- | --- |
|
||||
| `kit.TextResult(content)` | Successful text result |
|
||||
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
|
||||
| `kit.ImageResult(content, data, mediaType)` | Image result with binary data (e.g. `"image/png"`) |
|
||||
| `kit.MediaResult(content, data, mediaType)` | Non-image media result (e.g. `"audio/mpeg"`) |
|
||||
|
||||
#### ToolOutput Fields
|
||||
|
||||
```go
|
||||
kit.ToolOutput{
|
||||
Content: "result text", // text returned to the LLM
|
||||
IsError: false, // true = LLM sees this as an error
|
||||
Data: pngBytes, // optional binary data (images, audio)
|
||||
MediaType: "image/png", // MIME type for binary Data
|
||||
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
|
||||
}
|
||||
```
|
||||
|
||||
### With Callbacks
|
||||
|
||||
@@ -663,7 +719,7 @@ unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
unsub3 := host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
@@ -0,0 +1,473 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// extensionContextDeps groups the runtime dependencies needed to wire up
|
||||
// an extensions.Context for the interactive TUI mode.
|
||||
type extensionContextDeps struct {
|
||||
ctx context.Context
|
||||
cwd string
|
||||
modelName string
|
||||
interactive bool
|
||||
kitInstance *kit.Kit
|
||||
appInstance *app.App
|
||||
usageTracker *ui.UsageTracker
|
||||
}
|
||||
|
||||
// buildInteractiveExtensionContext returns an extensions.Context with every
|
||||
// field except Print / PrintInfo / PrintError populated. Callers must set
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
+31
-793
@@ -169,12 +169,6 @@ func InitConfig() {
|
||||
models.ReloadGlobalRegistry()
|
||||
}
|
||||
|
||||
// LoadConfigWithEnvSubstitution loads a config file with environment variable
|
||||
// substitution. Delegates to the SDK implementation.
|
||||
func LoadConfigWithEnvSubstitution(configPath string) error {
|
||||
return kit.LoadConfigWithEnvSubstitution(configPath)
|
||||
}
|
||||
|
||||
// adaptiveOrDefault converts a config.AdaptiveColor to a resolved color.Color,
|
||||
// falling back to fallback when both Light and Dark are empty.
|
||||
func adaptiveOrDefault(ac config.AdaptiveColor, fallback color.Color) color.Color {
|
||||
@@ -850,759 +844,42 @@ func runNormalMode(ctx context.Context) error {
|
||||
// Set up extension context and emit SessionStart.
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
cwd, _ := os.Getwd()
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintInfo: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintError: func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
},
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
// Find skill by name
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
// Inject via SendMessage as a system context message
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: kitInstance.DiscoverSkillsForExtension,
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
extCtx := buildInteractiveExtensionContext(extensionContextDeps{
|
||||
ctx: ctx,
|
||||
cwd: cwd,
|
||||
modelName: modelName,
|
||||
interactive: positionalPrompt == "",
|
||||
kitInstance: kitInstance,
|
||||
appInstance: appInstance,
|
||||
usageTracker: usageTracker,
|
||||
})
|
||||
extCtx.Print = func(text string) {
|
||||
// Capture messages during startup, print after startup banner.
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
}
|
||||
extCtx.PrintInfo = func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
}
|
||||
extCtx.PrintError = func(text string) {
|
||||
startupExtensionMessages = append(startupExtensionMessages, text)
|
||||
}
|
||||
kitInstance.Extensions().SetContext(extCtx)
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
|
||||
// Restore normal print functions for runtime use.
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
CWD: cwd,
|
||||
Model: modelName,
|
||||
Interactive: positionalPrompt == "",
|
||||
Print: func(text string) { appInstance.PrintFromExtension("", text) },
|
||||
PrintInfo: func(text string) { appInstance.PrintFromExtension("info", text) },
|
||||
PrintError: func(text string) { appInstance.PrintFromExtension("error", text) },
|
||||
PrintBlock: appInstance.PrintBlockFromExtension,
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
// In-process subagent via SDK.
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
// Bridge SDK events to extension SubagentEvents.
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: kitInstance.GetChildren,
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ParseTemplate: kit.ParseTemplate,
|
||||
RenderTemplate: kit.RenderTemplate,
|
||||
ParseArguments: kit.ParseArguments,
|
||||
SimpleParseArguments: kit.SimpleParseArguments,
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge) - Second Context
|
||||
// -------------------------------------------------------------------------
|
||||
ResolveModelChain: kit.ResolveModelChain,
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: kit.CheckModelAvailable,
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
extCtx = buildInteractiveExtensionContext(extensionContextDeps{
|
||||
ctx: ctx,
|
||||
cwd: cwd,
|
||||
modelName: modelName,
|
||||
interactive: positionalPrompt == "",
|
||||
kitInstance: kitInstance,
|
||||
appInstance: appInstance,
|
||||
usageTracker: usageTracker,
|
||||
})
|
||||
extCtx.Print = func(text string) { appInstance.PrintFromExtension("", text) }
|
||||
extCtx.PrintInfo = func(text string) { appInstance.PrintFromExtension("info", text) }
|
||||
extCtx.PrintError = func(text string) { appInstance.PrintFromExtension("error", text) }
|
||||
kitInstance.Extensions().SetContext(extCtx)
|
||||
}
|
||||
|
||||
// Convert extension commands to UI-layer type for the interactive TUI.
|
||||
@@ -2184,42 +1461,3 @@ func runInteractiveModeBubbleTea(_ context.Context, appInstance *app.App, modelN
|
||||
_, runErr := program.Run()
|
||||
return runErr
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension-facing
|
||||
// SubagentEvent. Returns a zero-value event (Type=="") for events that
|
||||
// don't map to anything useful.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func main() {
|
||||
}
|
||||
})
|
||||
// Subscribe to streaming chunks.
|
||||
host3.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
host3.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
|
||||
|
||||
@@ -5,24 +5,24 @@ go 1.26.2
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.6
|
||||
charm.land/fantasy v0.19.0
|
||||
charm.land/fantasy v0.23.0
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/alecthomas/chroma/v2 v2.24.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.12.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/coder/acp-go-sdk v0.12.2
|
||||
github.com/fsnotify/fsnotify v1.10.1
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.49.0
|
||||
github.com/mark3labs/mcp-go v0.51.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
@@ -37,21 +37,21 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0 // indirect
|
||||
github.com/aws/smithy-go v1.25.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
||||
github.com/aws/smithy-go v1.25.1 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -59,9 +59,9 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -69,30 +69,31 @@ require (
|
||||
github.com/dlclark/regexp2 v1.12.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.2 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.18 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.8 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.5.2 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.7 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.13 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
@@ -115,10 +116,10 @@ require (
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.276.0 // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/api v0.277.0 // indirect
|
||||
google.golang.org/genai v1.55.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect
|
||||
google.golang.org/grpc v1.81.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -129,7 +130,7 @@ require (
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
|
||||
@@ -2,8 +2,8 @@ charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
|
||||
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
|
||||
charm.land/fantasy v0.19.0 h1:fnNXkIJ/xcIW3sdVtWxjtQGpWWe8pDGhBCWSHkgbrd0=
|
||||
charm.land/fantasy v0.19.0/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/fantasy v0.23.0 h1:pocjwC5CxfEg1Bpwb0raML2d5ijo3op33Mmd6hYJyo4=
|
||||
charm.land/fantasy v0.23.0/go.mod h1:4yzSsd9XmFEVjRnF1P0LTEbLTmQX6OLnPkrHaf7iruo=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
@@ -28,42 +28,42 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1 h1:m5ffpfZbIb++k8AqFEKy9uVgY12xIQtBsQlc6DfZJQM=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1/go.mod h1:l+ohZ9xRXIbGe7cIW+YZgOGbvuVLjMps/FYN/CwuabI=
|
||||
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 h1:adBsCIIpLbLmYnkQU+nAChU5yhVTvu5PerROm+/Kq2A=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9/go.mod h1:uOYhgfgThm/ZyAuJGNQ5YgNyOlYfqnGpTHXvk3cpykg=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16 h1:Q0iQ7quUgJP0F/SCRTieScnaMdXr9h/2+wze1u3cNeM=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16/go.mod h1:duCCnJEFqpt2RC6no1iK6q+8HpwOAkiUua0pY507dQc=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15 h1:fyvgWTszojq8hEnMi8PPBTvZdTtEVmAVyo+NFLHBhH4=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15/go.mod h1:gJiYyMOjNg8OEdRWOf3CrFQxM2a98qmrtjx1zuiQfB8=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22 h1:IOGsJ1xVWhsi+ZO7/NW8OuZZBtMJLZbk4P5HDjJO0jQ=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22/go.mod h1:b+hYdbU+jGKfXE8kKM6g1+h+L/Go3vMvzlxBsiuGsxg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 h1:FPXsW9+gMuIeKmz7j6ENWcWtBGTe1kH8r9thNt5Uxx4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23/go.mod h1:7J8iGMdRKk6lw2C+cMIphgAnT8uTwBwNOsGkyOCm80U=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 h1:HtOTYcbVcGABLOVuPYaIihj6IlkqubBwFj10K5fxRek=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8/go.mod h1:VsK9abqQeGlzPgUr+isNWzPlK2vKe9INMLWnY65f5Xs=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 h1:PUmZeJU6Y1Lbvt9WFuJ0ugUK2xn6hIWUBBbKuOWF30s=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22/go.mod h1:nO6egFBoAaoXze24a2C0NjQCvdpk8OueRoYimvEB9jo=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10 h1:a1Fq/KXn75wSzoJaPQTgZO0wHGqE9mjFnylnqEPTchA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10/go.mod h1:p6+MXNxW7IA6dMgHfTAzljuwSKD0NCm/4lbS4t6+7vI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16 h1:x6bKbmDhsgSZwv6q19wY/u3rLk/3FGjJWyqKcIRufpE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16/go.mod h1:CudnEVKRtLn0+3uMV0yEXZ+YZOKnAtUJ5DmDhilVnIw=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20 h1:oK/njaL8GtyEihkWMD4k3VgHCT64RQKkZwh0DG5j8ak=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20/go.mod h1:JHs8/y1f3zY7U5WcuzoJ/yAYGYtNIVPKLIbp61euvmg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0 h1:ks8KBcZPh3PYISr5dAiXCM5/Thcuxk8l+PG4+A0exds=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0/go.mod h1:pFw33T0WLvXU3rw1WBkpMlkgIn54eCB5FYLhjDc9Foo=
|
||||
github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U=
|
||||
github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
@@ -86,8 +86,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7 h1:PbFxahSfyADcQOp+7WxbeqN3wX37KA/Rk+EXOW1xS9Q=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be h1:j7w8VP/D4lu5+/4GamMmFy8nrtadcl82/fjvDgSHwLo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -98,14 +98,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5 h1:3ElWZRQqSRqML2P/r2TmuSkdXPMDI+Jg3f0bGA6Ekg4=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 h1:rByFKh9JgQScu7oy0+TlUbC2e93woW/QNZmNXbbbw/E=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5 h1:QqpW1CPNAnOpM3Nj0X7IT2IFlR90bLdAkO5+A3Hwbi4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 h1:PMjHdSo8Vpq9psUw9BoHo9JLPMkm9Hqb+Whk64n3AQQ=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -124,8 +124,8 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.12.0 h1:GoIC6RrkPMBIVQ3ckSkl+bO/ERV/IRK6clBdZmx4Uf4=
|
||||
github.com/coder/acp-go-sdk v0.12.0/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/coder/acp-go-sdk v0.12.2 h1:fpRJ8Z5HMSr5cZ5IywzFlFZcIxZOsto+laNVu7XelFA=
|
||||
github.com/coder/acp-go-sdk v0.12.2/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
@@ -146,10 +146,10 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 h1:2WmHkJINIjgXXYDGik8d3oJvFA3DAwPy00csDJ3vo+o=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
|
||||
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -167,8 +167,8 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -187,14 +187,14 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.2 h1:52gGOx4ZwbLEiOyDMNA1ax2WktKlrKsmV6Ydf9Tw3/I=
|
||||
github.com/kaptinlin/go-i18n v0.4.2/go.mod h1:IACLIi+sHn3pGyryFMiqr2N1CJry4OKFD0MAEneEVQk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.18 h1:EDUXT4WKpOKguU7oaFv6VaNatN7uHFe6dEYHX0+OFxs=
|
||||
github.com/kaptinlin/jsonpointer v0.4.18/go.mod h1:ndmfvrqrEDSbV3F7yGaOuDvr29WrxYU1aqkvef9L2do=
|
||||
github.com/kaptinlin/jsonschema v0.7.8 h1:aHv28bYtfLfUXYI/10Phb1nvVyLXNz1lmu73vtKmlOY=
|
||||
github.com/kaptinlin/jsonschema v0.7.8/go.mod h1:cz7SK0jTHdabKdQp+SwBKKmOeZ55txuNo72Jx9Sbb2w=
|
||||
github.com/kaptinlin/messageformat-go v0.5.2 h1:E+D5oQVRepHgyMiLWRHnPXYFbqBDI4Sek7/CTIAByj4=
|
||||
github.com/kaptinlin/messageformat-go v0.5.2/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/kaptinlin/go-i18n v0.4.7 h1:apjIIZHnGRyrkiX3vHj07F1BF6D0JLmV+VGSr1781Jc=
|
||||
github.com/kaptinlin/go-i18n v0.4.7/go.mod h1:+i1J0pFq/9i9ESC5qRMVkKwC+mdQTABhhBExpYOlbeM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 h1:WVkwQbeerbHFcoXG7Yo/mlQhhZjWiTnagECEfwDXXa0=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21/go.mod h1:Mo7+DX8RlQTFqS4dnYJl0izSP4ob+Rl5xO/mGDETgaU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 h1:m9ZE/fCjnsk8bdkv7Qs56L/ZoHbmQqhz9mRZSAQLU5g=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3/go.mod h1:2KOZ/hgo/SveZ+uyi7vPUpUXieX65Mppzbc3VpGyqKs=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -203,10 +203,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.49.0 h1:7Ssx4d7/T86qnWoJIdye7wEEvUzv39UIbnZb/FqUZMY=
|
||||
github.com/mark3labs/mcp-go v0.49.0/go.mod h1:BflTAZAzXlrTpiO44gmjMu89n2FO56rJ9m31fp4zd5k=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mark3labs/mcp-go v0.51.0 h1:e8AhEfxzcYt7XqYzwT7uzWNhnqpu3H1Tn7dEJB9Ygj8=
|
||||
github.com/mark3labs/mcp-go v0.51.0/go.mod h1:Zg9cB2HdwdMMVgY0xtTzq3KvYIOJQDsaut+jWjwDaQY=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
@@ -223,8 +223,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -238,6 +238,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -310,18 +312,18 @@ golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY=
|
||||
google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 h1:XF8+t6QQiS0o9ArVan/HW8Q7cycNPGsJf6GA2nXxYAg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/api v0.277.0 h1:HJfyJUiNeBBUMai7ez8u14wkp/gH/I4wpGbbO9o+cSk=
|
||||
google.golang.org/api v0.277.0/go.mod h1:B9TqLBwJqVjp1mtt7WeoQwWRwvu/400y5lETOql+giQ=
|
||||
google.golang.org/genai v1.55.0 h1:iLHGk4Bj/IZ/GNNZb7hYqwSJMRBvqLeu2Hb6YQ+rYGw=
|
||||
google.golang.org/genai v1.55.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4 h1:2iMJZntwvmfgtse+s744JY7v7PgEdSBuFYXucvpOHNM=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:v14kaaboYyXQ1Gsu489Q+Hg/oN4B33mWtuOhF1HCeXA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw=
|
||||
google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -185,6 +185,26 @@ func (a *Agent) ListSessions(_ context.Context, _ acp.ListSessionsRequest) (acp.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CloseSession cancels any ongoing work for the session and frees its resources.
|
||||
func (a *Agent) CloseSession(_ context.Context, params acp.CloseSessionRequest) (acp.CloseSessionResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
log.Debug("acp: close session", "session", sessionID)
|
||||
sess.cancelPrompt()
|
||||
a.registry.remove(sessionID)
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
// ResumeSession is not supported — Kit doesn't persist sessions across
|
||||
// restarts in ACP mode. Clients should use NewSession instead.
|
||||
func (a *Agent) ResumeSession(_ context.Context, _ acp.ResumeSessionRequest) (acp.ResumeSessionResponse, error) {
|
||||
return acp.ResumeSessionResponse{}, fmt.Errorf("resume session not supported")
|
||||
}
|
||||
|
||||
// SetSessionConfigOption handles session configuration changes. Currently
|
||||
// supports the "model" config option to change the active model for a session.
|
||||
func (a *Agent) SetSessionConfigOption(ctx context.Context, params acp.SetSessionConfigOptionRequest) (acp.SetSessionConfigOptionResponse, error) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -152,38 +153,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
@@ -232,6 +202,20 @@ func (r *sessionRegistry) closeAll() {
|
||||
}
|
||||
}
|
||||
|
||||
// remove closes and removes a single session by ID.
|
||||
func (r *sessionRegistry) remove(sessionID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
sess, ok := r.sessions[sessionID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if sess.kit != nil {
|
||||
_ = sess.kit.Close()
|
||||
}
|
||||
delete(r.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cancelPrompt cancels the current prompt for a session, if any.
|
||||
func (s *acpSession) cancelPrompt() {
|
||||
s.cancelMu.Lock()
|
||||
@@ -255,40 +239,3 @@ func (s *acpSession) clearCancel() {
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
+302
-70
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
@@ -58,6 +59,11 @@ type AgentConfig struct {
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour for any
|
||||
// server that didn't advertise task support during initialize.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -126,6 +132,76 @@ type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// StepStartHandler is called when a new LLM step begins within a turn.
|
||||
type StepStartHandler func(stepNumber int)
|
||||
|
||||
// StepFinishHandler is called when a step completes with full context.
|
||||
type StepFinishHandler func(stepNumber int, hasToolCalls bool, finishReason string, usage fantasy.Usage)
|
||||
|
||||
// TextStartHandler is called when the LLM begins generating text content.
|
||||
type TextStartHandler func(id string)
|
||||
|
||||
// TextEndHandler is called when the LLM finishes generating text content.
|
||||
type TextEndHandler func(id string)
|
||||
|
||||
// ReasoningStartHandler is called when the LLM begins reasoning/thinking.
|
||||
type ReasoningStartHandler func(id string)
|
||||
|
||||
// WarningsHandler is called when the LLM provider returns warnings.
|
||||
type WarningsHandler func(warnings []string)
|
||||
|
||||
// SourceHandler is called when the LLM references a source.
|
||||
type SourceHandler func(sourceType, id, url, title string)
|
||||
|
||||
// StreamFinishHandler is called when a per-step LLM stream completes.
|
||||
type StreamFinishHandler func(usage fantasy.Usage, finishReason string)
|
||||
|
||||
// ErrorHandler is called when an agent-level error occurs.
|
||||
type ErrorHandler func(err error)
|
||||
|
||||
// RetryHandler is called when the LLM request is retried.
|
||||
type RetryHandler func(attempt int, err error)
|
||||
|
||||
// PrepareStepHandler is called between steps to allow message modification.
|
||||
// It receives the step number and current messages, and returns replacement
|
||||
// messages (or nil to keep unchanged).
|
||||
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
|
||||
|
||||
// GenerateCallbacks consolidates all callback functions for
|
||||
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
|
||||
// 16+ positional callback parameters, making it easier to add new callbacks
|
||||
// without breaking existing callers (new fields default to nil).
|
||||
type GenerateCallbacks struct {
|
||||
OnToolCall ToolCallHandler
|
||||
OnToolExecution ToolExecutionHandler
|
||||
OnToolResult ToolResultHandler
|
||||
OnResponse ResponseHandler
|
||||
OnToolCallContent ToolCallContentHandler
|
||||
OnStreamingResponse StreamingResponseHandler
|
||||
OnReasoningDelta ReasoningDeltaHandler
|
||||
OnReasoningComplete ReasoningCompleteHandler
|
||||
OnToolOutput ToolOutputHandler
|
||||
OnStepMessages StepMessagesHandler
|
||||
OnStepUsage StepUsageHandler
|
||||
OnPasswordPrompt PasswordPromptHandler
|
||||
OnToolCallStart ToolCallStartHandler
|
||||
OnToolCallDelta ToolCallDeltaHandler
|
||||
OnToolCallEnd ToolCallEndHandler
|
||||
|
||||
// New callbacks for previously unwired Fantasy lifecycle events.
|
||||
OnStepStart StepStartHandler
|
||||
OnStepFinish StepFinishHandler
|
||||
OnTextStart TextStartHandler
|
||||
OnTextEnd TextEndHandler
|
||||
OnReasoningStart ReasoningStartHandler
|
||||
OnWarnings WarningsHandler
|
||||
OnSource SourceHandler
|
||||
OnStreamFinish StreamFinishHandler
|
||||
OnError ErrorHandler
|
||||
OnRetry RetryHandler
|
||||
OnPrepareStep PrepareStepHandler
|
||||
}
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
@@ -160,6 +236,10 @@ type Agent struct {
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpTaskConfig is stored from AgentConfig so AddMCPServer() can
|
||||
// propagate it to a lazily-created MCPToolManager.
|
||||
mcpTaskConfig tools.MCPTaskConfig
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
@@ -258,6 +338,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
mcpTaskConfig: agentConfig.MCPTaskConfig,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
@@ -277,6 +358,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
// Apply task-augmented tool execution config (zero value = no-op).
|
||||
toolManager.SetTaskConfig(agentConfig.MCPTaskConfig)
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
@@ -423,13 +506,20 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
// The agent handles the tool call loop internally.
|
||||
//
|
||||
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
|
||||
// struct and is easier to extend with new callbacks.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
@@ -444,6 +534,31 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onToolCallDelta ToolCallDeltaHandler,
|
||||
onToolCallEnd ToolCallEndHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
OnStreamingResponse: onStreamingResponse,
|
||||
OnReasoningDelta: onReasoningDelta,
|
||||
OnReasoningComplete: onReasoningComplete,
|
||||
OnToolOutput: onToolOutput,
|
||||
OnStepMessages: onStepMessages,
|
||||
OnStepUsage: onStepUsage,
|
||||
OnPasswordPrompt: onPasswordPrompt,
|
||||
OnToolCallStart: onToolCallStart,
|
||||
OnToolCallDelta: onToolCallDelta,
|
||||
OnToolCallEnd: onToolCallEnd,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Message,
|
||||
cb GenerateCallbacks,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Wait for background MCP tool loading to complete and rebuild the
|
||||
// fantasy agent with the full tool set. This is a no-op when no MCP
|
||||
@@ -451,13 +566,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
a.ensureMCPTools()
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
if cb.OnToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, cb.OnToolOutput)
|
||||
}
|
||||
|
||||
// Inject password prompt handler into context for use by bash tool.
|
||||
if onPasswordPrompt != nil {
|
||||
ctx = core.ContextWithPasswordPrompt(ctx, onPasswordPrompt)
|
||||
if cb.OnPasswordPrompt != nil {
|
||||
ctx = core.ContextWithPasswordPrompt(ctx, cb.OnPasswordPrompt)
|
||||
}
|
||||
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
@@ -477,9 +592,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
onToolCallContent != nil || onStreamingResponse != nil || onReasoningDelta != nil ||
|
||||
onToolCallStart != nil || onToolCallDelta != nil || onToolCallEnd != nil
|
||||
hasCallbacks := cb.OnToolCall != nil || cb.OnToolExecution != nil || cb.OnToolResult != nil ||
|
||||
cb.OnToolCallContent != nil || cb.OnStreamingResponse != nil || cb.OnReasoningDelta != nil ||
|
||||
cb.OnToolCallStart != nil || cb.OnToolCallDelta != nil || cb.OnToolCallEnd != nil ||
|
||||
cb.OnStepStart != nil || cb.OnStepFinish != nil || cb.OnTextStart != nil ||
|
||||
cb.OnTextEnd != nil || cb.OnReasoningStart != nil || cb.OnWarnings != nil ||
|
||||
cb.OnSource != nil || cb.OnStreamFinish != nil || cb.OnError != nil ||
|
||||
cb.OnRetry != nil || cb.OnPrepareStep != nil
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
@@ -488,9 +607,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// input) were persisted incrementally via cb.OnStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
// stepCounter tracks the current step number for StepStart/StepFinish events.
|
||||
var stepCounter int
|
||||
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
@@ -503,8 +624,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onToolCallStart != nil {
|
||||
onToolCallStart(id, toolName)
|
||||
if cb.OnToolCallStart != nil {
|
||||
cb.OnToolCallStart(id, toolName)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -512,8 +633,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onToolCallDelta != nil {
|
||||
onToolCallDelta(id, delta)
|
||||
if cb.OnToolCallDelta != nil {
|
||||
cb.OnToolCallDelta(id, delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -521,8 +642,39 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onToolCallEnd != nil {
|
||||
onToolCallEnd(id)
|
||||
if cb.OnToolCallEnd != nil {
|
||||
cb.OnToolCallEnd(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text start/end callbacks
|
||||
OnTextStart: func(id string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnTextStart != nil {
|
||||
cb.OnTextStart(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnTextEnd: func(id string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnTextEnd != nil {
|
||||
cb.OnTextEnd(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning start callback
|
||||
OnReasoningStart: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnReasoningStart != nil {
|
||||
cb.OnReasoningStart(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -532,8 +684,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningDelta != nil {
|
||||
onReasoningDelta(delta)
|
||||
if cb.OnReasoningDelta != nil {
|
||||
cb.OnReasoningDelta(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -543,8 +695,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
if cb.OnReasoningComplete != nil {
|
||||
cb.OnReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -554,8 +706,64 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onStreamingResponse != nil {
|
||||
onStreamingResponse(text)
|
||||
if cb.OnStreamingResponse != nil {
|
||||
cb.OnStreamingResponse(text)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Warnings callback
|
||||
OnWarnings: func(warnings []fantasy.CallWarning) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnWarnings != nil {
|
||||
strs := make([]string, len(warnings))
|
||||
for i, w := range warnings {
|
||||
strs[i] = w.Message
|
||||
}
|
||||
cb.OnWarnings(strs)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Source callback
|
||||
OnSource: func(source fantasy.SourceContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnSource != nil {
|
||||
cb.OnSource(string(source.SourceType), source.ID, source.URL, source.Title)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Stream finish callback (per-step stream completion)
|
||||
OnStreamFinish: func(usage fantasy.Usage, finishReason fantasy.FinishReason, _ fantasy.ProviderMetadata) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnStreamFinish != nil {
|
||||
cb.OnStreamFinish(usage, string(finishReason))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Error callback
|
||||
OnError: func(err error) {
|
||||
if cb.OnError != nil {
|
||||
cb.OnError(err)
|
||||
}
|
||||
},
|
||||
|
||||
// Step start callback
|
||||
OnStepStart: func(stepNumber int) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
stepCounter = stepNumber
|
||||
if cb.OnStepStart != nil {
|
||||
cb.OnStepStart(stepNumber)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -568,13 +776,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
if onToolCall != nil {
|
||||
onToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
|
||||
if cb.OnToolCall != nil {
|
||||
cb.OnToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
|
||||
}
|
||||
|
||||
// Notify tool execution starting
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -586,14 +794,14 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return ctx.Err()
|
||||
}
|
||||
// Notify tool execution finished
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
}
|
||||
|
||||
if onToolResult != nil {
|
||||
if cb.OnToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
onToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -607,8 +815,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
if cb.OnStepMessages != nil && len(step.Messages) > 0 {
|
||||
cb.OnStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
@@ -618,65 +826,88 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// Check if step has text content alongside tool calls
|
||||
text := step.Content.Text()
|
||||
toolCalls := step.Content.ToolCalls()
|
||||
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
|
||||
onToolCallContent(text)
|
||||
if text != "" && len(toolCalls) > 0 && cb.OnToolCallContent != nil {
|
||||
cb.OnToolCallContent(text)
|
||||
}
|
||||
// Emit step usage for real-time cost tracking
|
||||
if onStepUsage != nil {
|
||||
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
if cb.OnStepUsage != nil {
|
||||
cb.OnStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
|
||||
}
|
||||
// Emit unified step finish event
|
||||
if cb.OnStepFinish != nil {
|
||||
cb.OnStepFinish(stepCounter, len(toolCalls) > 0, string(step.FinishReason), step.Usage)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
// Always wire up PrepareStep to handle both steering and the
|
||||
// OnPrepareStep hook. Steering drains its channel first, then
|
||||
// OnPrepareStep hooks run against the (possibly already steered)
|
||||
// messages.
|
||||
steerCh := steerChFromContext(ctx)
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
hasSteering := steerCh != nil
|
||||
hasPrepareStepHook := cb.OnPrepareStep != nil
|
||||
|
||||
if hasSteering || hasPrepareStepHook {
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
|
||||
// Phase 1: Drain steering channel (if present).
|
||||
if hasSteering {
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
done:
|
||||
if len(steered) > 0 {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Run OnPrepareStep hook (if registered).
|
||||
if hasPrepareStepHook {
|
||||
if replacement := cb.OnPrepareStep(opts.StepNumber, result.Messages); replacement != nil {
|
||||
result.Messages = replacement
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wire OnRetry callback if provided.
|
||||
if cb.OnRetry != nil {
|
||||
streamCall.OnRetry = func(err *fantasy.ProviderError, _ time.Duration) {
|
||||
// Use the retry number from the error if available; Fantasy
|
||||
// doesn't pass a counter directly, so we approximate with a
|
||||
// counter incremented on each call.
|
||||
cb.OnRetry(0, err)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
@@ -702,8 +933,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// empty (e.g. reasoning-only responses) so the UI properly resets
|
||||
// the stream component and avoids duplicate content on the next
|
||||
// flush.
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
if cb.OnResponse != nil {
|
||||
cb.OnResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
r := convertAgentResult(result, messages)
|
||||
@@ -723,8 +954,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
// For non-streaming, fire the response callback so callers can reset
|
||||
// streaming state (see streaming path comment above).
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
if cb.OnResponse != nil {
|
||||
cb.OnResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
@@ -915,6 +1146,7 @@ func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPSer
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetTaskConfig(a.mcpTaskConfig)
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
|
||||
@@ -56,6 +56,8 @@ type AgentCreationOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -76,6 +78,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
+26
-7
@@ -923,7 +923,7 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
a.recordStepUsage(ev, stepUsageSeen, sendFn)
|
||||
case kit.PasswordPromptEvent:
|
||||
// Convert SDK PasswordPromptEvent to app PasswordPromptEvent
|
||||
// The TUI will handle this and send the response back
|
||||
@@ -1241,7 +1241,16 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
//
|
||||
// Both session totals (cost, token counts) and the context window fill level
|
||||
// are updated here so the status bar reflects progress after every LLM call,
|
||||
// not just at the end of the full turn. Context fill monotonically increases
|
||||
// across steps because each step re-sends the entire conversation plus any
|
||||
// new tool results, so the numbers only go up.
|
||||
//
|
||||
// sendFn is called with a UsageUpdatedEvent to trigger a TUI re-render so
|
||||
// the updated values are visible immediately.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool, sendFn func(tea.Msg)) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
@@ -1262,11 +1271,21 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
// Update context window fill from this step's usage. Each step sends
|
||||
// the full conversation to the LLM, so the reported token counts
|
||||
// represent the actual context utilization at that point.
|
||||
contextFill := int(ev.InputTokens) + int(ev.CacheReadTokens) + int(ev.CacheWriteTokens) + int(ev.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheWrite=%d + Output=%d)",
|
||||
contextFill, ev.InputTokens, ev.CacheReadTokens, ev.CacheWriteTokens, ev.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
// Notify the TUI so it re-renders the status bar with updated values.
|
||||
if sendFn != nil {
|
||||
sendFn(UsageUpdatedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
|
||||
@@ -534,9 +534,9 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
// recorded immediately for cost tracking. Context tokens are also updated so
|
||||
// the status bar reflects context fill after every LLM call in a multi-step
|
||||
// turn, not just at the end.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
@@ -547,7 +547,7 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
@@ -559,9 +559,13 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
// Context tokens should now be updated per-step (Input + CacheRead + CacheWrite + Output).
|
||||
if usage.contextCalls != 1 {
|
||||
t.Fatalf("expected 1 context token update from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
expectedContext := 120 + 45 + 5 + 2
|
||||
if usage.lastContextTokens != expectedContext {
|
||||
t.Fatalf("expected context tokens %d, got %d", expectedContext, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -210,6 +210,12 @@ type ModelChangedEvent struct {
|
||||
ModelName string
|
||||
}
|
||||
|
||||
// UsageUpdatedEvent is sent after each completed LLM step to notify the TUI
|
||||
// that token counts and costs have changed. The UsageTracker is updated
|
||||
// in-place before this event is sent; the TUI just needs to re-render to
|
||||
// reflect the new values in the status bar.
|
||||
type UsageUpdatedEvent struct{}
|
||||
|
||||
// WidgetUpdateEvent is sent when an extension adds, updates, or removes a
|
||||
// widget via ctx.SetWidget or ctx.RemoveWidget. The TUI re-reads widget state
|
||||
// from its WidgetProvider on the next render cycle.
|
||||
|
||||
@@ -38,6 +38,23 @@ type MCPServerConfig struct {
|
||||
// servers that don't support it.
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
|
||||
// TasksMode controls when this server's tools/call requests are augmented
|
||||
// with MCP task metadata (turning a synchronous call into an asynchronous,
|
||||
// pollable job — see https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks).
|
||||
//
|
||||
// Valid values:
|
||||
// - "" or "auto": (default) augment requests with task metadata only
|
||||
// when the server advertises tasks/toolCalls capability during initialize.
|
||||
// - "never": never augment — every tool call is synchronous, regardless
|
||||
// of server capability.
|
||||
// - "always": always augment, even when the server didn't advertise
|
||||
// task support. The server may still respond synchronously; this just
|
||||
// opts in unconditionally on the client side.
|
||||
//
|
||||
// In all modes, when the server returns a CreateTaskResult the client polls
|
||||
// tasks/get / tasks/result until the task reaches a terminal state.
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
|
||||
// InProcessServer holds a live *server.MCPServer for in-process transport.
|
||||
// When set (and Type is "inprocess"), the connection pool creates an
|
||||
// in-process client instead of spawning a subprocess or making HTTP calls.
|
||||
@@ -68,6 +85,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -80,6 +98,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Try new format first
|
||||
@@ -96,6 +115,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
s.NoOAuth = newConfig.NoOAuth
|
||||
s.TasksMode = newConfig.TasksMode
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,6 +136,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = legacyConfig.Headers
|
||||
s.AllowedTools = legacyConfig.AllowedTools
|
||||
s.ExcludedTools = legacyConfig.ExcludedTools
|
||||
s.TasksMode = legacyConfig.TasksMode
|
||||
|
||||
// Infer type from legacy format for better compatibility
|
||||
// Only set Type when it doesn't change existing transport behavior
|
||||
@@ -324,6 +345,17 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("server %s: allowedTools and excludedTools are mutually exclusive", serverName)
|
||||
}
|
||||
|
||||
// Reject unknown tasksMode values up front so a typo (e.g. "alwasy")
|
||||
// fails loud here instead of being silently downgraded to "auto" by
|
||||
// the runtime parser. Comparison is case-insensitive to match
|
||||
// tools.ParseTaskMode.
|
||||
switch strings.ToLower(strings.TrimSpace(serverConfig.TasksMode)) {
|
||||
case "", "auto", "never", "always":
|
||||
// ok
|
||||
default:
|
||||
return fmt.Errorf("server %s: invalid tasksMode %q (expected one of: auto, never, always)", serverName, serverConfig.TasksMode)
|
||||
}
|
||||
|
||||
transport := serverConfig.GetTransportType()
|
||||
switch transport {
|
||||
case "stdio":
|
||||
|
||||
@@ -627,3 +627,92 @@ func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_NewFormat(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://my-mcp-server.com",
|
||||
"tasksMode": "always"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "always" {
|
||||
t.Errorf("expected TasksMode 'always', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_LegacyFormat(t *testing.T) {
|
||||
// tasksMode also recognised in the legacy unmarshal path so users on
|
||||
// the older command/args shape can opt in without migrating.
|
||||
jsonData := `{
|
||||
"command": "npx",
|
||||
"args": ["@modelcontextprotocol/server-filesystem", "/path"],
|
||||
"tasksMode": "never"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "never" {
|
||||
t.Errorf("expected TasksMode 'never', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_DefaultEmpty(t *testing.T) {
|
||||
// When tasksMode is not set the field stays empty, which downstream
|
||||
// resolves to "auto" via tools.ParseTaskMode.
|
||||
jsonData := `{"type":"remote","url":"https://x.example"}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "" {
|
||||
t.Errorf("expected default TasksMode to be empty, got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_TasksMode(t *testing.T) {
|
||||
t.Run("empty is valid", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example"},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("empty TasksMode should validate, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("known values are valid", func(t *testing.T) {
|
||||
for _, mode := range []string{"auto", "never", "always", "AUTO", " always "} {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example", TasksMode: mode},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("TasksMode=%q should validate, got %v", mode, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("typo is rejected with a clear error", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"buildbot": {Type: "remote", URL: "https://x.example", TasksMode: "alwasy"},
|
||||
},
|
||||
}
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for invalid TasksMode")
|
||||
}
|
||||
// Error must mention the server name AND the bad value so the
|
||||
// user knows where to look.
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "buildbot") || !strings.Contains(msg, `"alwasy"`) {
|
||||
t.Errorf("error %q should mention both server name and bad value", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+6
-42
@@ -21,12 +21,9 @@ type Edit struct {
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
Path string `json:"path"`
|
||||
Edits []Edit `json:"edits"`
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
@@ -52,20 +49,12 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Description: "Edit a file by replacing exact text. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)",
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
@@ -85,7 +74,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
Required: []string{"path", "edits"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -163,36 +152,11 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
if len(args.Edits) == 0 {
|
||||
return nil, fmt.Errorf("edits array is required and must not be empty")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
|
||||
+62
-44
@@ -389,9 +389,11 @@ func TestExecuteEdit_ExactMatch(t *testing.T) {
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -426,9 +428,11 @@ func TestExecuteEdit_ExactMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
target := lines[49]
|
||||
replacement := "REPLACED_LINE_50"
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -470,9 +474,11 @@ func TestExecuteEdit_FuzzyMatch_TrailingWhitespace(t *testing.T) {
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -519,9 +525,11 @@ func TestExecuteEdit_FuzzyMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
search := strings.Repeat("x", 10) + "\n" + strings.Repeat("x", 10)
|
||||
// But this matches lines 1-2, 2-3, etc. — should fail due to ambiguity.
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -546,9 +554,11 @@ func TestExecuteEdit_MultipleMatches_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -575,9 +585,11 @@ func TestExecuteEdit_NoMatch_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -601,9 +613,11 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
writeFileOrFail(t, path, "line1\r\nline2\r\nline3\r\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -622,8 +636,10 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -636,9 +652,11 @@ func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_NonexistentFile(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -661,9 +679,11 @@ func TestExecuteEdit_DiffContainsHunkHeader(t *testing.T) {
|
||||
writeFileOrFail(t, path, strings.Join(lines, "\n")+"\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -684,9 +704,11 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
writeFileOrFail(t, path, "old content\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -905,18 +927,14 @@ func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
func TestExecuteEdit_EmptyEditsArray_Fails(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -924,10 +942,10 @@ func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
if !strings.Contains(resp.Content, "required") {
|
||||
t.Errorf("expected 'required' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// Package extbridge wires the public Kit SDK to the internal extensions
|
||||
// package. It exists so that cmd/ and internal/acpserver/ don't both
|
||||
// reimplement the same SDK→extension event/subagent conversions.
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
|
||||
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
|
||||
// (Type=="") for events that don't map to anything useful — callers should
|
||||
// drop those.
|
||||
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
|
||||
// the result/events back into the extension-facing types. The returned
|
||||
// handle is always nil — the SDK path runs synchronously and does not
|
||||
// expose a separate process handle. Callers that need non-blocking
|
||||
// behaviour should run this in their own goroutine.
|
||||
//
|
||||
// This function consolidates the previously-duplicated wiring in
|
||||
// cmd/root.go (interactive + runtime contexts) and
|
||||
// internal/acpserver/session.go.
|
||||
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: cfg.Prompt,
|
||||
Model: cfg.Model,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := SDKEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
cfg.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := k.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
}
|
||||
@@ -1094,6 +1094,14 @@ type API struct {
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
onStepStart func(func(StepStartEvent, Context))
|
||||
onStepFinish func(func(StepFinishEvent, Context))
|
||||
onReasoningStart func(func(ReasoningStartEvent, Context))
|
||||
onWarnings func(func(WarningsEvent, Context))
|
||||
onSource func(func(SourceEvent, Context))
|
||||
onError func(func(ErrorEvent, Context))
|
||||
onRetry func(func(RetryEvent, Context))
|
||||
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -1301,6 +1309,56 @@ func (a *API) OnBeforeCompact(handler func(BeforeCompactEvent, Context) *BeforeC
|
||||
a.onBeforeCompact(handler)
|
||||
}
|
||||
|
||||
// OnStepStart registers a handler that fires when a new LLM call begins
|
||||
// within a multi-step agent turn.
|
||||
func (a *API) OnStepStart(handler func(StepStartEvent, Context)) {
|
||||
a.onStepStart(handler)
|
||||
}
|
||||
|
||||
// OnStepFinish registers a handler that fires when a step completes,
|
||||
// providing step number, finish reason, and decomposed token usage.
|
||||
func (a *API) OnStepFinish(handler func(StepFinishEvent, Context)) {
|
||||
a.onStepFinish(handler)
|
||||
}
|
||||
|
||||
// OnReasoningStart registers a handler that fires when the LLM begins
|
||||
// reasoning/thinking.
|
||||
func (a *API) OnReasoningStart(handler func(ReasoningStartEvent, Context)) {
|
||||
a.onReasoningStart(handler)
|
||||
}
|
||||
|
||||
// OnWarnings registers a handler that fires when the LLM provider returns
|
||||
// warnings about the request.
|
||||
func (a *API) OnWarnings(handler func(WarningsEvent, Context)) {
|
||||
a.onWarnings(handler)
|
||||
}
|
||||
|
||||
// OnSource registers a handler that fires when the LLM references a source
|
||||
// (e.g. from web search tools).
|
||||
func (a *API) OnSource(handler func(SourceEvent, Context)) {
|
||||
a.onSource(handler)
|
||||
}
|
||||
|
||||
// OnError registers a handler that fires when an agent-level error occurs
|
||||
// during streaming.
|
||||
func (a *API) OnError(handler func(ErrorEvent, Context)) {
|
||||
a.onError(handler)
|
||||
}
|
||||
|
||||
// OnRetry registers a handler that fires when the LLM provider request is
|
||||
// retried after a transient error.
|
||||
func (a *API) OnRetry(handler func(RetryEvent, Context)) {
|
||||
a.onRetry(handler)
|
||||
}
|
||||
|
||||
// OnPrepareStep registers a handler that fires between steps within a
|
||||
// multi-step agent turn, after steering messages are injected and before
|
||||
// messages are sent to the LLM. Return a non-nil PrepareStepResult with
|
||||
// Messages to replace the context window for this step.
|
||||
func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStepResult) {
|
||||
a.onPrepareStep(handler)
|
||||
}
|
||||
|
||||
// RegisterToolRenderer registers a custom renderer for a specific tool's
|
||||
// display in the TUI. The renderer controls the header (parameter summary)
|
||||
// and/or body (result display) of the tool's output block. If multiple
|
||||
@@ -2253,6 +2311,98 @@ type SubagentEndEvent struct {
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Step lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// StepStartEvent fires when a new LLM call begins within a multi-step agent turn.
|
||||
type StepStartEvent struct {
|
||||
StepNumber int
|
||||
}
|
||||
|
||||
func (e StepStartEvent) Type() EventType { return StepStart }
|
||||
|
||||
// StepFinishEvent fires when a step completes, providing step metadata and
|
||||
// token usage. Usage fields are plain int64 (not LLMUsage) because Yaegi
|
||||
// cannot handle fantasy types across the interpreter boundary.
|
||||
type StepFinishEvent struct {
|
||||
StepNumber int
|
||||
HasToolCalls bool
|
||||
FinishReason string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheReadTokens int64
|
||||
CacheWriteTokens int64
|
||||
}
|
||||
|
||||
func (e StepFinishEvent) Type() EventType { return StepFinish }
|
||||
|
||||
// ReasoningStartEvent fires when the LLM begins reasoning/thinking.
|
||||
type ReasoningStartEvent struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (e ReasoningStartEvent) Type() EventType { return ReasoningStart }
|
||||
|
||||
// WarningsEvent fires when the LLM provider returns warnings about the request.
|
||||
type WarningsEvent struct {
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
func (e WarningsEvent) Type() EventType { return Warnings }
|
||||
|
||||
// SourceEvent fires when the LLM references a source (e.g. from web search).
|
||||
type SourceEvent struct {
|
||||
SourceType string
|
||||
ID string
|
||||
URL string
|
||||
Title string
|
||||
}
|
||||
|
||||
func (e SourceEvent) Type() EventType { return Source }
|
||||
|
||||
// ErrorEvent fires when an agent-level error occurs during streaming.
|
||||
// Uses string instead of error because Yaegi cannot handle the error
|
||||
// interface reliably across the interpreter boundary.
|
||||
type ErrorEvent struct {
|
||||
Error string
|
||||
}
|
||||
|
||||
func (e ErrorEvent) Type() EventType { return Error }
|
||||
|
||||
// RetryEvent fires when the LLM provider request is retried after a
|
||||
// transient error.
|
||||
type RetryEvent struct {
|
||||
Attempt int
|
||||
Error string
|
||||
}
|
||||
|
||||
func (e RetryEvent) Type() EventType { return Retry }
|
||||
|
||||
// PrepareStepEvent fires between steps within a multi-step agent turn,
|
||||
// after steering messages are injected and before messages are sent to
|
||||
// the LLM. Handlers can inspect and replace the context window.
|
||||
type PrepareStepEvent struct {
|
||||
// StepNumber is the zero-based step index within the current turn.
|
||||
StepNumber int
|
||||
// Messages is the current context window that will be sent to the LLM.
|
||||
Messages []ContextMessage
|
||||
}
|
||||
|
||||
func (e PrepareStepEvent) Type() EventType { return PrepareStep }
|
||||
|
||||
// PrepareStepResult allows extensions to replace the context window between
|
||||
// steps. Return nil Messages to leave the context unchanged.
|
||||
type PrepareStepResult struct {
|
||||
// Messages replaces the entire context window for this step. If nil,
|
||||
// the original messages are used unchanged. Messages with a non-negative
|
||||
// Index reuse the original message at that position; messages with
|
||||
// Index < 0 are created fresh from Role + Content.
|
||||
Messages []ContextMessage
|
||||
}
|
||||
|
||||
func (PrepareStepResult) isResult() {}
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -96,6 +96,35 @@ const (
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
|
||||
// StepStart fires when a new LLM call begins within a multi-step
|
||||
// agent turn.
|
||||
StepStart EventType = "step_start"
|
||||
|
||||
// StepFinish fires when a step completes, providing step number,
|
||||
// finish reason, and token usage.
|
||||
StepFinish EventType = "step_finish"
|
||||
|
||||
// ReasoningStart fires when the LLM begins reasoning/thinking.
|
||||
ReasoningStart EventType = "reasoning_start"
|
||||
|
||||
// Warnings fires when the LLM provider returns warnings.
|
||||
Warnings EventType = "warnings"
|
||||
|
||||
// Source fires when the LLM references a source (e.g. web search).
|
||||
Source EventType = "source"
|
||||
|
||||
// Error fires when an agent-level error occurs during streaming.
|
||||
Error EventType = "error"
|
||||
|
||||
// Retry fires when the LLM provider request is retried after a
|
||||
// transient error.
|
||||
Retry EventType = "retry"
|
||||
|
||||
// PrepareStep fires between steps within a multi-step agent turn,
|
||||
// after steering messages are injected and before messages are sent
|
||||
// to the LLM. Handlers can replace the context window for this step.
|
||||
PrepareStep EventType = "prepare_step"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -109,6 +138,8 @@ func AllEventTypes() []EventType {
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
|
||||
PrepareStep,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 24 {
|
||||
t.Fatalf("expected 24 event types, got %d", len(all))
|
||||
if len(all) != 32 {
|
||||
t.Fatalf("expected 32 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
|
||||
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
|
||||
// saveManifest round-trip against a temp directory, ensuring an absent
|
||||
// manifest loads as empty and a saved manifest reads back identically.
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
err = installer.saveManifest(manifest, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
t.Fatalf("saveManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was written to expected path
|
||||
if _, err := os.Stat(manifestPath); err != nil {
|
||||
t.Fatalf("manifest file not created: %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
loaded, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
|
||||
// followed by removeFromManifest leaves the manifest in its original
|
||||
// (empty) state, using a temp-directory installer scope.
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
|
||||
t.Fatalf("addToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
|
||||
t.Fatalf("removeFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
manifest, err = installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindInManifest writes a manifest file directly to the path
|
||||
// resolved by the package-level manifestPathForScope helper and then
|
||||
// confirms FindInManifest locates the entry by identity (and returns
|
||||
// nil for a non-existent identity).
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
t.Setenv("XDG_DATA_HOME", tempDir)
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
// Write a manifest entry directly via the package-level path resolver
|
||||
// so FindInManifest (which uses manifestPathForScope) can read it back.
|
||||
manifestPath := manifestPathForScope(ScopeGlobal)
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
|
||||
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
|
||||
@@ -618,6 +618,57 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onStepStart: func(h func(StepStartEvent, Context)) {
|
||||
reg(StepStart, func(e Event, c Context) Result {
|
||||
h(e.(StepStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onStepFinish: func(h func(StepFinishEvent, Context)) {
|
||||
reg(StepFinish, func(e Event, c Context) Result {
|
||||
h(e.(StepFinishEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onReasoningStart: func(h func(ReasoningStartEvent, Context)) {
|
||||
reg(ReasoningStart, func(e Event, c Context) Result {
|
||||
h(e.(ReasoningStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onWarnings: func(h func(WarningsEvent, Context)) {
|
||||
reg(Warnings, func(e Event, c Context) Result {
|
||||
h(e.(WarningsEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSource: func(h func(SourceEvent, Context)) {
|
||||
reg(Source, func(e Event, c Context) Result {
|
||||
h(e.(SourceEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onError: func(h func(ErrorEvent, Context)) {
|
||||
reg(Error, func(e Event, c Context) Result {
|
||||
h(e.(ErrorEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onRetry: func(h func(RetryEvent, Context)) {
|
||||
reg(Retry, func(e Event, c Context) Result {
|
||||
h(e.(RetryEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onPrepareStep: func(h func(PrepareStepEvent, Context) *PrepareStepResult) {
|
||||
reg(PrepareStep, func(e Event, c Context) Result {
|
||||
r := h(e.(PrepareStepEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
|
||||
@@ -2,22 +2,15 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
|
||||
@@ -172,6 +172,17 @@ func Symbols() interp.Exports {
|
||||
"SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)),
|
||||
"SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)),
|
||||
"ModelChangeEvent": reflect.ValueOf((*ModelChangeEvent)(nil)),
|
||||
|
||||
// Step lifecycle events
|
||||
"StepStartEvent": reflect.ValueOf((*StepStartEvent)(nil)),
|
||||
"StepFinishEvent": reflect.ValueOf((*StepFinishEvent)(nil)),
|
||||
"ReasoningStartEvent": reflect.ValueOf((*ReasoningStartEvent)(nil)),
|
||||
"WarningsEvent": reflect.ValueOf((*WarningsEvent)(nil)),
|
||||
"SourceEvent": reflect.ValueOf((*SourceEvent)(nil)),
|
||||
"ErrorEvent": reflect.ValueOf((*ErrorEvent)(nil)),
|
||||
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
|
||||
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
|
||||
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,8 +90,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
// 0. Check if tool is disabled via SetActiveTools.
|
||||
if w.runner.IsToolDisabled(toolName) {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)),
|
||||
fmt.Errorf("tool %q disabled by extension", toolName)
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
@@ -111,8 +110,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
if reason == "" {
|
||||
reason = "blocked by extension"
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
fmt.Errorf("tool blocked by extension: %s", reason)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +236,7 @@ func (t *extensionTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), err
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(result), nil
|
||||
}
|
||||
|
||||
@@ -142,8 +142,8 @@ func TestWrappedTool_BlockExecution(t *testing.T) {
|
||||
if toolRan {
|
||||
t.Error("tool should not have run after block")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("expected error from blocked tool")
|
||||
if err != nil {
|
||||
t.Error("expected nil error for blocked tool (error is conveyed via IsError response)")
|
||||
}
|
||||
if resp.IsError != true {
|
||||
t.Error("expected IsError=true from blocked response")
|
||||
@@ -234,8 +234,8 @@ func TestExtensionTool_Error(t *testing.T) {
|
||||
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
if err != nil {
|
||||
t.Error("expected nil error (error is conveyed via IsError response)")
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
|
||||
@@ -72,6 +72,9 @@ type AgentSetupOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -229,6 +232,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package models
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -111,13 +112,30 @@ func NewModelsRegistry() *ModelsRegistry {
|
||||
}
|
||||
|
||||
// buildFromModelsDB converts models.dev provider data into our internal format.
|
||||
// It tries the on-disk cache first and falls back to the embedded database.
|
||||
// It starts from the compile-time embedded database and merges on-disk cached
|
||||
// data from `kit update-models` on top. Cached provider metadata replaces
|
||||
// embedded metadata, and model entries are merged with cached models taking
|
||||
// precedence. This means newly synced models are available while embedded
|
||||
// models that haven't been synced yet are still reachable.
|
||||
func buildFromModelsDB() map[string]ProviderInfo {
|
||||
// Try cached data first (from `kit update-models`)
|
||||
dbProviders, _ := LoadCachedProviders()
|
||||
if len(dbProviders) == 0 {
|
||||
// Fall back to compile-time embedded data
|
||||
dbProviders = loadEmbeddedProviders()
|
||||
// Start with compile-time embedded data as the base.
|
||||
dbProviders := loadEmbeddedProviders()
|
||||
if dbProviders == nil {
|
||||
dbProviders = make(ModelsDBProviders)
|
||||
}
|
||||
|
||||
// Merge on-disk cached data on top (cached takes precedence).
|
||||
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
|
||||
for providerID, cp := range cached {
|
||||
if existing, ok := dbProviders[providerID]; ok {
|
||||
// Merge models: embedded base + cached overrides.
|
||||
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
|
||||
maps.Copy(mergedModels, existing.Models)
|
||||
maps.Copy(mergedModels, cp.Models)
|
||||
cp.Models = mergedModels
|
||||
}
|
||||
dbProviders[providerID] = cp
|
||||
}
|
||||
}
|
||||
|
||||
providers := make(map[string]ProviderInfo, len(dbProviders))
|
||||
|
||||
@@ -179,31 +179,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
|
||||
// It returns the deduplicated list and diagnostics for any collisions.
|
||||
// This is a standalone function for when you need to deduplicate an existing list.
|
||||
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
|
||||
seen := make(map[string]*PromptTemplate)
|
||||
var result []*PromptTemplate
|
||||
var diagnostics []Diagnostic
|
||||
|
||||
for _, tpl := range templates {
|
||||
if existing, ok := seen[tpl.Name]; ok {
|
||||
diagnostics = append(diagnostics, Diagnostic{
|
||||
Name: tpl.Name,
|
||||
KeptPath: existing.FilePath,
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: "duplicate template name (first-match-wins)",
|
||||
})
|
||||
} else {
|
||||
seen[tpl.Name] = tpl
|
||||
result = append(result, tpl)
|
||||
}
|
||||
}
|
||||
|
||||
return result, diagnostics
|
||||
}
|
||||
|
||||
// loadDefaultTemplates returns the built-in default templates.
|
||||
// These are embedded templates that ship with Kit.
|
||||
func loadDefaultTemplates() []*PromptTemplate {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeCwdForDir verifies the working-directory → session-directory
|
||||
// name encoding strips characters that are illegal on Windows (notably the
|
||||
// drive-letter colon, see issue #18) while preserving the previous output
|
||||
// for the typical Unix paths.
|
||||
func TestEncodeCwdForDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "unix absolute path",
|
||||
cwd: "/home/user/proj",
|
||||
want: "home--user--proj",
|
||||
},
|
||||
{
|
||||
name: "unix relative path",
|
||||
cwd: "proj/sub",
|
||||
want: "proj--sub",
|
||||
},
|
||||
{
|
||||
name: "windows drive root",
|
||||
cwd: `C:\test`,
|
||||
want: "C--test",
|
||||
},
|
||||
{
|
||||
name: "windows nested path",
|
||||
cwd: `C:\Users\User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows secondary drive",
|
||||
cwd: `S:\work\repo`,
|
||||
want: "S--work--repo",
|
||||
},
|
||||
{
|
||||
name: "windows mixed separators",
|
||||
cwd: `C:\Users/User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows other illegal chars stripped",
|
||||
cwd: `C:\a<b>c|d?e*f"g`,
|
||||
want: "C--abcdefg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := encodeCwdForDir(tc.cwd)
|
||||
if got != tc.want {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q, want %q", tc.cwd, got, tc.want)
|
||||
}
|
||||
// Encoded directory must never contain characters that are
|
||||
// illegal in Windows directory names.
|
||||
for _, bad := range []string{":", "<", ">", "\"", "|", "?", "*", "\\", "/"} {
|
||||
if strings.Contains(got, bad) {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q contains illegal char %q", tc.cwd, got, bad)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -63,6 +63,11 @@ type TreeManager struct {
|
||||
|
||||
// file is the open file handle for appending entries. Nil for in-memory.
|
||||
file *os.File
|
||||
|
||||
// writer is a buffered writer wrapping file. Writes go through this
|
||||
// buffer and are flushed to disk at explicit sync points (after each
|
||||
// public Append* call, in Close, etc.) to reduce syscall overhead.
|
||||
writer *bufio.Writer
|
||||
}
|
||||
|
||||
// --- Constructors ---
|
||||
@@ -105,11 +110,16 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
|
||||
return nil, fmt.Errorf("failed to create session file: %w", err)
|
||||
}
|
||||
tm.file = f
|
||||
tm.writer = bufio.NewWriter(f)
|
||||
|
||||
if err := tm.writeEntry(&header); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to write session header: %w", err)
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to flush session header: %w", err)
|
||||
}
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
@@ -150,6 +160,7 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
|
||||
return nil, fmt.Errorf("failed to recreate session file: %w", err)
|
||||
}
|
||||
newTm.file = f
|
||||
newTm.writer = bufio.NewWriter(f)
|
||||
|
||||
if err := newTm.writeEntry(&newTm.header); err != nil {
|
||||
_ = f.Close()
|
||||
@@ -289,6 +300,12 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all buffered writes from the fork in a single syscall.
|
||||
if err := newTm.flushLocked(); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to flush forked session: %w", err)
|
||||
}
|
||||
|
||||
// Set the leaf to the last entry in the new session.
|
||||
newTm.leafID = prevNewID
|
||||
|
||||
@@ -374,6 +391,7 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
return nil, fmt.Errorf("failed to open session file for append: %w", err)
|
||||
}
|
||||
tm.file = f
|
||||
tm.writer = bufio.NewWriter(f)
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
@@ -427,6 +445,9 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush message: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -451,6 +472,9 @@ func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, erro
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush model change: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -465,6 +489,9 @@ func (tm *TreeManager) AppendBranchSummary(fromID, summary string) (string, erro
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush branch summary: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -479,6 +506,9 @@ func (tm *TreeManager) AppendLabel(targetID, label string) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush label: %w", err)
|
||||
}
|
||||
|
||||
tm.labels[targetID] = label
|
||||
tm.leafID = entry.ID
|
||||
@@ -494,6 +524,9 @@ func (tm *TreeManager) AppendSessionInfo(name string) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush session info: %w", err)
|
||||
}
|
||||
|
||||
tm.sessionName = name
|
||||
tm.leafID = entry.ID
|
||||
@@ -510,6 +543,9 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush extension data: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -541,6 +577,9 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush compaction: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -926,11 +965,31 @@ func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
// Flush writes any buffered data to the underlying file.
|
||||
func (tm *TreeManager) Flush() error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// flushLocked writes buffered data to disk. Caller must hold the lock.
|
||||
func (tm *TreeManager) flushLocked() error {
|
||||
if tm.writer != nil {
|
||||
return tm.writer.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close flushes any buffered writes and closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
if tm.file != nil {
|
||||
// Flush buffered data before closing.
|
||||
if tm.writer != nil {
|
||||
_ = tm.writer.Flush()
|
||||
tm.writer = nil
|
||||
}
|
||||
err := tm.file.Close()
|
||||
tm.file = nil
|
||||
return err
|
||||
@@ -1090,13 +1149,22 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
// All entries are buffered and flushed to disk in a single batch.
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
entry, err := NewMessageEntry(tm.leafID, message.FromLLMMessage(msg))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return err
|
||||
}
|
||||
tm.leafID = entry.ID
|
||||
}
|
||||
return nil
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
@@ -1148,12 +1216,20 @@ func (tm *TreeManager) appendAndPersist(entry any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeEntry serializes an entry and appends it as a line to the file.
|
||||
// writeEntry serializes an entry and appends it to the buffered writer.
|
||||
// The data is not flushed to disk until flushLocked is called.
|
||||
func (tm *TreeManager) writeEntry(entry any) error {
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal entry: %w", err)
|
||||
}
|
||||
if tm.writer != nil {
|
||||
if _, err := tm.writer.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return tm.writer.WriteByte('\n')
|
||||
}
|
||||
// Fallback for direct file writes (shouldn't happen in normal flow).
|
||||
data = append(data, '\n')
|
||||
_, err = tm.file.Write(data)
|
||||
return err
|
||||
@@ -1274,15 +1350,44 @@ func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[stri
|
||||
// --- Path conventions ---
|
||||
|
||||
// DefaultSessionDir returns the default session storage directory for a cwd.
|
||||
// Convention: ~/.kit/sessions/--<cwd-path>--/
|
||||
// Convention: ~/.kit/sessions/<encoded-cwd>, where path separators are
|
||||
// encoded as "--" with no leading or trailing dashes — e.g.
|
||||
// /home/user/proj becomes home--user--proj. See encodeCwdForDir for the
|
||||
// full encoding rules (including Windows path handling).
|
||||
func DefaultSessionDir(cwd string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
// Convert path separators to double dashes.
|
||||
safeCwd := strings.ReplaceAll(cwd, string(filepath.Separator), "--")
|
||||
return filepath.Join(home, ".kit", "sessions", encodeCwdForDir(cwd))
|
||||
}
|
||||
|
||||
// encodeCwdForDir converts a working-directory path into a single, filesystem-
|
||||
// safe directory name. Path separators are replaced with double dashes and
|
||||
// characters that are illegal in Windows directory names — most importantly
|
||||
// the colon that follows the drive letter (e.g. `C:\foo` → `C--foo`) — are
|
||||
// stripped. The result is identical to the previous Unix-only encoding for
|
||||
// paths that do not contain such characters, so existing session directories
|
||||
// are preserved.
|
||||
func encodeCwdForDir(cwd string) string {
|
||||
// Convert both `/` and `\` to double dashes so encoding is stable across
|
||||
// platforms and remains correct on Windows where `filepath.Separator`
|
||||
// would otherwise miss forward-slash style paths.
|
||||
safeCwd := strings.ReplaceAll(cwd, "\\", "--")
|
||||
safeCwd = strings.ReplaceAll(safeCwd, "/", "--")
|
||||
// Remove leading separator replacement.
|
||||
safeCwd = strings.TrimPrefix(safeCwd, "--")
|
||||
return filepath.Join(home, ".kit", "sessions", safeCwd)
|
||||
// Strip characters that are illegal in directory names on Windows
|
||||
// (`< > : " | ? *`). On Unix these characters are legal but rare in
|
||||
// practice; stripping them keeps the encoding portable.
|
||||
replacer := strings.NewReplacer(
|
||||
":", "",
|
||||
"<", "",
|
||||
">", "",
|
||||
"\"", "",
|
||||
"|", "",
|
||||
"?", "",
|
||||
"*", "",
|
||||
)
|
||||
return replacer.Replace(safeCwd)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type MCPConnection struct {
|
||||
client client.MCPClient
|
||||
serverName string
|
||||
serverConfig config.MCPServerConfig
|
||||
initResult *mcp.InitializeResult // captured at handshake; nil before initialize
|
||||
lastUsed time.Time
|
||||
isHealthy bool
|
||||
errorCount int
|
||||
@@ -262,7 +263,9 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
conn := &MCPConnection{}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if oauthEnabled && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
@@ -270,7 +273,7 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
@@ -280,15 +283,11 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
isHealthy: true,
|
||||
errorCount: 0,
|
||||
lastError: nil,
|
||||
}
|
||||
conn.client = mcpClient
|
||||
conn.serverName = serverName
|
||||
conn.serverConfig = serverConfig
|
||||
conn.lastUsed = time.Now()
|
||||
conn.isHealthy = true
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Created connection for %s", serverName))
|
||||
@@ -484,8 +483,10 @@ func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenS
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
// initializeClient initializes the client and captures the server's
|
||||
// initialize result on the supplied connection so callers can later
|
||||
// inspect advertised capabilities (e.g. task support).
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, c client.MCPClient, conn *MCPConnection) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -495,12 +496,21 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
Name: "kit",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
|
||||
// Advertise task support so servers may return CreateTaskResult for
|
||||
// long-running tools/call requests instead of blocking the connection
|
||||
// until completion. The client is responsible for polling tasks/get and
|
||||
// tasks/result until the task reaches a terminal state.
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{
|
||||
Tasks: mcp.NewTasksCapability(),
|
||||
}
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
initResult, err := c.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
if conn != nil {
|
||||
conn.initResult = initResult
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug("[POOL] Initialized MCP client")
|
||||
@@ -615,6 +625,54 @@ func (c *MCPConnection) ServerName() string {
|
||||
return c.serverName
|
||||
}
|
||||
|
||||
// InitializeResult returns the result captured from the server's initialize
|
||||
// response, or nil if the connection was created before initialize completed.
|
||||
// Callers can inspect ServerCapabilities.Tasks to discover task-related
|
||||
// capability advertisements.
|
||||
func (c *MCPConnection) InitializeResult() *mcp.InitializeResult {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.initResult
|
||||
}
|
||||
|
||||
// SupportsToolTasks reports whether the server advertised support for
|
||||
// task-augmented tools/call requests. Returns false when the connection has
|
||||
// not yet completed initialization or when the server omitted task
|
||||
// capabilities.
|
||||
func (c *MCPConnection) SupportsToolTasks() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return supportsToolTasksFromInit(c.initResult)
|
||||
}
|
||||
|
||||
// supportsToolTasksFromInit reports whether the supplied InitializeResult
|
||||
// advertises task-augmented tools/call support. Extracted to a free function
|
||||
// for unit testing without standing up a connection.
|
||||
func supportsToolTasksFromInit(init *mcp.InitializeResult) bool {
|
||||
if init == nil || init.Capabilities.Tasks == nil {
|
||||
return false
|
||||
}
|
||||
req := init.Capabilities.Tasks.Requests
|
||||
if req == nil || req.Tools == nil {
|
||||
return false
|
||||
}
|
||||
return req.Tools.Call != nil
|
||||
}
|
||||
|
||||
// ServerSupportsToolTasks reports whether the named server's connection
|
||||
// advertises task-augmented tools/call support. Returns false when no
|
||||
// connection exists for the server or when the server didn't advertise the
|
||||
// capability.
|
||||
func (p *MCPConnectionPool) ServerSupportsToolTasks(serverName string) bool {
|
||||
p.mu.RLock()
|
||||
conn, ok := p.connections[serverName]
|
||||
p.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return conn.SupportsToolTasks()
|
||||
}
|
||||
|
||||
// GetClients returns a map of all MCP clients currently in the pool.
|
||||
// The map keys are server names and values are the corresponding MCP client instances.
|
||||
// The returned map is a copy and modifications won't affect the pool.
|
||||
|
||||
+229
-27
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
log "github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
@@ -141,6 +143,11 @@ type MCPToolManager struct {
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
|
||||
// taskCfg controls task-augmented tools/call execution. The zero value
|
||||
// means: auto-detect server capability, no progress callback, default
|
||||
// poll/timeout.
|
||||
taskCfg MCPTaskConfig
|
||||
|
||||
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
||||
// Called with server name, tool count, and error (nil on success).
|
||||
onServerLoaded func(serverName string, toolCount int, err error)
|
||||
@@ -220,6 +227,21 @@ func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
||||
m.onToolsChanged = cb
|
||||
}
|
||||
|
||||
// SetTaskConfig sets the task-augmented tools/call configuration. Call
|
||||
// this before LoadTools / AddServer if you want the per-server mode
|
||||
// override and progress handler to take effect for the very first call.
|
||||
// Subsequent calls replace the previous configuration wholesale.
|
||||
func (m *MCPToolManager) SetTaskConfig(cfg MCPTaskConfig) {
|
||||
m.taskCfg = cfg
|
||||
}
|
||||
|
||||
// TaskConfig returns the manager's current task-augmented tools/call
|
||||
// configuration. The zero value means: defer to per-server config and
|
||||
// auto-detected capability, with no progress callback and default polling.
|
||||
func (m *MCPToolManager) TaskConfig() MCPTaskConfig {
|
||||
return m.taskCfg
|
||||
}
|
||||
|
||||
// AddServer connects to a new MCP server at runtime and loads its tools.
|
||||
// The server's tools are immediately available to the agent after this call.
|
||||
// Returns the number of tools loaded from the server.
|
||||
@@ -551,6 +573,14 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
// checks, OAuth re-authorization, and connection error tracking.
|
||||
// The inputJSON parameter is the raw JSON arguments from the LLM.
|
||||
// Returns the result content, error flag, and any execution error.
|
||||
//
|
||||
// When the per-server TasksMode resolves to "always", or to "auto" and the
|
||||
// server advertised tasks/toolCalls capability during initialize, the call
|
||||
// is augmented with TaskParams. If the server elects to respond with a
|
||||
// CreateTaskResult the manager polls tasks/get / tasks/result until the
|
||||
// task reaches a terminal state, transparently presenting the final
|
||||
// CallToolResult-equivalent content to the agent layer. Context
|
||||
// cancellation triggers a best-effort tasks/cancel.
|
||||
func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*MCPToolResult, error) {
|
||||
m.mu.Lock()
|
||||
mapping, ok := m.toolMap[prefixedName]
|
||||
@@ -582,49 +612,221 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
return nil, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
||||
}
|
||||
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
callParams := mcp.CallToolParams{
|
||||
Name: mapping.originalName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
// Call the MCP tool using the original (unprefixed) name
|
||||
result, err := conn.client.CallTool(ctx, callRequest)
|
||||
if err != nil {
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, err); flowErr != nil {
|
||||
// Decide whether to augment the request with TaskParams. Modes:
|
||||
// never — never augment (synchronous-only).
|
||||
// always — always augment, even without server capability.
|
||||
// auto — augment only when the server advertised tasks/toolCalls.
|
||||
mode := m.resolveTaskMode(mapping.serverName, mapping.serverConfig)
|
||||
useTask := mode == MCPTaskModeAlways ||
|
||||
(mode == MCPTaskModeAuto && conn.SupportsToolTasks())
|
||||
if useTask {
|
||||
var ttl *int64
|
||||
if m.taskCfg.DefaultTTL > 0 {
|
||||
ms := m.taskCfg.DefaultTTL.Milliseconds()
|
||||
ttl = &ms
|
||||
}
|
||||
callParams.Task = &mcp.TaskParams{TTL: ttl}
|
||||
}
|
||||
|
||||
// Synchronous fast path: no task augmentation. Use the upstream client
|
||||
// helper which keeps content-block typing identical to historical
|
||||
// behaviour.
|
||||
if !useTask {
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Task-augmented path. Bypass the upstream CallTool helper because its
|
||||
// ParseCallToolResult requires a "content" field that is absent from a
|
||||
// CreateTaskResult.
|
||||
rawClient, ok := conn.client.(*client.Client)
|
||||
if !ok {
|
||||
// Older client implementations — fall back to the synchronous shape.
|
||||
callParams.Task = nil
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
||||
}
|
||||
|
||||
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, callRequest)
|
||||
if err != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, err)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, err)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
marshaledResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
// Server chose to answer synchronously — same shape as the no-task path.
|
||||
if callResult != nil {
|
||||
marshaledResult, mErr := json.Marshal(callResult)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: callResult.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Asynchronous task path: poll until terminal, then return the result.
|
||||
if taskResult == nil {
|
||||
return nil, errors.New("mcp tools/call returned neither result nor task")
|
||||
}
|
||||
final, pollErr := pollTaskUntilTerminal(
|
||||
ctx, rawClient, mapping.serverName, taskResult.Task,
|
||||
m.taskCfg, m.taskCfg.Progress,
|
||||
)
|
||||
if pollErr != nil {
|
||||
return nil, fmt.Errorf("task execution failed: %w", pollErr)
|
||||
}
|
||||
|
||||
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
||||
adapted := &mcp.CallToolResult{
|
||||
Content: final.Content,
|
||||
StructuredContent: final.StructuredContent,
|
||||
IsError: final.IsError,
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(adapted)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
IsError: final.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveTaskMode resolves the effective task mode for a given server.
|
||||
// Programmatic overrides via SetTaskConfig take precedence over the
|
||||
// per-server TasksMode in MCPServerConfig. Empty / unknown values map to
|
||||
// MCPTaskModeAuto.
|
||||
func (m *MCPToolManager) resolveTaskMode(name string, cfg config.MCPServerConfig) MCPTaskMode {
|
||||
if m.taskCfg.PerServerMode != nil {
|
||||
if v, ok := m.taskCfg.PerServerMode[name]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ParseTaskMode(cfg.TasksMode)
|
||||
}
|
||||
|
||||
// ListServerTasks queries tasks/list on the named server and returns the
|
||||
// active and recent tasks the server is willing to disclose. Errors are
|
||||
// returned untouched (callers commonly ignore METHOD_NOT_FOUND when the
|
||||
// server didn't advertise tasks/list capability).
|
||||
func (m *MCPToolManager) ListServerTasks(ctx context.Context, serverName string) ([]MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := c.ListTasks(ctx, mcp.ListTasksRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tasks/list on %s: %w", serverName, err)
|
||||
}
|
||||
out := make([]MCPTaskInfo, 0, len(res.Tasks))
|
||||
for _, t := range res.Tasks {
|
||||
out = append(out, taskFromMCP(serverName, t))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetServerTask queries tasks/get for a single task on the named server.
|
||||
func (m *MCPToolManager) GetServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, err
|
||||
}
|
||||
res, err := c.GetTask(ctx, mcp.GetTaskRequest{Params: mcp.GetTaskParams{TaskId: taskID}})
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, fmt.Errorf("tasks/get on %s: %w", serverName, err)
|
||||
}
|
||||
return taskFromMCP(serverName, res.Task), nil
|
||||
}
|
||||
|
||||
// CancelServerTask issues tasks/cancel for a task on the named server.
|
||||
// Returns the post-cancel task state when the server responded with one.
|
||||
func (m *MCPToolManager) CancelServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, err
|
||||
}
|
||||
res, err := c.CancelTask(ctx, mcp.CancelTaskRequest{Params: mcp.CancelTaskParams{TaskId: taskID}})
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, fmt.Errorf("tasks/cancel on %s: %w", serverName, err)
|
||||
}
|
||||
return taskFromMCP(serverName, res.Task), nil
|
||||
}
|
||||
|
||||
// taskClient returns the *client.Client for a server. Tasks endpoints are
|
||||
// not part of the upstream MCPClient interface so callers must work with
|
||||
// the concrete client. Returns an error when the connection is missing
|
||||
// or backed by a non-standard client type.
|
||||
func (m *MCPToolManager) taskClient(serverName string) (*client.Client, error) {
|
||||
if m.connectionPool == nil {
|
||||
return nil, fmt.Errorf("no connection pool available")
|
||||
}
|
||||
clients := m.connectionPool.GetClients()
|
||||
raw, ok := clients[serverName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("MCP server %q not loaded", serverName)
|
||||
}
|
||||
c, ok := raw.(*client.Client)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("MCP server %q does not support task RPCs", serverName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetTools returns all loaded MCP tools from all configured MCP servers.
|
||||
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
||||
func (m *MCPToolManager) GetTools() []MCPTool {
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// MCPTaskMode controls when the connection pool augments tools/call requests
|
||||
// with MCP task metadata. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks.
|
||||
type MCPTaskMode string
|
||||
|
||||
const (
|
||||
// MCPTaskModeAuto augments tools/call with task metadata only when the
|
||||
// server advertises tasks/toolCalls capability during initialize.
|
||||
MCPTaskModeAuto MCPTaskMode = "auto"
|
||||
// MCPTaskModeNever forces every tools/call to be issued synchronously
|
||||
// (no Task field in the request), regardless of server capability.
|
||||
MCPTaskModeNever MCPTaskMode = "never"
|
||||
// MCPTaskModeAlways always sets a Task field on the tools/call request,
|
||||
// even when the server didn't advertise task support. The server may
|
||||
// still respond synchronously; this just opts in unconditionally on
|
||||
// the client side.
|
||||
MCPTaskModeAlways MCPTaskMode = "always"
|
||||
)
|
||||
|
||||
// ParseTaskMode normalises a per-server tasks-mode string from
|
||||
// configuration. Empty input maps to MCPTaskModeAuto. Unknown values are
|
||||
// also treated as MCPTaskModeAuto so a stray config typo never breaks
|
||||
// existing flows.
|
||||
func ParseTaskMode(s string) MCPTaskMode {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "auto":
|
||||
return MCPTaskModeAuto
|
||||
case "never", "off", "disabled":
|
||||
return MCPTaskModeNever
|
||||
case "always", "force":
|
||||
return MCPTaskModeAlways
|
||||
default:
|
||||
return MCPTaskModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// MCPTaskInfo is the connection-layer view of an MCP Task. It mirrors the
|
||||
// upstream mcp.Task but exposes Go-native types and includes the originating
|
||||
// server name. SDK-level wrappers re-export this under public-facing names.
|
||||
type MCPTaskInfo struct {
|
||||
// Server is the configured MCP server name this task lives on.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the task.
|
||||
TaskID string
|
||||
// Status is the current task lifecycle state.
|
||||
Status mcp.TaskStatus
|
||||
// StatusMessage is an optional human-readable description.
|
||||
StatusMessage string
|
||||
// CreatedAt is the wall-clock time the task was created (best-effort
|
||||
// parsed from the server's ISO-8601 timestamp; zero on parse failure).
|
||||
CreatedAt time.Time
|
||||
// UpdatedAt is the wall-clock time the task was last updated (best-
|
||||
// effort parsed; zero on parse failure).
|
||||
UpdatedAt time.Time
|
||||
// TTL is the time-to-live the server intends to retain the task after
|
||||
// creation. Zero means the server did not advertise a TTL.
|
||||
TTL time.Duration
|
||||
// PollInterval is the suggested polling interval. Zero means use the
|
||||
// client's default.
|
||||
PollInterval time.Duration
|
||||
}
|
||||
|
||||
// MCPTaskProgress is emitted while the connection pool is waiting on a
|
||||
// task-augmented tool call. It provides minimal feedback for SDK consumers
|
||||
// that want to render progress widgets without subscribing to the full
|
||||
// notifications/tasks/status channel (Phase 2).
|
||||
type MCPTaskProgress struct {
|
||||
Server string
|
||||
TaskID string
|
||||
Status mcp.TaskStatus
|
||||
Message string
|
||||
}
|
||||
|
||||
// MCPTaskProgressHandler is invoked once after a task is accepted and on
|
||||
// every status transition observed by the polling loop. The final
|
||||
// invocation always carries a terminal status. Implementations must not
|
||||
// block; long work should be queued on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// MCPTaskConfig configures task-aware tool execution on the manager.
|
||||
// All fields are optional; the zero value disables progress callbacks and
|
||||
// applies sensible defaults.
|
||||
type MCPTaskConfig struct {
|
||||
// PerServerMode overrides the per-server TasksMode resolved from
|
||||
// MCPServerConfig. Keys are server names. Missing entries fall back
|
||||
// to the value from config. Used by SDK consumers that want to set
|
||||
// modes programmatically.
|
||||
PerServerMode map[string]MCPTaskMode
|
||||
|
||||
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
|
||||
// tools/call. Zero means omit the TTL — let the server pick its own.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// PollInterval is the fallback interval between tasks/get requests
|
||||
// when the server does not suggest one. Zero defaults to 1 second.
|
||||
PollInterval time.Duration
|
||||
|
||||
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
|
||||
MaxPollInterval time.Duration
|
||||
|
||||
// Timeout is the maximum wall-clock duration to wait for a task to
|
||||
// reach a terminal state. Zero defaults to 15 minutes. Independent
|
||||
// of the per-call context deadline; whichever fires first wins.
|
||||
Timeout time.Duration
|
||||
|
||||
// Progress, if non-nil, receives every status transition observed by
|
||||
// the polling loop.
|
||||
Progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
func (c MCPTaskConfig) resolved() MCPTaskConfig {
|
||||
if c.PollInterval <= 0 {
|
||||
c.PollInterval = 1 * time.Second
|
||||
}
|
||||
if c.MaxPollInterval <= 0 {
|
||||
c.MaxPollInterval = 5 * time.Second
|
||||
}
|
||||
if c.Timeout <= 0 {
|
||||
c.Timeout = 15 * time.Minute
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// requestIDCounter generates monotonically increasing JSON-RPC request IDs
|
||||
// for low-level tools/call invocations that bypass the upstream client's
|
||||
// ParseCallToolResult helper (necessary because that helper rejects task
|
||||
// responses for lacking a "content" field).
|
||||
//
|
||||
// The counter is process-wide rather than per-manager so multiple managers
|
||||
// or repeated calls within the same connection produce unique IDs.
|
||||
var requestIDCounter atomic.Int64
|
||||
|
||||
func nextRequestID() mcp.RequestId {
|
||||
return mcp.NewRequestId(requestIDCounter.Add(1))
|
||||
}
|
||||
|
||||
// callToolWithTask issues tools/call directly on the transport so we can
|
||||
// observe both response shapes:
|
||||
//
|
||||
// - {"content": [...], ...} — synchronous CallToolResult.
|
||||
// - {"task": {...}, ...} — asynchronous CreateTaskResult.
|
||||
//
|
||||
// On success exactly one of (callResult, taskResult) is non-nil. The
|
||||
// upstream client.CallTool helper parses the response with
|
||||
// mcp.ParseCallToolResult which requires a "content" field, so it cannot
|
||||
// be used for task-augmented calls.
|
||||
func callToolWithTask(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
params mcp.CallToolParams,
|
||||
) (callResult *mcp.CallToolResult, taskResult *mcp.CreateTaskResult, err error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
Params: params,
|
||||
}
|
||||
|
||||
resp, sendErr := tr.SendRequest(ctx, req)
|
||||
if sendErr != nil {
|
||||
return nil, nil, sendErr
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Peek at the raw result to decide which shape we got.
|
||||
var probe struct {
|
||||
Task json.RawMessage `json:"task"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
raw := resp.Result
|
||||
if len(raw) == 0 {
|
||||
return nil, nil, errors.New("empty tools/call result")
|
||||
}
|
||||
if uErr := json.Unmarshal(raw, &probe); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode tools/call result: %w", uErr)
|
||||
}
|
||||
|
||||
if len(probe.Task) > 0 && string(probe.Task) != "null" {
|
||||
// Task-augmented response.
|
||||
var ct mcp.CreateTaskResult
|
||||
if uErr := json.Unmarshal(raw, &ct); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode CreateTaskResult: %w", uErr)
|
||||
}
|
||||
return nil, &ct, nil
|
||||
}
|
||||
|
||||
// Synchronous response — defer to the upstream parser so content blocks
|
||||
// are typed correctly (TextContent, ImageContent, ResourceLink, etc.).
|
||||
cr, pErr := mcp.ParseCallToolResult(&raw)
|
||||
if pErr != nil {
|
||||
return nil, nil, fmt.Errorf("parse CallToolResult: %w", pErr)
|
||||
}
|
||||
return cr, nil, nil
|
||||
}
|
||||
|
||||
// pollTaskUntilTerminal blocks until the task reaches a terminal status,
|
||||
// the context is cancelled, or the configured timeout elapses. On
|
||||
// cancellation it best-effort issues tasks/cancel before returning.
|
||||
func pollTaskUntilTerminal(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
serverName string,
|
||||
task mcp.Task,
|
||||
cfg MCPTaskConfig,
|
||||
progress MCPTaskProgressHandler,
|
||||
) (*mcp.TaskResultResult, error) {
|
||||
cfg = cfg.resolved()
|
||||
deadline := time.Now().Add(cfg.Timeout)
|
||||
|
||||
emit := func(status mcp.TaskStatus, msg string) {
|
||||
if progress != nil {
|
||||
progress(MCPTaskProgress{Server: serverName, TaskID: task.TaskId, Status: status, Message: msg})
|
||||
}
|
||||
}
|
||||
|
||||
emit(task.Status, task.StatusMessage)
|
||||
|
||||
current := task
|
||||
interval := cfg.PollInterval
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = time.Duration(*current.PollInterval) * time.Millisecond
|
||||
}
|
||||
if interval > cfg.MaxPollInterval {
|
||||
interval = cfg.MaxPollInterval
|
||||
}
|
||||
|
||||
for !current.Status.IsTerminal() {
|
||||
if time.Now().After(deadline) {
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, fmt.Errorf("task %s timed out after %s", current.TaskId, cfg.Timeout)
|
||||
}
|
||||
|
||||
// Wait between polls or abort early on context cancellation.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
}
|
||||
|
||||
got, err := c.GetTask(ctx, mcp.GetTaskRequest{
|
||||
Params: mcp.GetTaskParams{TaskId: current.TaskId},
|
||||
})
|
||||
if err != nil {
|
||||
// Transient transport hiccup — propagate immediately. The
|
||||
// upstream agent layer treats this like any other tool error.
|
||||
return nil, fmt.Errorf("tasks/get failed: %w", err)
|
||||
}
|
||||
current = got.Task
|
||||
if current.Status != task.Status || current.StatusMessage != task.StatusMessage {
|
||||
emit(current.Status, current.StatusMessage)
|
||||
task = current
|
||||
}
|
||||
|
||||
// Honour any updated suggested poll interval, capped at the limit.
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = min(time.Duration(*current.PollInterval)*time.Millisecond, cfg.MaxPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// Terminal state reached. Emit one last progress event and fetch the
|
||||
// definitive tool result.
|
||||
emit(current.Status, current.StatusMessage)
|
||||
|
||||
if current.Status == mcp.TaskStatusCancelled {
|
||||
return nil, fmt.Errorf("task %s was cancelled", current.TaskId)
|
||||
}
|
||||
|
||||
res, err := fetchTaskResult(ctx, c, current.TaskId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tasks/result failed: %w", err)
|
||||
}
|
||||
if current.Status == mcp.TaskStatusFailed && res != nil && !res.IsError {
|
||||
// The server flagged the task as failed but didn't decorate the
|
||||
// result. Surface the status message so the caller still sees a
|
||||
// useful tool-error.
|
||||
return nil, fmt.Errorf("task %s failed: %s", current.TaskId, current.StatusMessage)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// fetchTaskResult issues tasks/result on the transport and parses the raw
|
||||
// response. The upstream client.TaskResult helper delegates to
|
||||
// mcp.ParseTaskResultResult which (as of mcp-go v0.51.0) looks for the
|
||||
// content array under a nested "result" key that never exists in the
|
||||
// wire format — leading to systematically empty Content. Doing the
|
||||
// parse here keeps the polling path working until that is fixed upstream.
|
||||
func fetchTaskResult(ctx context.Context, c *client.Client, taskID string) (*mcp.TaskResultResult, error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodTasksResult),
|
||||
Params: mcp.TaskResultParams{TaskId: taskID},
|
||||
}
|
||||
resp, err := tr.SendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Manually decode the wire shape: {"_meta": {...}, "content": [...],
|
||||
// "structuredContent": ..., "isError": bool}.
|
||||
var shape struct {
|
||||
Meta json.RawMessage `json:"_meta"`
|
||||
Content []json.RawMessage `json:"content"`
|
||||
StructuredContent any `json:"structuredContent"`
|
||||
IsError bool `json:"isError"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Result, &shape); err != nil {
|
||||
return nil, fmt.Errorf("decode tasks/result: %w", err)
|
||||
}
|
||||
|
||||
out := &mcp.TaskResultResult{
|
||||
StructuredContent: shape.StructuredContent,
|
||||
IsError: shape.IsError,
|
||||
}
|
||||
if len(shape.Meta) > 0 && string(shape.Meta) != "null" {
|
||||
var metaMap map[string]any
|
||||
if err := json.Unmarshal(shape.Meta, &metaMap); err == nil {
|
||||
out.Meta = mcp.NewMetaFromMap(metaMap)
|
||||
}
|
||||
}
|
||||
for _, raw := range shape.Content {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(raw, &contentMap); err != nil {
|
||||
return nil, fmt.Errorf("decode content block: %w", err)
|
||||
}
|
||||
parsed, err := mcp.ParseContent(contentMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse content block: %w", err)
|
||||
}
|
||||
out.Content = append(out.Content, parsed)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cancelTaskBestEffort issues tasks/cancel and ignores any error. Used on
|
||||
// context cancellation paths where the connection is already going away.
|
||||
func cancelTaskBestEffort(c *client.Client, taskID string) {
|
||||
if c == nil || taskID == "" {
|
||||
return
|
||||
}
|
||||
cancelCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _ = c.CancelTask(cancelCtx, mcp.CancelTaskRequest{
|
||||
Params: mcp.CancelTaskParams{TaskId: taskID},
|
||||
})
|
||||
}
|
||||
|
||||
// taskFromMCP converts a wire-format mcp.Task to our richer connection-
|
||||
// layer view. Unparseable timestamps surface as the zero time.
|
||||
func taskFromMCP(serverName string, t mcp.Task) MCPTaskInfo {
|
||||
out := MCPTaskInfo{
|
||||
Server: serverName,
|
||||
TaskID: t.TaskId,
|
||||
Status: t.Status,
|
||||
StatusMessage: t.StatusMessage,
|
||||
}
|
||||
if t.CreatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.CreatedAt); err == nil {
|
||||
out.CreatedAt = v
|
||||
}
|
||||
}
|
||||
if t.LastUpdatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.LastUpdatedAt); err == nil {
|
||||
out.UpdatedAt = v
|
||||
}
|
||||
}
|
||||
if t.TTL != nil {
|
||||
out.TTL = time.Duration(*t.TTL) * time.Millisecond
|
||||
}
|
||||
if t.PollInterval != nil {
|
||||
out.PollInterval = time.Duration(*t.PollInterval) * time.Millisecond
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTaskTestInProcessServer builds an in-process MCP server with a
|
||||
// task-augmented tool. The handler simulates work by sleeping briefly
|
||||
// before completing.
|
||||
//
|
||||
// Important: the upstream mcp-go server cancels the request context as
|
||||
// soon as the synchronous part of the tools/call returns (see
|
||||
// request_handler.go:85, `defer cancel()`). Task goroutines spawned by
|
||||
// AddTaskTool inherit that context and therefore see context.Canceled
|
||||
// the instant they start. Real-world transports (stdio, SSE, streamable
|
||||
// HTTP) don't trip this because they keep the connection — and a
|
||||
// background context — alive across the async work, but the in-process
|
||||
// transport runs entirely on the request goroutine. To test the polling
|
||||
// path realistically we detach from the request context here.
|
||||
func newTaskTestInProcessServer(t *testing.T, workDuration time.Duration) *server.MCPServer {
|
||||
t.Helper()
|
||||
srv := server.NewMCPServer("task-test", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
// list=true, cancel=true, toolCallTasks=true so capability detection,
|
||||
// cancellation, and tool augmentation all flow through.
|
||||
server.WithTaskCapabilities(true, true, true),
|
||||
)
|
||||
srv.AddTaskTool(
|
||||
mcp.Tool{
|
||||
Name: "long_running",
|
||||
Description: "Sleep, then echo the input string.",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"msg": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
Execution: &mcp.ToolExecution{
|
||||
TaskSupport: mcp.TaskSupportRequired,
|
||||
},
|
||||
},
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CreateTaskResult, error) {
|
||||
msg, _ := req.GetArguments()["msg"].(string)
|
||||
// Detach from the request context so the task handler can
|
||||
// outlive the synchronous request — see comment above.
|
||||
time.Sleep(workDuration)
|
||||
_ = ctx
|
||||
return &mcp.CreateTaskResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "echo:" + msg},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
// newSyncOnlyServer is a server that does NOT advertise task capability.
|
||||
// Used to verify the auto-detect path keeps the sync semantics.
|
||||
func newSyncOnlyServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("sync-only", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required()),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("hi " + name), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestConnectionPoolAdvertisesTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
srv := newTaskTestInProcessServer(t, 0)
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: srv}
|
||||
|
||||
conn, err := pool.GetConnection(context.Background(), "tasks", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
|
||||
init := conn.InitializeResult()
|
||||
if init == nil {
|
||||
t.Fatal("InitializeResult is nil after GetConnection")
|
||||
}
|
||||
if init.Capabilities.Tasks == nil {
|
||||
t.Fatal("server did not advertise Tasks capability — initialize handshake regressed")
|
||||
}
|
||||
if !conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be true for a server with toolCallTasks=true")
|
||||
}
|
||||
if !pool.ServerSupportsToolTasks("tasks") {
|
||||
t.Error("ServerSupportsToolTasks should mirror the connection's value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolDetectsAbsentTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: newSyncOnlyServer()}
|
||||
conn, err := pool.GetConnection(context.Background(), "sync", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
if conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be false for a server that didn't advertise the capability")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsToolTasksFromInit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *mcp.InitializeResult
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"no tasks", &mcp.InitializeResult{}, false},
|
||||
{"tasks no requests", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: &mcp.TasksCapability{}},
|
||||
}, false},
|
||||
{"tasks with toolCalls", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: mcp.NewTasksCapability()},
|
||||
}, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := supportsToolTasksFromInit(tc.in); got != tc.want {
|
||||
t.Errorf("supportsToolTasksFromInit() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTaskMode(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want MCPTaskMode
|
||||
}{
|
||||
{"", MCPTaskModeAuto},
|
||||
{"auto", MCPTaskModeAuto},
|
||||
{"AUTO", MCPTaskModeAuto},
|
||||
{"never", MCPTaskModeNever},
|
||||
{"off", MCPTaskModeNever},
|
||||
{"always", MCPTaskModeAlways},
|
||||
{"force", MCPTaskModeAlways},
|
||||
{"bogus", MCPTaskModeAuto},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := ParseTaskMode(tc.in); got != tc.want {
|
||||
t.Errorf("ParseTaskMode(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolPollsTaskToCompletion(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 20 * time.Millisecond,
|
||||
MaxPollInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hello"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if res.IsError {
|
||||
t.Fatalf("expected non-error result, got %s", res.Content)
|
||||
}
|
||||
if !strings.Contains(res.Content, "echo:hello") {
|
||||
t.Errorf("expected result to contain 'echo:hello', got %s", res.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolHonorsNeverMode(t *testing.T) {
|
||||
// Even though the server advertises tasks/toolCalls, "never" should
|
||||
// keep the call synchronous. Since the tool is TaskSupportRequired,
|
||||
// the server returns an error rather than running it sync — we just
|
||||
// verify the error surfaces (not a poll-loop hang).
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PerServerMode: map[string]MCPTaskMode{"tasks": MCPTaskModeNever},
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// We don't care which way the server fails the sync call; we just want
|
||||
// to confirm we didn't hang in the polling loop and didn't panic.
|
||||
_, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"x"}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error when forcing sync execution of a task-required tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolEmitsProgress(t *testing.T) {
|
||||
var statuses []mcp.TaskStatus
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 10 * time.Millisecond,
|
||||
MaxPollInterval: 25 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
Progress: func(p MCPTaskProgress) {
|
||||
statuses = append(statuses, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 30*time.Millisecond),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hi"}`); err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if len(statuses) == 0 {
|
||||
t.Fatal("expected at least one progress event")
|
||||
}
|
||||
last := statuses[len(statuses)-1]
|
||||
if !last.IsTerminal() {
|
||||
t.Errorf("last progress event should be terminal, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGetCancelMCPTasksOnLoadedServer(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// tasks/list — no in-flight tasks yet, so we just verify the call
|
||||
// succeeds and returns an empty slice (or any slice; the exact length
|
||||
// depends on server retention policy).
|
||||
if _, err := mgr.ListServerTasks(ctx, "tasks"); err != nil {
|
||||
t.Errorf("ListServerTasks: %v", err)
|
||||
}
|
||||
|
||||
// Unknown server should error cleanly without panicking.
|
||||
if _, err := mgr.GetServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("GetServerTask on unknown server should error")
|
||||
}
|
||||
if _, err := mgr.CancelServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("CancelServerTask on unknown server should error")
|
||||
}
|
||||
}
|
||||
@@ -28,15 +28,6 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoBorder returns a renderingOption that disables all borders on the
|
||||
// block, rendering content with only padding.
|
||||
func WithNoBorder() renderingOption {
|
||||
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackground returns a renderingOption that sets the background color
|
||||
// for the entire block. The color parameter accepts any color.Color value,
|
||||
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
|
||||
func WithBackground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.background = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
}
|
||||
}
|
||||
|
||||
// renderContentBlock renders content with configurable styling options
|
||||
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
|
||||
renderer := &blockRenderer{
|
||||
|
||||
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
|
||||
@@ -199,18 +199,6 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
names = append(names, cmd.Name)
|
||||
names = append(names, cmd.Aliases...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ExtensionCommand is a slash command registered by an extension. Unlike
|
||||
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
|
||||
// extension commands carry their own Execute callback.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
}
|
||||
|
||||
// Format the message to include all the debug info in a structured way
|
||||
var formattedMessage string
|
||||
|
||||
// Check if this is a multi-line debug output (like connection info)
|
||||
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
|
||||
// Extract the tag and content
|
||||
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
|
||||
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Add appropriate emoji based on the message content
|
||||
if strings.Contains(content, "Creating new connection") {
|
||||
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
|
||||
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
|
||||
} else if strings.Contains(content, "Reusing") {
|
||||
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
|
||||
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
|
||||
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
|
||||
} else if strings.Contains(content, "closed") {
|
||||
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Failed to close") {
|
||||
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
|
||||
} else {
|
||||
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
|
||||
// Use the CLI's debug message rendering
|
||||
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileSuggestion represents a single file, directory, or MCP resource
|
||||
@@ -31,6 +33,51 @@ type FileSuggestion struct {
|
||||
// maxFileSuggestions is the maximum number of file suggestions returned.
|
||||
const maxFileSuggestions = 20
|
||||
|
||||
// fileListCache caches the result of listFiles() keyed by directory to avoid
|
||||
// re-running git subprocesses on every keystroke during @file completion.
|
||||
var fileListCache struct {
|
||||
mu sync.Mutex
|
||||
dir string // searchDir that produced the cached entries
|
||||
cwd string // cwd used for the git query
|
||||
entries []FileSuggestion // cached file list
|
||||
expireAt time.Time // when the cache entry expires
|
||||
}
|
||||
|
||||
// fileListCacheTTL controls how long a cached file list stays valid.
|
||||
// During rapid typing the list is reused; after the TTL a fresh git
|
||||
// ls-files is executed so newly created files become visible.
|
||||
const fileListCacheTTL = 3 * time.Second
|
||||
|
||||
// getCachedFileList returns the file list for searchDir, using a short-lived
|
||||
// cache to avoid repeated subprocess calls during @file autocompletion.
|
||||
func getCachedFileList(searchDir, cwd string) []FileSuggestion {
|
||||
fileListCache.mu.Lock()
|
||||
defer fileListCache.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if fileListCache.dir == searchDir &&
|
||||
fileListCache.cwd == cwd &&
|
||||
now.Before(fileListCache.expireAt) {
|
||||
// Return a copy so callers can mutate (e.g. prepend baseDir).
|
||||
cp := make([]FileSuggestion, len(fileListCache.entries))
|
||||
copy(cp, fileListCache.entries)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Cache miss or expired — run the real (potentially expensive) lookup.
|
||||
files := listFiles(searchDir, cwd)
|
||||
|
||||
fileListCache.dir = searchDir
|
||||
fileListCache.cwd = cwd
|
||||
fileListCache.entries = files
|
||||
fileListCache.expireAt = now.Add(fileListCacheTTL)
|
||||
|
||||
// Return a copy.
|
||||
cp := make([]FileSuggestion, len(files))
|
||||
copy(cp, files)
|
||||
return cp
|
||||
}
|
||||
|
||||
// ExtractAtPrefix checks the current line for an @-file trigger at cursorCol.
|
||||
// It returns:
|
||||
// - hasAt: true if a valid @ trigger was found
|
||||
@@ -99,7 +146,7 @@ func GetFileSuggestions(prefix string, cwd string) []FileSuggestion {
|
||||
}
|
||||
}
|
||||
|
||||
files := listFiles(searchDir, cwd)
|
||||
files := getCachedFileList(searchDir, cwd)
|
||||
if len(files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,17 +25,6 @@ type TextMessageItem struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
@@ -109,8 +98,8 @@ func (m *TextMessageItem) renderContent(width int) string {
|
||||
// It accumulates content chunks and re-renders on each update for live display.
|
||||
type StreamingMessageItem struct {
|
||||
id string
|
||||
role string // "assistant" or "reasoning"
|
||||
content string // Accumulated streaming content
|
||||
role string // "assistant" or "reasoning"
|
||||
content strings.Builder // Accumulated streaming content
|
||||
timestamp time.Time
|
||||
startTime time.Time // When streaming started (for live duration counter)
|
||||
modelName string
|
||||
@@ -156,10 +145,10 @@ func (s *StreamingMessageItem) Render(width int) string {
|
||||
durationMs = time.Since(s.startTime).Milliseconds()
|
||||
}
|
||||
ty := createTypography(style.GetTheme())
|
||||
rendered = render.ReasoningBlock(s.content, durationMs, width, ty, style.GetTheme())
|
||||
rendered = render.ReasoningBlock(s.content.String(), durationMs, width, ty, style.GetTheme())
|
||||
} else {
|
||||
// Render as assistant message
|
||||
rendered = render.AssistantBlock(s.content, width, style.GetTheme())
|
||||
rendered = render.AssistantBlock(s.content.String(), width, style.GetTheme())
|
||||
}
|
||||
|
||||
// Cache and return (but reasoning is never cached due to live duration)
|
||||
@@ -187,7 +176,7 @@ func (s *StreamingMessageItem) Height() int {
|
||||
|
||||
// AppendChunk adds a content chunk and invalidates the render cache.
|
||||
func (s *StreamingMessageItem) AppendChunk(chunk string) {
|
||||
s.content += chunk
|
||||
s.content.WriteString(chunk)
|
||||
s.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
@@ -243,9 +232,7 @@ func (m *StreamingBashOutputItem) Render(width int) string {
|
||||
|
||||
// Header with command
|
||||
if m.command != "" {
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
headerStyle := style.GetCachedStyles().BashHeader
|
||||
parts = append(parts, headerStyle.Render(fmt.Sprintf("▸ %s", m.command)))
|
||||
}
|
||||
|
||||
@@ -318,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -88,13 +88,9 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
"content": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
for key, val := range params {
|
||||
@@ -338,7 +334,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
headerLine += " " + style.GetCachedStyles().ToolMuted.Render(params)
|
||||
}
|
||||
|
||||
// Get body content
|
||||
|
||||
@@ -1881,6 +1881,10 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
} else {
|
||||
bashItem.AppendStdout(msg.Chunk)
|
||||
}
|
||||
// Invalidate cached height after mutation.
|
||||
if m.scrollList != nil {
|
||||
m.scrollList.InvalidateItemHeight(bashItem.ID())
|
||||
}
|
||||
|
||||
// Check height and cap if needed - we don't want streaming output to grow forever
|
||||
const maxStreamingBashHeight = 20 // Max lines to show during streaming
|
||||
@@ -2072,6 +2076,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.providerName = msg.ProviderName
|
||||
m.modelName = msg.ModelName
|
||||
|
||||
case app.UsageUpdatedEvent:
|
||||
// Token usage was updated after a completed LLM step. No state
|
||||
// changes needed — the UsageTracker was already mutated in-place.
|
||||
// Returning from Update() triggers View() which re-renders the
|
||||
// status bar with the latest token counts, cost, and context %.
|
||||
|
||||
case app.WidgetUpdateEvent:
|
||||
// Extension widget changed — recalculate height distribution so the
|
||||
// stream region accounts for widget space. View() will read the
|
||||
@@ -3696,6 +3706,10 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
// If last message is a StreamingMessageItem with matching role, append to it
|
||||
if streamMsg, ok := lastMsg.(*StreamingMessageItem); ok && streamMsg.role == role {
|
||||
streamMsg.AppendChunk(content)
|
||||
// Invalidate cached height so GotoBottom sees the new size.
|
||||
if m.scrollList != nil {
|
||||
m.scrollList.InvalidateItemHeight(streamMsg.ID())
|
||||
}
|
||||
// Auto-scroll to bottom if enabled (iteratr pattern)
|
||||
// Don't call SetItems() - the slice reference hasn't changed
|
||||
if m.scrollList != nil {
|
||||
|
||||
@@ -19,33 +19,12 @@ import (
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
// Any @file tokens in the content are highlighted with the theme accent color.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
// Highlight @file tokens with accent color so file references are
|
||||
// visually distinct from surrounding prompt text.
|
||||
content = HighlightFileTokens(content, theme)
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
// UserBlock-related rendering helpers and herald typography.
|
||||
|
||||
// HighlightFileTokens wraps @file tokens in the given text with the theme
|
||||
// accent color so they stand out visually in rendered user messages.
|
||||
func HighlightFileTokens(text string, theme style.Theme) string {
|
||||
accentStyle := lipgloss.NewStyle().Foreground(theme.Accent).Bold(true)
|
||||
accentStyle := style.GetCachedStyles().FileTokenAccent
|
||||
return fileTokenPattern.ReplaceAllStringFunc(text, func(token string) string {
|
||||
return accentStyle.Render(token)
|
||||
})
|
||||
@@ -75,8 +54,8 @@ func ReasoningBlock(content string, duration int64, width int, ty *herald.Typogr
|
||||
if width > 4 {
|
||||
contentStr = wrapText(contentStr, width-4)
|
||||
}
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
contentRendered := mutedStyle.Render(ty.Italic(contentStr))
|
||||
cs := style.GetCachedStyles()
|
||||
contentRendered := cs.Muted.Render(ty.Italic(contentStr))
|
||||
|
||||
// Build label based on duration
|
||||
if duration > 0 {
|
||||
@@ -86,14 +65,14 @@ func ReasoningBlock(content string, duration int64, width int, ty *herald.Typogr
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", float64(duration)/1000)
|
||||
}
|
||||
labelPart := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationPart := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
labelPart := cs.VeryMuted.Render("Thought for ")
|
||||
durationPart := cs.Accent.Render(durationStr)
|
||||
label := labelPart + durationPart
|
||||
rendered := contentRendered + "\n" + label
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought")
|
||||
label := cs.VeryMuted.Render("Thought")
|
||||
rendered := contentRendered + "\n" + label
|
||||
|
||||
return styleMarginBottom(theme, rendered)
|
||||
@@ -154,47 +133,9 @@ func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) strin
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ToolBlock renders a tool execution result with header and body.
|
||||
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
return styleMarginBottom(theme, fullContent)
|
||||
}
|
||||
|
||||
// styleMarginBottom applies a 1-line margin bottom using the theme.
|
||||
func styleMarginBottom(theme style.Theme, content string) string {
|
||||
return lipgloss.NewStyle().MarginBottom(1).Render(content)
|
||||
return style.GetCachedStyles().MarginBottom1.Render(content)
|
||||
}
|
||||
|
||||
// wrapText soft-wraps a string to the given width using lipgloss, which is
|
||||
|
||||
@@ -4,30 +4,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// testTypography creates a herald Typography for tests.
|
||||
func testTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertLabel(herald.AlertTip, ""),
|
||||
herald.WithAlertIcon(herald.AlertTip, ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHighlightFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
|
||||
@@ -88,24 +67,25 @@ func TestHighlightFileTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserBlockHighlightsFileTokens(t *testing.T) {
|
||||
// TestHighlightFileTokensInjectsANSI verifies that HighlightFileTokens
|
||||
// preserves the original @file references in the output and wraps each
|
||||
// token with ANSI escape codes for the theme accent color.
|
||||
func TestHighlightFileTokensInjectsANSI(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
ty := testTypography(theme)
|
||||
|
||||
// A user message with @file tokens should contain ANSI escapes around the token.
|
||||
content := "refactor @main.go and @utils.go"
|
||||
result := UserBlock(content, 80, ty, theme)
|
||||
result := HighlightFileTokens(content, theme)
|
||||
|
||||
// The rendered output should contain both file references.
|
||||
// The output should still contain both file references.
|
||||
if !strings.Contains(result, "@main.go") {
|
||||
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @main.go, got:\n%s", result)
|
||||
}
|
||||
if !strings.Contains(result, "@utils.go") {
|
||||
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @utils.go, got:\n%s", result)
|
||||
}
|
||||
|
||||
// Verify ANSI codes are present (the tokens are styled).
|
||||
if !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
|
||||
t.Errorf("HighlightFileTokens output should contain ANSI escape codes for styled @file tokens")
|
||||
}
|
||||
}
|
||||
|
||||
+96
-66
@@ -35,6 +35,12 @@ type ScrollList struct {
|
||||
autoScroll bool // Whether to auto-scroll to bottom on new content
|
||||
itemGap int // Number of blank lines between items (0 = no gap)
|
||||
|
||||
// heightCache maps item ID → rendered line count at current width.
|
||||
// Avoids redundant Render() calls in GotoBottom/clampOffset/AtBottom.
|
||||
// Invalidated on width change; individual entries are refreshed in
|
||||
// View() when an item is actually rendered.
|
||||
heightCache map[string]int
|
||||
|
||||
// Character-level text selection (crush-style).
|
||||
sel selection.State
|
||||
}
|
||||
@@ -42,13 +48,14 @@ type ScrollList struct {
|
||||
// NewScrollList creates a new ScrollList with the given dimensions.
|
||||
func NewScrollList(width, height int) *ScrollList {
|
||||
return &ScrollList{
|
||||
items: []MessageItem{},
|
||||
offsetIdx: 0,
|
||||
offsetLine: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
autoScroll: true,
|
||||
sel: selection.NewState(),
|
||||
items: []MessageItem{},
|
||||
offsetIdx: 0,
|
||||
offsetLine: 0,
|
||||
width: width,
|
||||
height: height,
|
||||
autoScroll: true,
|
||||
heightCache: make(map[string]int, 64),
|
||||
sel: selection.NewState(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +68,13 @@ func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateItemHeight removes the cached height for the given item ID,
|
||||
// forcing a re-render on the next height query. Call this after mutating
|
||||
// an item's content (e.g. AppendChunk on a streaming message).
|
||||
func (s *ScrollList) InvalidateItemHeight(id string) {
|
||||
delete(s.heightCache, id)
|
||||
}
|
||||
|
||||
// SetHeight updates the viewport height. Called when the terminal is resized.
|
||||
func (s *ScrollList) SetHeight(height int) {
|
||||
s.height = height
|
||||
@@ -68,9 +82,11 @@ func (s *ScrollList) SetHeight(height int) {
|
||||
}
|
||||
|
||||
// SetWidth updates the viewport width. Called when the terminal is resized.
|
||||
// This may invalidate cached renders in MessageItems.
|
||||
// This invalidates the height cache since rendered heights are width-dependent.
|
||||
func (s *ScrollList) SetWidth(width int) {
|
||||
s.width = width
|
||||
// Width change invalidates all cached heights.
|
||||
clear(s.heightCache)
|
||||
s.clampOffset()
|
||||
}
|
||||
|
||||
@@ -338,9 +354,8 @@ func (s *ScrollList) ScrollBy(lines int) {
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
break
|
||||
}
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
remainingLines := itemHeight - s.offsetLine
|
||||
ih := s.itemHeight(s.items[s.offsetIdx])
|
||||
remainingLines := ih - s.offsetLine
|
||||
|
||||
if lines >= remainingLines {
|
||||
// Move to next item
|
||||
@@ -387,14 +402,13 @@ func (s *ScrollList) ScrollBy(lines int) {
|
||||
// Move to previous item
|
||||
s.offsetIdx--
|
||||
if s.offsetIdx < len(s.items) {
|
||||
currentItem := s.items[s.offsetIdx]
|
||||
itemHeight := currentItem.Height()
|
||||
ih := s.itemHeight(s.items[s.offsetIdx])
|
||||
|
||||
if lines >= itemHeight {
|
||||
lines -= itemHeight
|
||||
if lines >= ih {
|
||||
lines -= ih
|
||||
s.offsetLine = 0
|
||||
} else {
|
||||
s.offsetLine = itemHeight - lines
|
||||
s.offsetLine = ih - lines
|
||||
lines = 0
|
||||
}
|
||||
}
|
||||
@@ -405,6 +419,8 @@ func (s *ScrollList) ScrollBy(lines int) {
|
||||
}
|
||||
|
||||
// GotoBottom scrolls to the end of the list.
|
||||
// Uses cached heights and walks backwards from the end to avoid rendering
|
||||
// every item in the list.
|
||||
func (s *ScrollList) GotoBottom() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
@@ -412,42 +428,31 @@ func (s *ScrollList) GotoBottom() {
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate total height including gaps
|
||||
totalHeight := 0
|
||||
for i, item := range s.items {
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
totalHeight += itemHeight
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
totalHeight += s.itemGap
|
||||
// Walk backwards from the last item, accumulating height until we
|
||||
// exceed the viewport. This is O(visible) instead of O(all items).
|
||||
budget := s.height
|
||||
for idx := len(s.items) - 1; idx >= 0; idx-- {
|
||||
ih := s.itemHeight(s.items[idx])
|
||||
|
||||
// Account for gap *above* this item (gap between idx-1 and idx).
|
||||
gap := 0
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
gap = s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// If content fits in viewport, start at top
|
||||
if totalHeight <= s.height {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, position viewport at bottom
|
||||
remaining := totalHeight - s.height
|
||||
for idx := 0; idx < len(s.items); idx++ {
|
||||
rendered := s.items[idx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
if remaining < itemHeight {
|
||||
if ih+gap >= budget {
|
||||
// This item (partially) fills the remaining budget.
|
||||
// When the gap consumed part of the budget, offsetLine would go
|
||||
// negative — clamp to 0 so the item is shown fully.
|
||||
s.offsetIdx = idx
|
||||
s.offsetLine = remaining
|
||||
s.offsetLine = max(0, ih-budget)
|
||||
return
|
||||
}
|
||||
remaining -= itemHeight
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
remaining -= s.itemGap
|
||||
}
|
||||
budget -= ih + gap
|
||||
}
|
||||
|
||||
// Fallback: show last item
|
||||
s.offsetIdx = max(0, len(s.items)-1)
|
||||
// All content fits in viewport — start at top.
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
@@ -465,14 +470,12 @@ func (s *ScrollList) AtBottom() bool {
|
||||
|
||||
visibleHeight := 0
|
||||
for idx := s.offsetIdx; idx < len(s.items); idx++ {
|
||||
item := s.items[idx]
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
ih := s.itemHeight(s.items[idx])
|
||||
|
||||
if idx == s.offsetIdx {
|
||||
visibleHeight += itemHeight - s.offsetLine
|
||||
visibleHeight += ih - s.offsetLine
|
||||
} else {
|
||||
visibleHeight += itemHeight
|
||||
visibleHeight += ih
|
||||
}
|
||||
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
@@ -520,6 +523,9 @@ func (s *ScrollList) View() string {
|
||||
content := item.Render(s.width)
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
// Refresh height cache from the actual render (authoritative).
|
||||
s.heightCache[item.ID()] = len(contentLines)
|
||||
|
||||
startLine := 0
|
||||
if idx == s.offsetIdx {
|
||||
startLine = s.offsetLine
|
||||
@@ -568,7 +574,7 @@ func (s *ScrollList) ScrollPercent() float64 {
|
||||
|
||||
totalHeight := 0
|
||||
for _, item := range s.items {
|
||||
totalHeight += item.Height()
|
||||
totalHeight += s.itemHeight(item)
|
||||
}
|
||||
|
||||
if totalHeight <= s.height {
|
||||
@@ -577,7 +583,7 @@ func (s *ScrollList) ScrollPercent() float64 {
|
||||
|
||||
linesAbove := 0
|
||||
for i := 0; i < s.offsetIdx && i < len(s.items); i++ {
|
||||
linesAbove += s.items[i].Height()
|
||||
linesAbove += s.itemHeight(s.items[i])
|
||||
}
|
||||
linesAbove += s.offsetLine
|
||||
|
||||
@@ -597,7 +603,8 @@ func (s *ScrollList) ScrollPercent() float64 {
|
||||
}
|
||||
|
||||
// clampOffset ensures the offset values are within valid bounds after
|
||||
// resizing or scrolling operations.
|
||||
// resizing or scrolling operations. Uses cached heights to avoid
|
||||
// redundant Render() calls.
|
||||
func (s *ScrollList) clampOffset() {
|
||||
if len(s.items) == 0 {
|
||||
s.offsetIdx = 0
|
||||
@@ -605,6 +612,7 @@ func (s *ScrollList) clampOffset() {
|
||||
return
|
||||
}
|
||||
|
||||
// Clamp offsetIdx to valid item range.
|
||||
if s.offsetIdx >= len(s.items) {
|
||||
s.offsetIdx = len(s.items) - 1
|
||||
}
|
||||
@@ -612,37 +620,38 @@ func (s *ScrollList) clampOffset() {
|
||||
s.offsetIdx = 0
|
||||
}
|
||||
|
||||
// Clamp offsetLine within current item.
|
||||
if s.offsetIdx < len(s.items) {
|
||||
rendered := s.items[s.offsetIdx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
if s.offsetLine >= itemHeight {
|
||||
s.offsetLine = max(0, itemHeight-1)
|
||||
ih := s.itemHeight(s.items[s.offsetIdx])
|
||||
if s.offsetLine >= ih {
|
||||
s.offsetLine = max(0, ih-1)
|
||||
}
|
||||
}
|
||||
if s.offsetLine < 0 {
|
||||
s.offsetLine = 0
|
||||
}
|
||||
|
||||
// Prevent scrolling past the bottom
|
||||
// Prevent scrolling past the bottom — compute total height and check
|
||||
// whether remaining content from the current offset fills the viewport.
|
||||
totalHeight := 0
|
||||
for i, item := range s.items {
|
||||
rendered := item.Render(s.width)
|
||||
totalHeight += strings.Count(rendered, "\n") + 1
|
||||
totalHeight += s.itemHeight(item)
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
totalHeight += s.itemGap
|
||||
}
|
||||
}
|
||||
|
||||
// If content fits in viewport, force start at top.
|
||||
if totalHeight <= s.height {
|
||||
s.offsetIdx = 0
|
||||
s.offsetLine = 0
|
||||
return
|
||||
}
|
||||
|
||||
// Compute lines above the viewport.
|
||||
linesAbove := 0
|
||||
for i := 0; i < s.offsetIdx; i++ {
|
||||
rendered := s.items[i].Render(s.width)
|
||||
linesAbove += strings.Count(rendered, "\n") + 1
|
||||
linesAbove += s.itemHeight(s.items[i])
|
||||
if s.itemGap > 0 && i < len(s.items)-1 {
|
||||
linesAbove += s.itemGap
|
||||
}
|
||||
@@ -651,20 +660,21 @@ func (s *ScrollList) clampOffset() {
|
||||
|
||||
linesFromCurrentToEnd := totalHeight - linesAbove
|
||||
if linesFromCurrentToEnd < s.height {
|
||||
// We've scrolled past the bottom — reposition so the last line
|
||||
// of content sits at the bottom of the viewport.
|
||||
targetLine := totalHeight - s.height
|
||||
currentLine := 0
|
||||
|
||||
for idx := 0; idx < len(s.items); idx++ {
|
||||
rendered := s.items[idx].Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
ih := s.itemHeight(s.items[idx])
|
||||
|
||||
if currentLine+itemHeight > targetLine {
|
||||
if currentLine+ih > targetLine {
|
||||
s.offsetIdx = idx
|
||||
s.offsetLine = targetLine - currentLine
|
||||
return
|
||||
}
|
||||
|
||||
currentLine += itemHeight
|
||||
currentLine += ih
|
||||
if s.itemGap > 0 && idx < len(s.items)-1 {
|
||||
currentLine += s.itemGap
|
||||
}
|
||||
@@ -672,6 +682,26 @@ func (s *ScrollList) clampOffset() {
|
||||
}
|
||||
}
|
||||
|
||||
// itemHeight returns the cached rendered height for an item, computing and
|
||||
// caching it on first access. This avoids calling Render() purely to
|
||||
// count lines — the most common source of redundant work in the scroll
|
||||
// list (GotoBottom, clampOffset, AtBottom, ScrollBy all need heights but
|
||||
// never use the rendered content).
|
||||
//
|
||||
// The cache is invalidated wholesale on width changes (SetWidth) and
|
||||
// individual entries are refreshed in View() after an item is actually
|
||||
// rendered, so stale entries are self-correcting within one frame.
|
||||
func (s *ScrollList) itemHeight(item MessageItem) int {
|
||||
id := item.ID()
|
||||
if h, ok := s.heightCache[id]; ok {
|
||||
return h
|
||||
}
|
||||
// Cache miss — render to measure.
|
||||
h := s.renderedHeight(item)
|
||||
s.heightCache[id] = h
|
||||
return h
|
||||
}
|
||||
|
||||
// renderedHeight returns the height of a message item in lines by actually
|
||||
// rendering it. This is the single source of truth for item height — it
|
||||
// matches exactly what View() produces, unlike item.Height() which may
|
||||
|
||||
+9
-11
@@ -21,12 +21,11 @@ func knightRiderFrames() []string {
|
||||
const numDots = 8
|
||||
const dot = "▪"
|
||||
|
||||
theme := style.GetTheme()
|
||||
|
||||
bright := lipgloss.NewStyle().Foreground(theme.Primary)
|
||||
med := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
dim := lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
off := lipgloss.NewStyle().Foreground(theme.MutedBorder)
|
||||
cs := style.GetCachedStyles()
|
||||
bright := cs.SpinnerBright
|
||||
med := cs.SpinnerMed
|
||||
dim := cs.SpinnerDim
|
||||
off := cs.SpinnerOff
|
||||
|
||||
// Scanner bounces: 0→7→0
|
||||
positions := make([]int, 0, 2*numDots-2)
|
||||
@@ -476,9 +475,8 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
if s.width > 4 {
|
||||
content = lipgloss.NewStyle().Width(s.width - 4).Render(content)
|
||||
}
|
||||
theme := GetTheme()
|
||||
mutedStyle := lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
parts = append(parts, mutedStyle.Render(s.ty.Italic(content)))
|
||||
cs := style.GetCachedStyles()
|
||||
parts = append(parts, cs.Muted.Render(s.ty.Italic(content)))
|
||||
|
||||
// Duration footer with VeryMuted label and Accent duration.
|
||||
var duration time.Duration
|
||||
@@ -494,8 +492,8 @@ func (s *StreamComponent) renderReasoningBlock(reasoning string) string {
|
||||
} else {
|
||||
durationStr = fmt.Sprintf("%.1fs", duration.Seconds())
|
||||
}
|
||||
label := lipgloss.NewStyle().Foreground(theme.VeryMuted).Render("Thought for ")
|
||||
durationStyled := lipgloss.NewStyle().Foreground(theme.Accent).Render(durationStr)
|
||||
label := cs.VeryMuted.Render("Thought for ")
|
||||
durationStyled := cs.Accent.Render(durationStr)
|
||||
parts = append(parts, label+durationStyled)
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,70 @@ func GetTheme() Theme {
|
||||
func SetTheme(theme Theme) {
|
||||
currentTheme = theme
|
||||
markdownTypographyCache = nil // invalidate cached renderer; colors may have changed
|
||||
styleCache = nil // invalidate cached styles; colors may have changed
|
||||
}
|
||||
|
||||
// CachedStyles holds pre-built lipgloss styles that are reused across
|
||||
// render frames. Invalidated by SetTheme, lazily rebuilt on next access.
|
||||
// Only accessed from BubbleTea's single-threaded Update/View cycle.
|
||||
type CachedStyles struct {
|
||||
// render/blocks.go
|
||||
FileTokenAccent lipgloss.Style // Foreground(Accent).Bold(true)
|
||||
Muted lipgloss.Style // Foreground(Muted)
|
||||
VeryMuted lipgloss.Style // Foreground(VeryMuted)
|
||||
Accent lipgloss.Style // Foreground(Accent)
|
||||
MarginBottom1 lipgloss.Style // MarginBottom(1)
|
||||
|
||||
// stream.go - spinner phases
|
||||
SpinnerBright lipgloss.Style // Foreground(Primary)
|
||||
SpinnerMed lipgloss.Style // Foreground(Muted)
|
||||
SpinnerDim lipgloss.Style // Foreground(VeryMuted)
|
||||
SpinnerOff lipgloss.Style // Foreground(MutedBorder)
|
||||
|
||||
// message_items.go - bash output
|
||||
BashHeader lipgloss.Style // Foreground(Muted).Italic(true)
|
||||
BashStderr lipgloss.Style // Foreground(Error)
|
||||
|
||||
// render/blocks.go - tool block
|
||||
ToolSuccess lipgloss.Style // Foreground(Success)
|
||||
ToolError lipgloss.Style // Foreground(Error)
|
||||
ToolInfo lipgloss.Style // Foreground(Info).Bold(true)
|
||||
ToolMuted lipgloss.Style // Foreground(Muted)
|
||||
|
||||
// common
|
||||
ErrorFg lipgloss.Style // Foreground(Error)
|
||||
TextBold lipgloss.Style // Foreground(Text).Bold(true)
|
||||
}
|
||||
|
||||
var styleCache *CachedStyles
|
||||
|
||||
// GetCachedStyles returns the pre-built style cache, creating it lazily
|
||||
// from the current theme. Invalidated by SetTheme.
|
||||
func GetCachedStyles() *CachedStyles {
|
||||
if styleCache != nil {
|
||||
return styleCache
|
||||
}
|
||||
theme := GetTheme()
|
||||
styleCache = &CachedStyles{
|
||||
FileTokenAccent: lipgloss.NewStyle().Foreground(theme.Accent).Bold(true),
|
||||
Muted: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
VeryMuted: lipgloss.NewStyle().Foreground(theme.VeryMuted),
|
||||
Accent: lipgloss.NewStyle().Foreground(theme.Accent),
|
||||
MarginBottom1: lipgloss.NewStyle().MarginBottom(1),
|
||||
SpinnerBright: lipgloss.NewStyle().Foreground(theme.Primary),
|
||||
SpinnerMed: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
SpinnerDim: lipgloss.NewStyle().Foreground(theme.VeryMuted),
|
||||
SpinnerOff: lipgloss.NewStyle().Foreground(theme.MutedBorder),
|
||||
BashHeader: lipgloss.NewStyle().Foreground(theme.Muted).Italic(true),
|
||||
BashStderr: lipgloss.NewStyle().Foreground(theme.Error),
|
||||
ToolSuccess: lipgloss.NewStyle().Foreground(theme.Success),
|
||||
ToolError: lipgloss.NewStyle().Foreground(theme.Error),
|
||||
ToolInfo: lipgloss.NewStyle().Foreground(theme.Info).Bold(true),
|
||||
ToolMuted: lipgloss.NewStyle().Foreground(theme.Muted),
|
||||
ErrorFg: lipgloss.NewStyle().Foreground(theme.Error),
|
||||
TextBold: lipgloss.NewStyle().Foreground(theme.Text).Bold(true),
|
||||
}
|
||||
return styleCache
|
||||
}
|
||||
|
||||
// MarkdownThemeColors defines colors for markdown rendering and syntax highlighting.
|
||||
@@ -147,106 +211,11 @@ func DefaultTheme() Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
|
||||
// padding, and appropriate width. Used for grouping related content in a visually
|
||||
// distinct box.
|
||||
func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Border).
|
||||
Padding(1, 2).
|
||||
MarginBottom(1)
|
||||
}
|
||||
|
||||
// IsDarkBackground returns the cached terminal background detection result.
|
||||
func IsDarkBackground() bool {
|
||||
return isDarkBg
|
||||
}
|
||||
|
||||
// StyleHeader creates a lipgloss style for primary headers using the theme's
|
||||
// primary color with bold text for emphasis and hierarchy.
|
||||
func StyleHeader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
|
||||
// secondary color with bold text, providing visual hierarchy below primary headers.
|
||||
func StyleSubheader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
|
||||
// and italic formatting, suitable for supplementary or less important information.
|
||||
func StyleMuted(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
}
|
||||
|
||||
// StyleSuccess creates a lipgloss style for success messages using green colors
|
||||
// with bold text to indicate successful operations or positive outcomes.
|
||||
func StyleSuccess(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleError creates a lipgloss style for error messages using red colors
|
||||
// with bold text to ensure visibility of problems or failures.
|
||||
func StyleError(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
|
||||
// colors with bold text to draw attention to potential issues or cautions.
|
||||
func StyleWarning(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleInfo creates a lipgloss style for informational messages using blue colors
|
||||
// with bold text for general notifications and status updates.
|
||||
func StyleInfo(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// CreateSeparator generates a horizontal separator line with the specified width,
|
||||
// character, and color. Useful for visually dividing sections of content in the UI.
|
||||
func CreateSeparator(width int, char string, c color.Color) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Width(width).
|
||||
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
|
||||
}
|
||||
|
||||
// CreateProgressBar generates a visual progress bar with filled and empty segments
|
||||
// based on the percentage complete. The bar uses Unicode block characters for smooth
|
||||
// appearance and theme colors to indicate progress.
|
||||
func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
filled := int(float64(width) * percentage / 100)
|
||||
empty := width - filled
|
||||
|
||||
filledBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Render(lipgloss.PlaceHorizontal(filled, lipgloss.Left, "█"))
|
||||
|
||||
emptyBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Render(lipgloss.PlaceHorizontal(empty, lipgloss.Left, "░"))
|
||||
|
||||
return filledBar + emptyBar
|
||||
}
|
||||
|
||||
// CreateBadge generates a styled badge or label with inverted colors (text on
|
||||
// colored background) for highlighting important tags, statuses, or categories.
|
||||
func CreateBadge(text string, c color.Color) string {
|
||||
|
||||
@@ -6,13 +6,6 @@ import (
|
||||
heraldmd "github.com/indaco/herald-md"
|
||||
)
|
||||
|
||||
// BaseStyle returns a new, empty lipgloss style that can be customized with
|
||||
// additional styling methods. This serves as the foundation for building more
|
||||
// complex styled components.
|
||||
func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// markdownTypographyCache holds the last-created Typography instance for
|
||||
// herald-md rendering. It is cached to avoid re-initialization on every
|
||||
// streaming flush tick. The cache is invalidated by SetTheme when the
|
||||
|
||||
@@ -543,12 +543,6 @@ func ApplyThemeWithoutSave(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshThemeRegistry re-scans the themes directory. Call after the user
|
||||
// drops a new file into ~/.config/kit/themes/.
|
||||
func RefreshThemeRegistry() {
|
||||
initThemeRegistry()
|
||||
}
|
||||
|
||||
// RegisterThemeFromConfig adds a theme to the runtime registry from an
|
||||
// extension's ThemeColorConfig (string hex pairs). Replaces any existing
|
||||
// entry with the same name. The theme is immediately available via
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
type ToolApprovalInput struct {
|
||||
textarea textarea.Model
|
||||
toolName string
|
||||
toolArgs string
|
||||
width int
|
||||
selected bool // true when "yes" is highlighted and false when "no" is
|
||||
approved bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &ToolApprovalInput{
|
||||
textarea: ta,
|
||||
toolName: toolName,
|
||||
toolArgs: toolArgs,
|
||||
width: width,
|
||||
selected: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyPressMsg:
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
t.approved = true
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "n", "N":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "left":
|
||||
t.selected = true
|
||||
return t, nil
|
||||
case "right":
|
||||
t.selected = false
|
||||
return t, nil
|
||||
case "enter":
|
||||
t.approved = t.selected
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "esc", "ctrl+c":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) View() tea.View {
|
||||
if t.done {
|
||||
return tea.NewView("we are done")
|
||||
}
|
||||
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(t.width - 1) // full width minus left border
|
||||
|
||||
// Style for the currently selected/highlighted option
|
||||
selectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true).
|
||||
Underline(true)
|
||||
|
||||
// Style for the unselected/unhighlighted option
|
||||
unselectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted)
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render("Allow tool execution"))
|
||||
view.WriteString("\n")
|
||||
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
|
||||
view.WriteString(details)
|
||||
view.WriteString("Allow tool execution: ")
|
||||
|
||||
var yesText, noText string
|
||||
if t.selected {
|
||||
yesText = selectedStyle.Render("[y]es")
|
||||
noText = unselectedStyle.Render("[n]o")
|
||||
} else {
|
||||
yesText = unselectedStyle.Render("[y]es")
|
||||
noText = selectedStyle.Render("[n]o")
|
||||
}
|
||||
view.WriteString(yesText + "/" + noText + "\n")
|
||||
|
||||
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
|
||||
}
|
||||
@@ -79,8 +79,7 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// Edit tool — side-by-side diff
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
// renderEditBody renders a side-by-side diff from the edits array in toolArgs.
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
@@ -90,35 +89,28 @@ func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
editsArr, ok := args["edits"].([]any)
|
||||
if !ok || len(editsArr) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractDiffStartLine parses the first @@ hunk header from a unified diff
|
||||
|
||||
+36
-1
@@ -106,7 +106,7 @@ unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
unsub3 := host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
@@ -190,6 +190,41 @@ msg, err := host.GetMCPPrompt(ctx, "server-name", "prompt-name", map[string]stri
|
||||
})
|
||||
```
|
||||
|
||||
### MCP Tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`. Cooperating servers can respond to `tools/call` with a
|
||||
`taskId` immediately; Kit then polls `tasks/get` / `tasks/result` until the
|
||||
task reaches a terminal state, and best-effort `tasks/cancel`s on context
|
||||
cancellation. Servers that don't advertise the capability keep their previous
|
||||
synchronous behaviour.
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
// Per-server mode: auto (default), never, or always.
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskTimeout: 15 * time.Minute, // total wall-clock cap
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s/%s: %s", p.Server, p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
// Inspect / cancel in-flight tasks
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
t, _ := host.GetMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
if !t.Status.IsTerminal() {
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", t.TaskID)
|
||||
}
|
||||
```
|
||||
|
||||
The progress handler fires once when a task is accepted and again on every
|
||||
observed status transition; the final invocation always carries a terminal
|
||||
status (`MCPTaskStatusCompleted`, `MCPTaskStatusFailed`, or
|
||||
`MCPTaskStatusCancelled`). Don't block in the handler — dispatch long work on
|
||||
a goroutine.
|
||||
|
||||
### Session Management
|
||||
|
||||
Maintain conversation context:
|
||||
|
||||
+339
-3
@@ -58,6 +58,31 @@ const (
|
||||
// EventSteerConsumed fires when one or more steering messages have been
|
||||
// injected into the agent turn via PrepareStep.
|
||||
EventSteerConsumed EventType = "steer_consumed"
|
||||
// EventStepStart fires when a new LLM call begins within a turn.
|
||||
EventStepStart EventType = "step_start"
|
||||
// EventStepFinish fires when a step completes, providing full step context
|
||||
// including whether tool calls were made, the finish reason, and usage stats.
|
||||
EventStepFinish EventType = "step_finish"
|
||||
// EventTextStart fires when the LLM begins generating text content.
|
||||
EventTextStart EventType = "text_start"
|
||||
// EventTextEnd fires when the LLM finishes generating text content.
|
||||
EventTextEnd EventType = "text_end"
|
||||
// EventReasoningStart fires when the LLM begins reasoning/thinking.
|
||||
EventReasoningStart EventType = "reasoning_start"
|
||||
// EventWarnings fires when the LLM provider returns warnings.
|
||||
EventWarnings EventType = "warnings"
|
||||
// EventSource fires when the LLM references a source (e.g. from web search).
|
||||
EventSource EventType = "source"
|
||||
// EventStreamFinish fires when a per-step LLM stream completes with
|
||||
// usage stats and a finish reason.
|
||||
EventStreamFinish EventType = "stream_finish"
|
||||
// EventError fires when an agent-level error occurs during streaming.
|
||||
// This is distinct from TurnEndEvent.Error — it fires at the point of
|
||||
// failure, before the turn ends.
|
||||
EventError EventType = "error"
|
||||
// EventRetry fires when the LLM provider request is retried after a
|
||||
// transient error.
|
||||
EventRetry EventType = "retry"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -123,9 +148,9 @@ func parseToolArgs(toolArgs string) map[string]any {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Finish reasons reported by the LLM provider on a completed turn. These
|
||||
// mirror fantasy.FinishReason string values so comparisons against
|
||||
// TurnEndEvent.StopReason / TurnResult.StopReason are stable across
|
||||
// providers.
|
||||
// mirror the underlying provider's finish reason string values so
|
||||
// comparisons against TurnEndEvent.StopReason / TurnResult.StopReason are
|
||||
// stable across providers.
|
||||
const (
|
||||
// FinishReasonStop: the model produced a natural stop (e.g. stop sequence
|
||||
// or end-of-turn signal).
|
||||
@@ -379,6 +404,100 @@ type SteerConsumedEvent struct {
|
||||
// EventType implements Event.
|
||||
func (e SteerConsumedEvent) EventType() EventType { return EventSteerConsumed }
|
||||
|
||||
// StepStartEvent fires when a new LLM call begins within a multi-step agent turn.
|
||||
type StepStartEvent struct {
|
||||
StepNumber int
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StepStartEvent) EventType() EventType { return EventStepStart }
|
||||
|
||||
// StepFinishEvent fires when a step completes, providing full step context.
|
||||
// This is a unified event that carries the same data as the existing
|
||||
// ToolCallContentEvent and StepUsageEvent, plus additional step metadata.
|
||||
type StepFinishEvent struct {
|
||||
StepNumber int
|
||||
HasToolCalls bool
|
||||
FinishReason string
|
||||
Usage LLMUsage
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StepFinishEvent) EventType() EventType { return EventStepFinish }
|
||||
|
||||
// TextStartEvent fires when the LLM begins generating text content.
|
||||
// Paired with MessageUpdateEvent (deltas) and TextEndEvent.
|
||||
type TextStartEvent struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e TextStartEvent) EventType() EventType { return EventTextStart }
|
||||
|
||||
// TextEndEvent fires when the LLM finishes generating text content.
|
||||
type TextEndEvent struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e TextEndEvent) EventType() EventType { return EventTextEnd }
|
||||
|
||||
// ReasoningStartEvent fires when the LLM begins reasoning/thinking.
|
||||
// Paired with ReasoningDeltaEvent (deltas) and ReasoningCompleteEvent.
|
||||
type ReasoningStartEvent struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e ReasoningStartEvent) EventType() EventType { return EventReasoningStart }
|
||||
|
||||
// WarningsEvent fires when the LLM provider returns warnings about the request.
|
||||
type WarningsEvent struct {
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e WarningsEvent) EventType() EventType { return EventWarnings }
|
||||
|
||||
// SourceEvent fires when the LLM references a source (e.g. from web search tools).
|
||||
type SourceEvent struct {
|
||||
SourceType string
|
||||
ID string
|
||||
URL string
|
||||
Title string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e SourceEvent) EventType() EventType { return EventSource }
|
||||
|
||||
// StreamFinishEvent fires when a per-step LLM stream completes.
|
||||
// Provides per-stream usage stats and finish reason.
|
||||
type StreamFinishEvent struct {
|
||||
Usage LLMUsage
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e StreamFinishEvent) EventType() EventType { return EventStreamFinish }
|
||||
|
||||
// ErrorEvent fires when an agent-level error occurs during streaming.
|
||||
// This is distinct from TurnEndEvent.Error — it fires at the point of failure.
|
||||
type ErrorEvent struct {
|
||||
Error error
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e ErrorEvent) EventType() EventType { return EventError }
|
||||
|
||||
// RetryEvent fires when the LLM provider request is retried after a transient error.
|
||||
type RetryEvent struct {
|
||||
Attempt int
|
||||
Error error
|
||||
}
|
||||
|
||||
// EventType implements Event.
|
||||
func (e RetryEvent) EventType() EventType { return EventRetry }
|
||||
|
||||
// PasswordPromptEvent fires when a sudo command needs a password.
|
||||
// The TUI should display a password prompt and send the result back via ResponseCh.
|
||||
type PasswordPromptEvent struct {
|
||||
@@ -517,7 +636,16 @@ func (m *Kit) OnToolOutput(handler func(ToolOutputEvent)) func() {
|
||||
|
||||
// OnStreaming registers a handler that fires only for MessageUpdateEvent
|
||||
// (streaming text chunks). Returns an unsubscribe function.
|
||||
//
|
||||
// Deprecated: Use OnMessageUpdate instead. OnStreaming will be removed in a
|
||||
// future release.
|
||||
func (m *Kit) OnStreaming(handler func(MessageUpdateEvent)) func() {
|
||||
return m.OnMessageUpdate(handler)
|
||||
}
|
||||
|
||||
// OnMessageUpdate registers a handler that fires only for MessageUpdateEvent
|
||||
// (streaming text chunks). Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageUpdate(handler func(MessageUpdateEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if mu, ok := e.(MessageUpdateEvent); ok {
|
||||
handler(mu)
|
||||
@@ -555,6 +683,214 @@ func (m *Kit) OnTurnEnd(handler func(TurnEndEvent)) func() {
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Typed subscribers for previously unsubscribed event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// OnMessageStart registers a handler that fires only for MessageStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageStart(handler func(MessageStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ms, ok := e.(MessageStartEvent); ok {
|
||||
handler(ms)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnMessageEnd registers a handler that fires only for MessageEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnMessageEnd(handler func(MessageEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if me, ok := e.(MessageEndEvent); ok {
|
||||
handler(me)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnReasoningDelta registers a handler that fires only for ReasoningDeltaEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningDelta(handler func(ReasoningDeltaEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rd, ok := e.(ReasoningDeltaEvent); ok {
|
||||
handler(rd)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnReasoningComplete registers a handler that fires only for ReasoningCompleteEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningComplete(handler func(ReasoningCompleteEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rc, ok := e.(ReasoningCompleteEvent); ok {
|
||||
handler(rc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnToolExecutionStart registers a handler that fires only for ToolExecutionStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolExecutionStart(handler func(ToolExecutionStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tes, ok := e.(ToolExecutionStartEvent); ok {
|
||||
handler(tes)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnToolExecutionEnd registers a handler that fires only for ToolExecutionEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolExecutionEnd(handler func(ToolExecutionEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tee, ok := e.(ToolExecutionEndEvent); ok {
|
||||
handler(tee)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnToolCallContent registers a handler that fires only for ToolCallContentEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnToolCallContent(handler func(ToolCallContentEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if tcc, ok := e.(ToolCallContentEvent); ok {
|
||||
handler(tcc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnStepUsage registers a handler that fires only for StepUsageEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepUsage(handler func(StepUsageEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if su, ok := e.(StepUsageEvent); ok {
|
||||
handler(su)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnCompaction registers a handler that fires only for CompactionEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnCompaction(handler func(CompactionEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ce, ok := e.(CompactionEvent); ok {
|
||||
handler(ce)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnSteerConsumed registers a handler that fires only for SteerConsumedEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnSteerConsumed(handler func(SteerConsumedEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sc, ok := e.(SteerConsumedEvent); ok {
|
||||
handler(sc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Typed subscribers for new event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// OnStepStart registers a handler that fires only for StepStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepStart(handler func(StepStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ss, ok := e.(StepStartEvent); ok {
|
||||
handler(ss)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnStepFinish registers a handler that fires only for StepFinishEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStepFinish(handler func(StepFinishEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sf, ok := e.(StepFinishEvent); ok {
|
||||
handler(sf)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnTextStart registers a handler that fires only for TextStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTextStart(handler func(TextStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ts, ok := e.(TextStartEvent); ok {
|
||||
handler(ts)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnTextEnd registers a handler that fires only for TextEndEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnTextEnd(handler func(TextEndEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if te, ok := e.(TextEndEvent); ok {
|
||||
handler(te)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnReasoningStart registers a handler that fires only for ReasoningStartEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnReasoningStart(handler func(ReasoningStartEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if rs, ok := e.(ReasoningStartEvent); ok {
|
||||
handler(rs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnWarnings registers a handler that fires only for WarningsEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnWarnings(handler func(WarningsEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if w, ok := e.(WarningsEvent); ok {
|
||||
handler(w)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnSource registers a handler that fires only for SourceEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnSource(handler func(SourceEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if s, ok := e.(SourceEvent); ok {
|
||||
handler(s)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnStreamFinish registers a handler that fires only for StreamFinishEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnStreamFinish(handler func(StreamFinishEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if sf, ok := e.(StreamFinishEvent); ok {
|
||||
handler(sf)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnError registers a handler that fires only for ErrorEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnError(handler func(ErrorEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if ee, ok := e.(ErrorEvent); ok {
|
||||
handler(ee)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnRetry registers a handler that fires only for RetryEvent.
|
||||
// Returns an unsubscribe function.
|
||||
func (m *Kit) OnRetry(handler func(RetryEvent)) func() {
|
||||
return m.Subscribe(func(e Event) {
|
||||
if r, ok := e.(RetryEvent); ok {
|
||||
handler(r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent event subscriptions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -190,6 +191,74 @@ func TestEventTypes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewEventTypes verifies that each new event struct returns the correct EventType.
|
||||
func TestNewEventTypes(t *testing.T) {
|
||||
tests := []struct {
|
||||
event Event
|
||||
expected EventType
|
||||
}{
|
||||
{StepStartEvent{StepNumber: 0}, EventStepStart},
|
||||
{StepFinishEvent{StepNumber: 1, HasToolCalls: true}, EventStepFinish},
|
||||
{TextStartEvent{ID: "text-1"}, EventTextStart},
|
||||
{TextEndEvent{ID: "text-1"}, EventTextEnd},
|
||||
{ReasoningStartEvent{ID: "reason-1"}, EventReasoningStart},
|
||||
{WarningsEvent{Warnings: []string{"test"}}, EventWarnings},
|
||||
{SourceEvent{URL: "https://example.com", Title: "Example"}, EventSource},
|
||||
{StreamFinishEvent{FinishReason: "stop"}, EventStreamFinish},
|
||||
{ErrorEvent{Error: fmt.Errorf("test error")}, EventError},
|
||||
{RetryEvent{Attempt: 1, Error: fmt.Errorf("retry error")}, EventRetry},
|
||||
{ToolCallStartEvent{}, EventToolCallStart},
|
||||
{ToolCallDeltaEvent{}, EventToolCallDelta},
|
||||
{ToolCallEndEvent{}, EventToolCallEnd},
|
||||
{PasswordPromptEvent{}, EventPasswordPrompt},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := tt.event.EventType(); got != tt.expected {
|
||||
t.Errorf("%T.EventType() = %q, want %q", tt.event, got, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewEventEmission verifies that new event types are properly emitted and received.
|
||||
func TestNewEventEmission(t *testing.T) {
|
||||
bus := newEventBus()
|
||||
var received []Event
|
||||
|
||||
bus.subscribe(func(e Event) {
|
||||
received = append(received, e)
|
||||
})
|
||||
|
||||
bus.emit(StepStartEvent{StepNumber: 0})
|
||||
bus.emit(TextStartEvent{ID: "text-1"})
|
||||
bus.emit(TextEndEvent{ID: "text-1"})
|
||||
bus.emit(ReasoningStartEvent{ID: "reason-1"})
|
||||
bus.emit(WarningsEvent{Warnings: []string{"low confidence"}})
|
||||
bus.emit(SourceEvent{URL: "https://example.com", Title: "Example"})
|
||||
bus.emit(StreamFinishEvent{FinishReason: "stop"})
|
||||
bus.emit(StepFinishEvent{StepNumber: 0, HasToolCalls: false, FinishReason: "stop"})
|
||||
bus.emit(ErrorEvent{Error: fmt.Errorf("test error")})
|
||||
bus.emit(RetryEvent{Attempt: 1, Error: fmt.Errorf("retry")})
|
||||
|
||||
if len(received) != 10 {
|
||||
t.Fatalf("expected 10 events, got %d", len(received))
|
||||
}
|
||||
|
||||
// Verify specific event fields
|
||||
if ss, ok := received[0].(StepStartEvent); !ok || ss.StepNumber != 0 {
|
||||
t.Errorf("event 0: expected StepStartEvent{StepNumber:0}, got %T %+v", received[0], received[0])
|
||||
}
|
||||
if ts, ok := received[1].(TextStartEvent); !ok || ts.ID != "text-1" {
|
||||
t.Errorf("event 1: expected TextStartEvent{ID:text-1}, got %T %+v", received[1], received[1])
|
||||
}
|
||||
if w, ok := received[4].(WarningsEvent); !ok || len(w.Warnings) != 1 || w.Warnings[0] != "low confidence" {
|
||||
t.Errorf("event 4: expected WarningsEvent with 1 warning, got %T %+v", received[4], received[4])
|
||||
}
|
||||
if sf, ok := received[7].(StepFinishEvent); !ok || sf.StepNumber != 0 || sf.HasToolCalls {
|
||||
t.Errorf("event 7: expected StepFinishEvent{StepNumber:0, HasToolCalls:false}, got %T %+v", received[7], received[7])
|
||||
}
|
||||
}
|
||||
|
||||
// TestEventBusListenerCanUnsubscribeInCallback verifies that a listener can
|
||||
// safely call its own unsubscribe function from within the callback.
|
||||
func TestEventBusListenerCanUnsubscribeInCallback(t *testing.T) {
|
||||
|
||||
@@ -356,4 +356,134 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// --- Step lifecycle observation events ---
|
||||
|
||||
if runner.HasHandlers(extensions.StepStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.StepFinish) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepFinishEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepFinishEvent{
|
||||
StepNumber: ev.StepNumber,
|
||||
HasToolCalls: ev.HasToolCalls,
|
||||
FinishReason: ev.FinishReason,
|
||||
InputTokens: ev.Usage.InputTokens,
|
||||
OutputTokens: ev.Usage.OutputTokens,
|
||||
CacheReadTokens: ev.Usage.CacheReadTokens,
|
||||
CacheWriteTokens: ev.Usage.CacheCreationTokens,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.ReasoningStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ReasoningStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.Warnings) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(WarningsEvent); ok {
|
||||
_, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.Source) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(SourceEvent); ok {
|
||||
_, _ = runner.Emit(extensions.SourceEvent{
|
||||
SourceType: ev.SourceType,
|
||||
ID: ev.ID,
|
||||
URL: ev.URL,
|
||||
Title: ev.Title,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.Error) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ErrorEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if runner.HasHandlers(extensions.Retry) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(RetryEvent); ok {
|
||||
_, _ = runner.Emit(extensions.RetryEvent{
|
||||
Attempt: ev.Attempt,
|
||||
Error: ev.Error.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// --- PrepareStep hook ---
|
||||
// Extension PrepareStep → SDK PrepareStep hook.
|
||||
// Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage.
|
||||
if runner.HasHandlers(extensions.PrepareStep) {
|
||||
m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
result, _ := runner.Emit(extensions.PrepareStepEvent{
|
||||
StepNumber: h.StepNumber,
|
||||
Messages: extMsgs,
|
||||
})
|
||||
r, ok := result.(extensions.PrepareStepResult)
|
||||
if !ok || r.Messages == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]LLMMessage, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
rebuilt = append(rebuilt, h.Messages[cm.Index])
|
||||
} else {
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &PrepareStepResult{Messages: rebuilt}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+48
-11
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -121,6 +119,32 @@ type BeforeCompactResult struct {
|
||||
Summary string
|
||||
}
|
||||
|
||||
// PrepareStepHook is the input for hooks that fire between steps within a
|
||||
// multi-step agent turn, with full message replacement capability. This is
|
||||
// the most powerful interception point — it fires after the existing steering
|
||||
// logic (if any) and before the messages are sent to the LLM.
|
||||
//
|
||||
// Use cases:
|
||||
// - Transforming tool results (e.g. converting image tool results to FilePart
|
||||
// user messages for vision models that don't support media in tool results)
|
||||
// - Dynamic tool filtering per step
|
||||
// - Mid-turn context injection beyond simple steering
|
||||
// - Custom stop conditions that inspect message history
|
||||
type PrepareStepHook struct {
|
||||
// StepNumber is the zero-based step index within the current turn.
|
||||
StepNumber int
|
||||
// Messages is the current context window that will be sent to the LLM.
|
||||
// This includes any steering messages already injected in this step.
|
||||
Messages []LLMMessage
|
||||
}
|
||||
|
||||
// PrepareStepResult can replace the context window between steps.
|
||||
type PrepareStepResult struct {
|
||||
// Messages replaces the entire context window for this step. If nil,
|
||||
// the original messages (including any steering) are used unchanged.
|
||||
Messages []LLMMessage
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Generic hook registry with priority ordering
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -248,6 +272,19 @@ func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeC
|
||||
return m.beforeCompact.register(p, h)
|
||||
}
|
||||
|
||||
// OnPrepareStep registers a hook that fires between steps within a multi-step
|
||||
// agent turn, after steering messages are injected and before the messages are
|
||||
// sent to the LLM. Return a non-nil PrepareStepResult with Messages to replace
|
||||
// the entire context window for this step. Hooks execute in priority order;
|
||||
// the first non-nil result wins. Returns an unregister function.
|
||||
//
|
||||
// This is the most powerful interception point in the agent lifecycle. It
|
||||
// enables patterns like transforming tool results, dynamic tool filtering,
|
||||
// and mid-turn context injection.
|
||||
func (m *Kit) OnPrepareStep(p HookPriority, h func(PrepareStepHook) *PrepareStepResult) func() {
|
||||
return m.prepareStep.register(p, h)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tool wrapping via hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -256,16 +293,16 @@ func (m *Kit) OnBeforeCompact(p HookPriority, h func(BeforeCompactHook) *BeforeC
|
||||
// AfterToolResult hooks around each execution. The registries are referenced
|
||||
// by pointer so hooks added after agent creation are still invoked.
|
||||
type hookedTool struct {
|
||||
inner fantasy.AgentTool
|
||||
inner Tool
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
|
||||
}
|
||||
|
||||
func (h *hookedTool) Info() fantasy.ToolInfo { return h.inner.Info() }
|
||||
func (h *hookedTool) ProviderOptions() fantasy.ProviderOptions { return h.inner.ProviderOptions() }
|
||||
func (h *hookedTool) SetProviderOptions(o fantasy.ProviderOptions) { h.inner.SetProviderOptions(o) }
|
||||
func (h *hookedTool) Info() LLMToolInfo { return h.inner.Info() }
|
||||
func (h *hookedTool) ProviderOptions() LLMProviderOptions { return h.inner.ProviderOptions() }
|
||||
func (h *hookedTool) SetProviderOptions(o LLMProviderOptions) { h.inner.SetProviderOptions(o) }
|
||||
|
||||
func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
func (h *hookedTool) Run(ctx context.Context, call LLMToolCall) (LLMToolResponse, error) {
|
||||
toolName := h.inner.Info().Name
|
||||
|
||||
// 1. BeforeToolCall — can block execution.
|
||||
@@ -279,7 +316,7 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
|
||||
if reason == "" {
|
||||
reason = "blocked by hook"
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
return newLLMTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
fmt.Errorf("tool blocked by hook: %s", reason)
|
||||
}
|
||||
}
|
||||
@@ -314,9 +351,9 @@ func (h *hookedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.To
|
||||
func hookToolWrapper(
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult],
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult],
|
||||
) func([]fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
wrapped := make([]fantasy.AgentTool, len(tools))
|
||||
) func([]Tool) []Tool {
|
||||
return func(tools []Tool) []Tool {
|
||||
wrapped := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
wrapped[i] = &hookedTool{
|
||||
inner: tool,
|
||||
|
||||
+98
-28
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -177,20 +175,20 @@ func TestHookRegistry_ConcurrentAccess(t *testing.T) {
|
||||
// mockAgentTool implements the AgentTool interface for testing.
|
||||
type mockAgentTool struct {
|
||||
name string
|
||||
runFn func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error)
|
||||
popts fantasy.ProviderOptions
|
||||
runFn func(ctx context.Context, call LLMToolCall) (LLMToolResponse, error)
|
||||
popts LLMProviderOptions
|
||||
}
|
||||
|
||||
func (m *mockAgentTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{Name: m.name, Description: "mock tool"}
|
||||
func (m *mockAgentTool) Info() LLMToolInfo {
|
||||
return LLMToolInfo{Name: m.name, Description: "mock tool"}
|
||||
}
|
||||
func (m *mockAgentTool) ProviderOptions() fantasy.ProviderOptions { return m.popts }
|
||||
func (m *mockAgentTool) SetProviderOptions(o fantasy.ProviderOptions) { m.popts = o }
|
||||
func (m *mockAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
func (m *mockAgentTool) ProviderOptions() LLMProviderOptions { return m.popts }
|
||||
func (m *mockAgentTool) SetProviderOptions(o LLMProviderOptions) { m.popts = o }
|
||||
func (m *mockAgentTool) Run(ctx context.Context, call LLMToolCall) (LLMToolResponse, error) {
|
||||
if m.runFn != nil {
|
||||
return m.runFn(ctx, call)
|
||||
}
|
||||
return fantasy.NewTextResponse("default output"), nil
|
||||
return newLLMTextResponse("default output"), nil
|
||||
}
|
||||
|
||||
// newEmptyHookedTool creates a hookedTool with empty hook registries and the given mock tool.
|
||||
@@ -203,14 +201,14 @@ func newEmptyHookedTool(mock *mockAgentTool) *hookedTool {
|
||||
func TestHookedTool_Passthrough(t *testing.T) {
|
||||
mock := &mockAgentTool{
|
||||
name: "test_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("hello world"), nil
|
||||
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
|
||||
return newLLMTextResponse("hello world"), nil
|
||||
},
|
||||
}
|
||||
|
||||
ht := newEmptyHookedTool(mock)
|
||||
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -226,9 +224,9 @@ func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
|
||||
toolRan := false
|
||||
mock := &mockAgentTool{
|
||||
name: "dangerous_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
|
||||
toolRan = true
|
||||
return fantasy.NewTextResponse("should not run"), nil
|
||||
return newLLMTextResponse("should not run"), nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -241,7 +239,7 @@ func TestHookedTool_BeforeToolCallBlock(t *testing.T) {
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from blocked tool")
|
||||
}
|
||||
@@ -263,7 +261,7 @@ func TestHookedTool_BeforeToolCallBlockDefaultReason(t *testing.T) {
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, _ := ht.Run(context.Background(), fantasy.ToolCall{})
|
||||
resp, _ := ht.Run(context.Background(), LLMToolCall{})
|
||||
if resp.Content != "Error: blocked by hook" {
|
||||
t.Errorf("expected default block reason, got %q", resp.Content)
|
||||
}
|
||||
@@ -275,8 +273,8 @@ func TestHookedTool_AfterToolResultModify(t *testing.T) {
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("secret data"), nil
|
||||
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
|
||||
return newLLMTextResponse("secret data"), nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -286,7 +284,7 @@ func TestHookedTool_AfterToolResultModify(t *testing.T) {
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{Input: "{}"})
|
||||
resp, err := ht.Run(context.Background(), LLMToolCall{Input: "{}"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -301,8 +299,8 @@ func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("ok"), nil
|
||||
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
|
||||
return newLLMTextResponse("ok"), nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -312,7 +310,7 @@ func TestHookedTool_AfterToolResultModifyIsError(t *testing.T) {
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
resp, err := ht.Run(context.Background(), fantasy.ToolCall{})
|
||||
resp, err := ht.Run(context.Background(), LLMToolCall{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -327,8 +325,8 @@ func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
|
||||
|
||||
mock := &mockAgentTool{
|
||||
name: "my_tool",
|
||||
runFn: func(_ context.Context, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return fantasy.NewTextResponse("result"), nil
|
||||
runFn: func(_ context.Context, _ LLMToolCall) (LLMToolResponse, error) {
|
||||
return newLLMTextResponse("result"), nil
|
||||
},
|
||||
}
|
||||
|
||||
@@ -345,7 +343,7 @@ func TestHookedTool_HookReceivesToolInfo(t *testing.T) {
|
||||
})
|
||||
|
||||
ht := &hookedTool{inner: mock, beforeToolCall: before, afterToolResult: after}
|
||||
_, _ = ht.Run(context.Background(), fantasy.ToolCall{Input: `{"key":"value"}`})
|
||||
_, _ = ht.Run(context.Background(), LLMToolCall{Input: `{"key":"value"}`})
|
||||
|
||||
if capturedBefore.ToolName != "my_tool" {
|
||||
t.Errorf("BeforeToolCall: expected tool name 'my_tool', got %q", capturedBefore.ToolName)
|
||||
@@ -380,7 +378,7 @@ func TestHookToolWrapper(t *testing.T) {
|
||||
|
||||
wrapper := hookToolWrapper(before, after)
|
||||
|
||||
tools := []fantasy.AgentTool{
|
||||
tools := []Tool{
|
||||
&mockAgentTool{name: "tool_a"},
|
||||
&mockAgentTool{name: "tool_b"},
|
||||
}
|
||||
@@ -407,7 +405,7 @@ func TestHookToolWrapper(t *testing.T) {
|
||||
return &BeforeToolCallResult{Block: true, Reason: "late hook"}
|
||||
})
|
||||
|
||||
_, err := wrapped[0].Run(context.Background(), fantasy.ToolCall{})
|
||||
_, err := wrapped[0].Run(context.Background(), LLMToolCall{})
|
||||
if err == nil {
|
||||
t.Error("expected error from late-registered blocking hook")
|
||||
}
|
||||
@@ -538,3 +536,75 @@ func TestKit_HookMethodsExist(t *testing.T) {
|
||||
u3()
|
||||
u4()
|
||||
}
|
||||
|
||||
// TestPrepareStepHookRegistry verifies registration and execution of PrepareStep hooks.
|
||||
func TestPrepareStepHookRegistry(t *testing.T) {
|
||||
hr := newHookRegistry[PrepareStepHook, PrepareStepResult]()
|
||||
|
||||
// Register a hook that appends a message.
|
||||
hr.register(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
|
||||
if h.StepNumber == 0 {
|
||||
// On step 0, prepend a system message.
|
||||
newMsgs := make([]LLMMessage, 0, len(h.Messages)+1)
|
||||
newMsgs = append(newMsgs, NewLLMSystemMessage("injected"))
|
||||
newMsgs = append(newMsgs, h.Messages...)
|
||||
return &PrepareStepResult{Messages: newMsgs}
|
||||
}
|
||||
return nil // No modification for other steps.
|
||||
})
|
||||
|
||||
// Test step 0 — should modify messages.
|
||||
input := PrepareStepHook{
|
||||
StepNumber: 0,
|
||||
Messages: []LLMMessage{NewLLMUserMessage("hello")},
|
||||
}
|
||||
result := hr.run(input)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result for step 0")
|
||||
}
|
||||
if len(result.Messages) != 2 {
|
||||
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Role != LLMRoleSystem {
|
||||
t.Errorf("expected system message first, got role %q", result.Messages[0].Role)
|
||||
}
|
||||
|
||||
// Test step 1 — should return nil (no modification).
|
||||
input.StepNumber = 1
|
||||
result = hr.run(input)
|
||||
if result != nil {
|
||||
t.Errorf("expected nil result for step 1, got %+v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareStepHookPriority verifies that PrepareStep hooks respect priority ordering.
|
||||
func TestPrepareStepHookPriority(t *testing.T) {
|
||||
hr := newHookRegistry[PrepareStepHook, PrepareStepResult]()
|
||||
|
||||
var order []string
|
||||
|
||||
// Low priority — should run second.
|
||||
hr.register(HookPriorityLow, func(_ PrepareStepHook) *PrepareStepResult {
|
||||
order = append(order, "low")
|
||||
return nil
|
||||
})
|
||||
|
||||
// High priority — should run first and win.
|
||||
hr.register(HookPriorityHigh, func(h PrepareStepHook) *PrepareStepResult {
|
||||
order = append(order, "high")
|
||||
return &PrepareStepResult{Messages: h.Messages}
|
||||
})
|
||||
|
||||
input := PrepareStepHook{
|
||||
StepNumber: 0,
|
||||
Messages: []LLMMessage{NewLLMUserMessage("test")},
|
||||
}
|
||||
result := hr.run(input)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if len(order) != 1 || order[0] != "high" {
|
||||
t.Errorf("expected [high] (first non-nil wins), got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
+142
-28
@@ -66,6 +66,7 @@ type Kit struct {
|
||||
afterTurn *hookRegistry[AfterTurnHook, AfterTurnResult]
|
||||
contextPrepare *hookRegistry[ContextPrepareHook, ContextPrepareResult]
|
||||
beforeCompact *hookRegistry[BeforeCompactHook, BeforeCompactResult]
|
||||
prepareStep *hookRegistry[PrepareStepHook, PrepareStepResult]
|
||||
|
||||
// lastInputTokens stores the API-reported input token count from the
|
||||
// most recent turn. Used by GetContextStats() to return accurate usage
|
||||
@@ -731,7 +732,7 @@ func (m *Kit) ExecuteCompletion(ctx context.Context, req extensions.CompleteRequ
|
||||
llmModel fantasy.LanguageModel
|
||||
closer func()
|
||||
usedModel string
|
||||
providerOps fantasy.ProviderOptions
|
||||
providerOps LLMProviderOptions
|
||||
)
|
||||
|
||||
if req.Model == "" {
|
||||
@@ -1034,6 +1035,41 @@ type Options struct {
|
||||
// real-time progress in the TUI.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// MCPTaskMode overrides the per-server [MCPTaskMode] for task-augmented
|
||||
// tools/call execution. Keys are MCP server names. Servers not present
|
||||
// in the map fall back to the TasksMode field of MCPServerConfig (or
|
||||
// MCPTaskModeAuto when that is empty). See the MCP Tasks spec for the
|
||||
// underlying semantics:
|
||||
// https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks
|
||||
MCPTaskMode map[string]MCPTaskMode
|
||||
|
||||
// MCPTaskTimeout is the maximum wall-clock duration to wait for a
|
||||
// task-augmented tool call to reach a terminal state. Independent of
|
||||
// any per-call context deadline; whichever fires first wins. Zero
|
||||
// means use the default (15 minutes).
|
||||
MCPTaskTimeout time.Duration
|
||||
|
||||
// MCPTaskTTL is the TTL hint sent in TaskParams for every
|
||||
// task-augmented tools/call. Zero omits the TTL and lets the server
|
||||
// pick its own retention policy.
|
||||
MCPTaskTTL time.Duration
|
||||
|
||||
// MCPTaskPollInterval is the fallback interval between tasks/get
|
||||
// requests when the server does not suggest one. Zero means use the
|
||||
// default (1 second).
|
||||
MCPTaskPollInterval time.Duration
|
||||
|
||||
// MCPTaskMaxPollInterval caps the polling interval (a server-supplied
|
||||
// pollInterval can otherwise grow without bound). Zero means use the
|
||||
// default (5 seconds).
|
||||
MCPTaskMaxPollInterval time.Duration
|
||||
|
||||
// MCPTaskProgress, if non-nil, is invoked once when a task is accepted
|
||||
// and on every status transition observed by the polling loop. The
|
||||
// final invocation always carries a terminal status. Implementations
|
||||
// must not block; long work should run on a goroutine.
|
||||
MCPTaskProgress MCPTaskProgressHandler
|
||||
|
||||
// CLI is optional CLI-specific configuration. SDK users leave this nil.
|
||||
CLI *CLIOptions
|
||||
|
||||
@@ -1368,6 +1404,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
afterTurn := newHookRegistry[AfterTurnHook, AfterTurnResult]()
|
||||
contextPrepare := newHookRegistry[ContextPrepareHook, ContextPrepareResult]()
|
||||
beforeCompact := newHookRegistry[BeforeCompactHook, BeforeCompactResult]()
|
||||
prepareStep := newHookRegistry[PrepareStepHook, PrepareStepResult]()
|
||||
|
||||
// Build agent setup options, pulling CLI-specific fields when available.
|
||||
// Pass the pre-built ProviderConfig and scalar viper snapshots so
|
||||
@@ -1385,6 +1422,14 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streaming,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: mcpTaskOptions{
|
||||
perServer: opts.MCPTaskMode,
|
||||
defaultTTL: opts.MCPTaskTTL,
|
||||
pollInterval: opts.MCPTaskPollInterval,
|
||||
maxPollInterval: opts.MCPTaskMaxPollInterval,
|
||||
timeout: opts.MCPTaskTimeout,
|
||||
progress: opts.MCPTaskProgress,
|
||||
}.toToolsConfig(),
|
||||
}
|
||||
|
||||
// Set up OAuth handler for remote MCP servers. The SDK does not create
|
||||
@@ -1461,6 +1506,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
afterTurn: afterTurn,
|
||||
contextPrepare: contextPrepare,
|
||||
beforeCompact: beforeCompact,
|
||||
prepareStep: prepareStep,
|
||||
}
|
||||
|
||||
// Bridge extension events to SDK hooks.
|
||||
@@ -1781,14 +1827,28 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
|
||||
// Create child Kit instance. Pass the parent's loaded MCP config to
|
||||
// avoid re-reading viper (which races with concurrent subagent spawns).
|
||||
// Streaming must be explicitly enabled — Options.Streaming defaults to
|
||||
// false, and New() unconditionally writes viper.Set("stream", opts.Streaming).
|
||||
// Without this, the subagent would (a) pollute viper global state for
|
||||
// other concurrent callers and (b) potentially hit provider-level
|
||||
// differences (e.g. Anthropic non-streaming timeouts with extended
|
||||
// thinking).
|
||||
childOpts := &Options{
|
||||
Model: model,
|
||||
SystemPrompt: systemPrompt,
|
||||
Tools: tools,
|
||||
NoSession: cfg.NoSession,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
MCPConfig: m.mcpConfig,
|
||||
}
|
||||
// Propagate the parent's MCP task configuration so a child subagent
|
||||
// invoking long-running MCP tools observes the same per-server modes,
|
||||
// timeouts, and progress callback as the parent. Without this, child
|
||||
// agents would silently fall back to MCPTaskModeAuto with default
|
||||
// polling and no progress feedback even when the parent had configured
|
||||
// custom values.
|
||||
inheritMCPTaskOptions(childOpts, m.opts)
|
||||
child, err := New(ctx, childOpts)
|
||||
if err != nil {
|
||||
return &SubagentResult{Elapsed: time.Since(start)}, fmt.Errorf("failed to create subagent: %w", err)
|
||||
@@ -1894,21 +1954,21 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
return sr, err
|
||||
})
|
||||
|
||||
return m.agent.GenerateWithLoopAndStreaming(ctx, messages,
|
||||
func(toolCallID, toolName, toolArgs string) {
|
||||
return m.agent.GenerateWithCallbacks(ctx, messages, agent.GenerateCallbacks{
|
||||
OnToolCall: func(toolCallID, toolName, toolArgs string) {
|
||||
m.events.emit(ToolCallEvent{
|
||||
ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName),
|
||||
ToolArgs: toolArgs, ParsedArgs: parseToolArgs(toolArgs),
|
||||
})
|
||||
},
|
||||
func(toolCallID, toolName, toolArgs string, isStarting bool) {
|
||||
OnToolExecution: func(toolCallID, toolName, toolArgs string, isStarting bool) {
|
||||
if isStarting {
|
||||
m.events.emit(ToolExecutionStartEvent{ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName), ToolArgs: toolArgs})
|
||||
} else {
|
||||
m.events.emit(ToolExecutionEndEvent{ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName)})
|
||||
}
|
||||
},
|
||||
func(toolCallID, toolName, toolArgs, resultText, metadata string, isError bool) {
|
||||
OnToolResult: func(toolCallID, toolName, toolArgs, resultText, metadata string, isError bool) {
|
||||
evt := ToolResultEvent{
|
||||
ToolCallID: toolCallID, ToolName: toolName, ToolKind: toolKindFor(toolName),
|
||||
ToolArgs: toolArgs, ParsedArgs: parseToolArgs(toolArgs),
|
||||
@@ -1922,17 +1982,17 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
}
|
||||
m.events.emit(evt)
|
||||
},
|
||||
func(content string) {
|
||||
OnResponse: func(content string) {
|
||||
m.events.emit(ResponseEvent{Content: content})
|
||||
},
|
||||
func(content string) {
|
||||
OnToolCallContent: func(content string) {
|
||||
m.events.emit(ToolCallContentEvent{Content: content})
|
||||
},
|
||||
// <think> tag filtering: models like Qwen/DeepSeek wrap reasoning inside
|
||||
// <think>...</think> tags in the regular text stream. We intercept those
|
||||
// spans here and re-route them as ReasoningDeltaEvent/ReasoningCompleteEvent
|
||||
// so callers always receive clean, tag-free text and structured reasoning.
|
||||
func() func(chunk string) {
|
||||
OnStreamingResponse: func() func(chunk string) {
|
||||
const (
|
||||
thinkOpen = "<think>"
|
||||
thinkClose = "</think>"
|
||||
@@ -1968,14 +2028,13 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
}
|
||||
}
|
||||
}(),
|
||||
func(delta string) {
|
||||
OnReasoningDelta: func(delta string) {
|
||||
m.events.emit(ReasoningDeltaEvent{Delta: delta})
|
||||
},
|
||||
func() {
|
||||
OnReasoningComplete: func() {
|
||||
m.events.emit(ReasoningCompleteEvent{})
|
||||
},
|
||||
func(toolCallID, toolName, chunk string, isStderr bool) {
|
||||
// Emit tool output chunk event for streaming bash output
|
||||
OnToolOutput: func(toolCallID, toolName, chunk string, isStderr bool) {
|
||||
m.events.emit(ToolOutputEvent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
@@ -1984,18 +2043,13 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
})
|
||||
},
|
||||
// Persist step messages incrementally so that progress survives
|
||||
// crashes and long-running turns don't lose work. Each step's
|
||||
// messages are persisted as a unit: for tool-calling steps this is
|
||||
// the assistant message (with tool_use parts) + tool-role message
|
||||
// (with tool_result parts) as a pair; for the final step it's the
|
||||
// assistant text/reasoning message alone.
|
||||
func(stepMessages []fantasy.Message) {
|
||||
// crashes and long-running turns don't lose work.
|
||||
OnStepMessages: func(stepMessages []fantasy.Message) {
|
||||
for _, msg := range stepMessages {
|
||||
_, _ = m.session.AppendMessage(msg)
|
||||
}
|
||||
},
|
||||
func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
// Emit step usage event for real-time cost tracking
|
||||
OnStepUsage: func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64) {
|
||||
if viper.GetBool("debug") {
|
||||
log.Printf("DEBUG Kit.generate emitting StepUsageEvent: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens,
|
||||
@@ -2009,37 +2063,97 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
})
|
||||
},
|
||||
// Password prompt handler for sudo commands
|
||||
func(prompt string) (string, bool) {
|
||||
// Emit event to TUI and wait for response via channel
|
||||
OnPasswordPrompt: func(prompt string) (string, bool) {
|
||||
responseCh := make(chan PasswordPromptResponse, 1)
|
||||
m.events.emit(PasswordPromptEvent{
|
||||
Prompt: prompt,
|
||||
ResponseCh: responseCh,
|
||||
})
|
||||
// Wait for response (TUI will send password or cancel)
|
||||
resp := <-responseCh
|
||||
return resp.Password, resp.Cancelled
|
||||
},
|
||||
// Tool call argument streaming — fire as the LLM generates tool arguments
|
||||
func(toolCallID, toolName string) {
|
||||
// Tool call argument streaming
|
||||
OnToolCallStart: func(toolCallID, toolName string) {
|
||||
m.events.emit(ToolCallStartEvent{
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
ToolKind: toolKindFor(toolName),
|
||||
})
|
||||
},
|
||||
func(toolCallID, delta string) {
|
||||
OnToolCallDelta: func(toolCallID, delta string) {
|
||||
m.events.emit(ToolCallDeltaEvent{
|
||||
ToolCallID: toolCallID,
|
||||
Delta: delta,
|
||||
})
|
||||
},
|
||||
func(toolCallID string) {
|
||||
OnToolCallEnd: func(toolCallID string) {
|
||||
m.events.emit(ToolCallEndEvent{
|
||||
ToolCallID: toolCallID,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
// New callbacks for previously unwired Fantasy lifecycle events.
|
||||
OnStepStart: func(stepNumber int) {
|
||||
m.events.emit(StepStartEvent{StepNumber: stepNumber})
|
||||
},
|
||||
OnStepFinish: func(stepNumber int, hasToolCalls bool, finishReason string, usage fantasy.Usage) {
|
||||
m.events.emit(StepFinishEvent{
|
||||
StepNumber: stepNumber,
|
||||
HasToolCalls: hasToolCalls,
|
||||
FinishReason: finishReason,
|
||||
Usage: usage,
|
||||
})
|
||||
},
|
||||
OnTextStart: func(id string) {
|
||||
m.events.emit(TextStartEvent{ID: id})
|
||||
},
|
||||
OnTextEnd: func(id string) {
|
||||
m.events.emit(TextEndEvent{ID: id})
|
||||
},
|
||||
OnReasoningStart: func(id string) {
|
||||
m.events.emit(ReasoningStartEvent{ID: id})
|
||||
},
|
||||
OnWarnings: func(warnings []string) {
|
||||
m.events.emit(WarningsEvent{Warnings: warnings})
|
||||
},
|
||||
OnSource: func(sourceType, id, url, title string) {
|
||||
m.events.emit(SourceEvent{
|
||||
SourceType: sourceType,
|
||||
ID: id,
|
||||
URL: url,
|
||||
Title: title,
|
||||
})
|
||||
},
|
||||
OnStreamFinish: func(usage fantasy.Usage, finishReason string) {
|
||||
m.events.emit(StreamFinishEvent{
|
||||
Usage: usage,
|
||||
FinishReason: finishReason,
|
||||
})
|
||||
},
|
||||
OnError: func(err error) {
|
||||
m.events.emit(ErrorEvent{Error: err})
|
||||
},
|
||||
OnRetry: func(attempt int, err error) {
|
||||
m.events.emit(RetryEvent{Attempt: attempt, Error: err})
|
||||
},
|
||||
// PrepareStep hook — compose with steering (handled in agent layer)
|
||||
// and then run SDK consumer hooks.
|
||||
OnPrepareStep: func() agent.PrepareStepHandler {
|
||||
if !m.prepareStep.hasHooks() {
|
||||
return nil
|
||||
}
|
||||
return func(stepNumber int, messages []fantasy.Message) []fantasy.Message {
|
||||
hookResult := m.prepareStep.run(PrepareStepHook{
|
||||
StepNumber: stepNumber,
|
||||
Messages: messages,
|
||||
})
|
||||
if hookResult != nil && hookResult.Messages != nil {
|
||||
return hookResult.Messages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}(),
|
||||
})
|
||||
}
|
||||
|
||||
// runTurn is the shared lifecycle for every prompt mode:
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// MCPTaskStatus represents the lifecycle state of a task-augmented MCP
|
||||
// tool call. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks
|
||||
// for the underlying spec.
|
||||
type MCPTaskStatus string
|
||||
|
||||
const (
|
||||
// MCPTaskStatusWorking indicates the task is currently being processed.
|
||||
MCPTaskStatusWorking MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusWorking)
|
||||
// MCPTaskStatusInputRequired indicates the server is waiting for client
|
||||
// input before it can proceed (rare; typically surfaced via elicitation).
|
||||
MCPTaskStatusInputRequired MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusInputRequired)
|
||||
// MCPTaskStatusCompleted indicates the task finished successfully.
|
||||
MCPTaskStatusCompleted MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCompleted)
|
||||
// MCPTaskStatusFailed indicates the task ended in error.
|
||||
MCPTaskStatusFailed MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusFailed)
|
||||
// MCPTaskStatusCancelled indicates the task was cancelled before completion.
|
||||
MCPTaskStatusCancelled MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCancelled)
|
||||
)
|
||||
|
||||
// IsTerminal reports whether the status represents a final state — that is,
|
||||
// the task will not change again. Terminal states are completed, failed,
|
||||
// and cancelled.
|
||||
func (s MCPTaskStatus) IsTerminal() bool {
|
||||
return mcp.TaskStatus(s).IsTerminal()
|
||||
}
|
||||
|
||||
// MCPTaskMode controls when Kit augments tools/call requests with MCP task
|
||||
// metadata for a specific server.
|
||||
type MCPTaskMode string
|
||||
|
||||
const (
|
||||
// MCPTaskModeAuto augments tools/call with task metadata only when the
|
||||
// server advertises tasks/toolCalls capability during initialize.
|
||||
// This is the default and is safe to leave unconfigured for any
|
||||
// existing MCP server.
|
||||
MCPTaskModeAuto MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAuto)
|
||||
// MCPTaskModeNever forces every tools/call to be issued synchronously
|
||||
// (no Task field), regardless of server capability.
|
||||
MCPTaskModeNever MCPTaskMode = MCPTaskMode(tools.MCPTaskModeNever)
|
||||
// MCPTaskModeAlways always opts into task augmentation, even when the
|
||||
// server didn't advertise the capability. The server may still respond
|
||||
// synchronously; this just expresses client intent unconditionally.
|
||||
MCPTaskModeAlways MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAlways)
|
||||
)
|
||||
|
||||
// MCPTask is the SDK-level view of an MCP Task. Timestamps are best-effort
|
||||
// parsed from the server's ISO-8601 strings; they may be the zero time when
|
||||
// the server omitted them or used a non-RFC3339 format.
|
||||
type MCPTask struct {
|
||||
// Server is the configured MCP server name this task lives on.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the task.
|
||||
TaskID string
|
||||
// Status is the current task lifecycle state.
|
||||
Status MCPTaskStatus
|
||||
// StatusMessage is an optional human-readable description provided by
|
||||
// the server.
|
||||
StatusMessage string
|
||||
// CreatedAt is when the task was created on the server.
|
||||
CreatedAt time.Time
|
||||
// UpdatedAt is when the task was last updated on the server.
|
||||
UpdatedAt time.Time
|
||||
// TTL is how long the server intends to retain this task after creation.
|
||||
// Zero means the server did not advertise a TTL.
|
||||
TTL time.Duration
|
||||
// PollInterval is the suggested time between status checks. Zero means
|
||||
// the client should use its own default.
|
||||
PollInterval time.Duration
|
||||
}
|
||||
|
||||
// MCPTaskProgress is a single status update emitted while Kit is waiting
|
||||
// on a task-augmented tool call.
|
||||
type MCPTaskProgress struct {
|
||||
// Server is the configured MCP server name.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the in-flight task.
|
||||
TaskID string
|
||||
// Status is the most recent task status observed.
|
||||
Status MCPTaskStatus
|
||||
// Message is the optional human-readable status message from the server.
|
||||
Message string
|
||||
}
|
||||
|
||||
// MCPTaskProgressHandler is called once when a task is accepted and again
|
||||
// on every observed status transition. The final invocation always carries
|
||||
// a terminal status. Implementations must not block; long work should be
|
||||
// dispatched on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
|
||||
// Stored on Options as a single value so the public surface stays compact;
|
||||
// individual fields are exposed via WithMCP* builder functions.
|
||||
type mcpTaskOptions struct {
|
||||
perServer map[string]MCPTaskMode
|
||||
defaultTTL time.Duration
|
||||
pollInterval time.Duration
|
||||
maxPollInterval time.Duration
|
||||
timeout time.Duration
|
||||
progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
// toToolsConfig converts the SDK-level config to the internal tools-package
|
||||
// representation. Keeps the dependency arrow internal-only.
|
||||
func (o mcpTaskOptions) toToolsConfig() tools.MCPTaskConfig {
|
||||
cfg := tools.MCPTaskConfig{
|
||||
DefaultTTL: o.defaultTTL,
|
||||
PollInterval: o.pollInterval,
|
||||
MaxPollInterval: o.maxPollInterval,
|
||||
Timeout: o.timeout,
|
||||
}
|
||||
if len(o.perServer) > 0 {
|
||||
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(o.perServer))
|
||||
for k, v := range o.perServer {
|
||||
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
|
||||
}
|
||||
}
|
||||
if o.progress != nil {
|
||||
h := o.progress
|
||||
cfg.Progress = func(p tools.MCPTaskProgress) {
|
||||
h(MCPTaskProgress{
|
||||
Server: p.Server,
|
||||
TaskID: p.TaskID,
|
||||
Status: MCPTaskStatus(p.Status),
|
||||
Message: p.Message,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ListMCPTasks queries tasks/list on the named MCP server and returns the
|
||||
// active and recent tasks the server is willing to disclose. Returns an
|
||||
// error when the server isn't loaded, doesn't expose tasks/list, or the
|
||||
// underlying transport fails.
|
||||
func (m *Kit) ListMCPTasks(ctx context.Context, serverName string) ([]MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
infos, err := mgr.ListServerTasks(ctx, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]MCPTask, len(infos))
|
||||
for i, t := range infos {
|
||||
out[i] = mcpTaskFromInternal(t)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetMCPTask queries tasks/get for a single in-flight task on the named
|
||||
// server. The returned MCPTask reflects the server's current view of the
|
||||
// task.
|
||||
func (m *Kit) GetMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
info, err := mgr.GetServerTask(ctx, serverName, taskID)
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
return mcpTaskFromInternal(info), nil
|
||||
}
|
||||
|
||||
// CancelMCPTask issues tasks/cancel for an in-flight task on the named
|
||||
// server. Returns the post-cancel task state when the server responded
|
||||
// with one. Cancelling an already-terminal task is a no-op on most
|
||||
// servers.
|
||||
func (m *Kit) CancelMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
info, err := mgr.CancelServerTask(ctx, serverName, taskID)
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
return mcpTaskFromInternal(info), nil
|
||||
}
|
||||
|
||||
// mcpToolManager returns the underlying MCP tool manager or an error when
|
||||
// no MCP servers are configured.
|
||||
func (m *Kit) mcpToolManager() (*tools.MCPToolManager, error) {
|
||||
if m == nil || m.agent == nil {
|
||||
return nil, fmt.Errorf("kit instance has no agent")
|
||||
}
|
||||
mgr := m.agent.GetMCPToolManager()
|
||||
if mgr == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// mcpTaskFromInternal adapts the internal tools.MCPTaskInfo to the
|
||||
// SDK-level MCPTask type. Keeps the public surface independent of
|
||||
// internal package types.
|
||||
func mcpTaskFromInternal(t tools.MCPTaskInfo) MCPTask {
|
||||
return MCPTask{
|
||||
Server: t.Server,
|
||||
TaskID: t.TaskID,
|
||||
Status: MCPTaskStatus(t.Status),
|
||||
StatusMessage: t.StatusMessage,
|
||||
CreatedAt: t.CreatedAt,
|
||||
UpdatedAt: t.UpdatedAt,
|
||||
TTL: t.TTL,
|
||||
PollInterval: t.PollInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// inheritMCPTaskOptions copies every MCP task-related field from parent
|
||||
// onto child. Used by Kit.Subagent so child instances observe the same
|
||||
// per-server modes, timeouts, and progress callback as their parent.
|
||||
// A nil parent is a no-op so callers don't have to guard at the call site.
|
||||
func inheritMCPTaskOptions(child, parent *Options) {
|
||||
if child == nil || parent == nil {
|
||||
return
|
||||
}
|
||||
child.MCPTaskMode = parent.MCPTaskMode
|
||||
child.MCPTaskTimeout = parent.MCPTaskTimeout
|
||||
child.MCPTaskTTL = parent.MCPTaskTTL
|
||||
child.MCPTaskPollInterval = parent.MCPTaskPollInterval
|
||||
child.MCPTaskMaxPollInterval = parent.MCPTaskMaxPollInterval
|
||||
child.MCPTaskProgress = parent.MCPTaskProgress
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
func TestMCPTaskStatusIsTerminal(t *testing.T) {
|
||||
cases := []struct {
|
||||
s MCPTaskStatus
|
||||
want bool
|
||||
}{
|
||||
{MCPTaskStatusWorking, false},
|
||||
{MCPTaskStatusInputRequired, false},
|
||||
{MCPTaskStatusCompleted, true},
|
||||
{MCPTaskStatusFailed, true},
|
||||
{MCPTaskStatusCancelled, true},
|
||||
{MCPTaskStatus("unknown"), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := tc.s.IsTerminal(); got != tc.want {
|
||||
t.Errorf("MCPTaskStatus(%q).IsTerminal() = %v, want %v", tc.s, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTaskOptionsToToolsConfig(t *testing.T) {
|
||||
called := 0
|
||||
o := mcpTaskOptions{
|
||||
perServer: map[string]MCPTaskMode{
|
||||
"alpha": MCPTaskModeAlways,
|
||||
"beta": MCPTaskModeNever,
|
||||
},
|
||||
defaultTTL: 30 * time.Second,
|
||||
pollInterval: 250 * time.Millisecond,
|
||||
maxPollInterval: 2 * time.Second,
|
||||
timeout: 5 * time.Minute,
|
||||
progress: func(p MCPTaskProgress) { called++ },
|
||||
}
|
||||
cfg := o.toToolsConfig()
|
||||
|
||||
if cfg.DefaultTTL != 30*time.Second {
|
||||
t.Errorf("DefaultTTL = %v, want 30s", cfg.DefaultTTL)
|
||||
}
|
||||
if cfg.PollInterval != 250*time.Millisecond {
|
||||
t.Errorf("PollInterval = %v, want 250ms", cfg.PollInterval)
|
||||
}
|
||||
if cfg.MaxPollInterval != 2*time.Second {
|
||||
t.Errorf("MaxPollInterval = %v, want 2s", cfg.MaxPollInterval)
|
||||
}
|
||||
if cfg.Timeout != 5*time.Minute {
|
||||
t.Errorf("Timeout = %v, want 5m", cfg.Timeout)
|
||||
}
|
||||
if cfg.PerServerMode["alpha"] != tools.MCPTaskModeAlways {
|
||||
t.Errorf("PerServerMode[alpha] = %q, want always", cfg.PerServerMode["alpha"])
|
||||
}
|
||||
if cfg.PerServerMode["beta"] != tools.MCPTaskModeNever {
|
||||
t.Errorf("PerServerMode[beta] = %q, want never", cfg.PerServerMode["beta"])
|
||||
}
|
||||
|
||||
// Progress conversion: invoking the internal handler must call our
|
||||
// SDK-level callback with the converted struct.
|
||||
if cfg.Progress == nil {
|
||||
t.Fatal("Progress callback was lost in conversion")
|
||||
}
|
||||
cfg.Progress(tools.MCPTaskProgress{
|
||||
Server: "alpha",
|
||||
TaskID: "t1",
|
||||
Status: "working",
|
||||
})
|
||||
if called != 1 {
|
||||
t.Errorf("expected SDK progress handler to be invoked once, got %d", called)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTaskFromInternal(t *testing.T) {
|
||||
in := tools.MCPTaskInfo{
|
||||
Server: "srv",
|
||||
TaskID: "t-1",
|
||||
Status: "working",
|
||||
StatusMessage: "phase 1",
|
||||
CreatedAt: time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2026, 5, 4, 12, 0, 1, 0, time.UTC),
|
||||
TTL: 5 * time.Minute,
|
||||
PollInterval: 500 * time.Millisecond,
|
||||
}
|
||||
out := mcpTaskFromInternal(in)
|
||||
|
||||
if out.Server != "srv" || out.TaskID != "t-1" {
|
||||
t.Errorf("identity fields not copied: %+v", out)
|
||||
}
|
||||
if out.Status != MCPTaskStatusWorking {
|
||||
t.Errorf("Status = %q, want working", out.Status)
|
||||
}
|
||||
if out.StatusMessage != "phase 1" {
|
||||
t.Errorf("StatusMessage = %q, want phase 1", out.StatusMessage)
|
||||
}
|
||||
if out.TTL != 5*time.Minute || out.PollInterval != 500*time.Millisecond {
|
||||
t.Errorf("durations not copied: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKitMCPTasksWithoutAgentReturnsError(t *testing.T) {
|
||||
// A nil/zero Kit must not panic — task RPCs should surface a clear
|
||||
// error instead. Useful for SDK consumers that try task ops on a Kit
|
||||
// constructed without MCP servers.
|
||||
var k *Kit
|
||||
ctx := t.Context()
|
||||
if _, err := k.ListMCPTasks(ctx, "any"); err == nil {
|
||||
t.Error("ListMCPTasks on nil Kit should error")
|
||||
}
|
||||
if _, err := k.GetMCPTask(ctx, "any", "id"); err == nil {
|
||||
t.Error("GetMCPTask on nil Kit should error")
|
||||
}
|
||||
if _, err := k.CancelMCPTask(ctx, "any", "id"); err == nil {
|
||||
t.Error("CancelMCPTask on nil Kit should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentPropagatesMCPTaskOptions(t *testing.T) {
|
||||
// Exercises the helper Kit.Subagent uses to copy MCP task options
|
||||
// onto child Options. Calling the real helper (rather than
|
||||
// duplicating its body in the test) means any new field added to
|
||||
// the propagation list is picked up automatically by the
|
||||
// equivalence assertion below.
|
||||
parent := &Options{
|
||||
MCPTaskMode: map[string]MCPTaskMode{
|
||||
"build": MCPTaskModeAlways,
|
||||
"chat": MCPTaskModeNever,
|
||||
},
|
||||
MCPTaskTimeout: 30 * time.Minute,
|
||||
MCPTaskTTL: 45 * time.Minute,
|
||||
MCPTaskPollInterval: 750 * time.Millisecond,
|
||||
MCPTaskMaxPollInterval: 4 * time.Second,
|
||||
MCPTaskProgress: func(MCPTaskProgress) {},
|
||||
}
|
||||
|
||||
child := &Options{}
|
||||
inheritMCPTaskOptions(child, parent)
|
||||
|
||||
if child.MCPTaskMode["build"] != MCPTaskModeAlways || child.MCPTaskMode["chat"] != MCPTaskModeNever {
|
||||
t.Errorf("MCPTaskMode not propagated: got %+v", child.MCPTaskMode)
|
||||
}
|
||||
if child.MCPTaskTimeout != 30*time.Minute {
|
||||
t.Errorf("MCPTaskTimeout = %v, want 30m", child.MCPTaskTimeout)
|
||||
}
|
||||
if child.MCPTaskTTL != 45*time.Minute {
|
||||
t.Errorf("MCPTaskTTL = %v, want 45m", child.MCPTaskTTL)
|
||||
}
|
||||
if child.MCPTaskPollInterval != 750*time.Millisecond {
|
||||
t.Errorf("MCPTaskPollInterval = %v, want 750ms", child.MCPTaskPollInterval)
|
||||
}
|
||||
if child.MCPTaskMaxPollInterval != 4*time.Second {
|
||||
t.Errorf("MCPTaskMaxPollInterval = %v, want 4s", child.MCPTaskMaxPollInterval)
|
||||
}
|
||||
if child.MCPTaskProgress == nil {
|
||||
t.Error("MCPTaskProgress not propagated")
|
||||
}
|
||||
|
||||
// Nil parent is a no-op rather than a panic.
|
||||
inheritMCPTaskOptions(&Options{}, nil)
|
||||
inheritMCPTaskOptions(nil, parent)
|
||||
}
|
||||
+52
-22
@@ -2,6 +2,7 @@ package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
@@ -52,6 +53,22 @@ func ErrorResult(content string) ToolOutput {
|
||||
return ToolOutput{Content: content, IsError: true}
|
||||
}
|
||||
|
||||
// ImageResult creates a [ToolOutput] that returns an image to the LLM.
|
||||
// The data is the raw image bytes and mediaType is the MIME type
|
||||
// (e.g. "image/png", "image/jpeg"). The optional text content accompanies
|
||||
// the image and is visible to the LLM alongside it.
|
||||
func ImageResult(content string, data []byte, mediaType string) ToolOutput {
|
||||
return ToolOutput{Content: content, Data: data, MediaType: mediaType}
|
||||
}
|
||||
|
||||
// MediaResult creates a [ToolOutput] that returns non-image binary media
|
||||
// (e.g. audio, video) to the LLM. The data is the raw bytes and mediaType
|
||||
// is the MIME type (e.g. "audio/wav", "video/mp4"). The optional text
|
||||
// content accompanies the media.
|
||||
func MediaResult(content string, data []byte, mediaType string) ToolOutput {
|
||||
return ToolOutput{Content: content, Data: data, MediaType: mediaType}
|
||||
}
|
||||
|
||||
// toolCallIDKey is the context key for the tool call ID.
|
||||
type toolCallIDKey struct{}
|
||||
|
||||
@@ -63,9 +80,35 @@ func ToolCallIDFromContext(ctx context.Context) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// toolOutputToResponse converts a [ToolOutput] into the underlying
|
||||
// framework's ToolResponse, inferring the response Type from Data/MediaType
|
||||
// so that binary content (images, audio, etc.) is forwarded to the LLM
|
||||
// instead of being silently dropped.
|
||||
func toolOutputToResponse(result ToolOutput) fantasy.ToolResponse {
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
// Infer response type from binary data so the downstream framework
|
||||
// creates a media content block instead of a plain-text one.
|
||||
if len(result.Data) > 0 && result.MediaType != "" {
|
||||
if strings.HasPrefix(result.MediaType, "image/") {
|
||||
resp.Type = "image"
|
||||
} else {
|
||||
resp.Type = "media"
|
||||
}
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// NewTool creates a custom [Tool] with automatic JSON schema generation from
|
||||
// the TInput struct type. The handler receives a typed input (deserialized
|
||||
// from the LLM's JSON arguments) and returns a [ToolResult].
|
||||
// from the LLM's JSON arguments) and returns a [ToolOutput].
|
||||
//
|
||||
// Struct tags on TInput control the generated schema:
|
||||
//
|
||||
@@ -77,6 +120,11 @@ func ToolCallIDFromContext(ctx context.Context) string {
|
||||
// The tool call ID is injected into the context and can be retrieved with
|
||||
// [ToolCallIDFromContext].
|
||||
//
|
||||
// Binary results: When [ToolOutput.Data] and [ToolOutput.MediaType] are set,
|
||||
// the response type is automatically inferred so the LLM receives the binary
|
||||
// content (e.g. an image) instead of only the text. Use [ImageResult] or
|
||||
// [MediaResult] for convenience.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type WeatherInput struct {
|
||||
@@ -84,7 +132,7 @@ func ToolCallIDFromContext(ctx context.Context) string {
|
||||
// }
|
||||
//
|
||||
// tool := kit.NewTool("get_weather", "Get weather for a city",
|
||||
// func(ctx context.Context, input WeatherInput) (kit.ToolResult, error) {
|
||||
// func(ctx context.Context, input WeatherInput) (kit.ToolOutput, error) {
|
||||
// return kit.TextResult("72°F, sunny in " + input.City), nil
|
||||
// },
|
||||
// )
|
||||
@@ -96,16 +144,7 @@ func NewTool[TInput any](name, description string, fn func(ctx context.Context,
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
return toolOutputToResponse(result), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -121,16 +160,7 @@ func NewParallelTool[TInput any](name, description string, fn func(ctx context.C
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
resp := fantasy.ToolResponse{
|
||||
Content: result.Content,
|
||||
IsError: result.IsError,
|
||||
Data: result.Data,
|
||||
MediaType: result.MediaType,
|
||||
}
|
||||
if result.Metadata != nil {
|
||||
resp = fantasy.WithResponseMetadata(resp, result.Metadata)
|
||||
}
|
||||
return resp, nil
|
||||
return toolOutputToResponse(result), nil
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -117,3 +117,149 @@ func TestToolOutput_BinaryData(t *testing.T) {
|
||||
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
|
||||
}
|
||||
}
|
||||
|
||||
// TestImageResult verifies the ImageResult convenience constructor.
|
||||
func TestImageResult(t *testing.T) {
|
||||
data := []byte{0x89, 0x50, 0x4E, 0x47}
|
||||
r := kit.ImageResult("here is the image", data, "image/png")
|
||||
if r.Content != "here is the image" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "here is the image")
|
||||
}
|
||||
if len(r.Data) != 4 {
|
||||
t.Errorf("Data len = %d, want 4", len(r.Data))
|
||||
}
|
||||
if r.MediaType != "image/png" {
|
||||
t.Errorf("MediaType = %q, want %q", r.MediaType, "image/png")
|
||||
}
|
||||
if r.IsError {
|
||||
t.Error("ImageResult should not set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMediaResult verifies the MediaResult convenience constructor.
|
||||
func TestMediaResult(t *testing.T) {
|
||||
data := []byte{0xFF, 0xFB, 0x90, 0x00}
|
||||
r := kit.MediaResult("audio clip", data, "audio/mpeg")
|
||||
if r.Content != "audio clip" {
|
||||
t.Errorf("Content = %q, want %q", r.Content, "audio clip")
|
||||
}
|
||||
if len(r.Data) != 4 {
|
||||
t.Errorf("Data len = %d, want 4", len(r.Data))
|
||||
}
|
||||
if r.MediaType != "audio/mpeg" {
|
||||
t.Errorf("MediaType = %q, want %q", r.MediaType, "audio/mpeg")
|
||||
}
|
||||
if r.IsError {
|
||||
t.Error("MediaResult should not set IsError")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTool_BinaryImageResponse verifies that NewTool correctly infers the
|
||||
// response type for image data so binary content is forwarded to the LLM
|
||||
// (issue #17).
|
||||
func TestNewTool_BinaryImageResponse(t *testing.T) {
|
||||
type Input struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
imgData := []byte{0x89, 0x50, 0x4E, 0x47} // PNG magic bytes
|
||||
|
||||
tool := kit.NewTool("read_image", "Read an image file",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.ImageResult("Here is the image", imgData, "image/png"), nil
|
||||
},
|
||||
)
|
||||
|
||||
// Run the tool and inspect the raw ToolResponse via the AgentTool interface.
|
||||
resp, err := tool.Run(context.Background(), kit.LLMToolCall{
|
||||
ID: "call_1",
|
||||
Name: "read_image",
|
||||
Input: `{"path": "test.png"}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error: %v", err)
|
||||
}
|
||||
|
||||
// The Type field must be "image" so the downstream framework creates a
|
||||
// media content block instead of discarding the binary data.
|
||||
if resp.Type != "image" {
|
||||
t.Errorf("ToolResponse.Type = %q, want %q", resp.Type, "image")
|
||||
}
|
||||
if len(resp.Data) != 4 {
|
||||
t.Errorf("ToolResponse.Data len = %d, want 4", len(resp.Data))
|
||||
}
|
||||
if resp.MediaType != "image/png" {
|
||||
t.Errorf("ToolResponse.MediaType = %q, want %q", resp.MediaType, "image/png")
|
||||
}
|
||||
if resp.Content != "Here is the image" {
|
||||
t.Errorf("ToolResponse.Content = %q, want %q", resp.Content, "Here is the image")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTool_BinaryMediaResponse verifies type inference for non-image media.
|
||||
func TestNewTool_BinaryMediaResponse(t *testing.T) {
|
||||
type Input struct{}
|
||||
|
||||
tool := kit.NewTool("get_audio", "Get audio",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.MediaResult("audio clip", []byte{0xFF, 0xFB}, "audio/mpeg"), nil
|
||||
},
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), kit.LLMToolCall{
|
||||
ID: "call_2",
|
||||
Name: "get_audio",
|
||||
Input: `{}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error: %v", err)
|
||||
}
|
||||
if resp.Type != "media" {
|
||||
t.Errorf("ToolResponse.Type = %q, want %q", resp.Type, "media")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewTool_TextResponseTypeNotSet verifies that text-only responses do NOT
|
||||
// get an inferred type (preserving existing behavior).
|
||||
func TestNewTool_TextResponseTypeNotSet(t *testing.T) {
|
||||
type Input struct{}
|
||||
|
||||
tool := kit.NewTool("echo", "Echo",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("hello"), nil
|
||||
},
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), kit.LLMToolCall{
|
||||
ID: "call_3", Name: "echo", Input: `{}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error: %v", err)
|
||||
}
|
||||
// Text responses should not have Type set (the framework treats "" as text).
|
||||
if resp.Type != "" {
|
||||
t.Errorf("ToolResponse.Type = %q, want empty string for text responses", resp.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewParallelTool_BinaryImageResponse mirrors the NewTool binary test for
|
||||
// NewParallelTool.
|
||||
func TestNewParallelTool_BinaryImageResponse(t *testing.T) {
|
||||
type Input struct{}
|
||||
|
||||
tool := kit.NewParallelTool("snap", "Take a snapshot",
|
||||
func(ctx context.Context, input Input) (kit.ToolOutput, error) {
|
||||
return kit.ImageResult("snapshot", []byte{0xFF, 0xD8}, "image/jpeg"), nil
|
||||
},
|
||||
)
|
||||
|
||||
resp, err := tool.Run(context.Background(), kit.LLMToolCall{
|
||||
ID: "call_4", Name: "snap", Input: `{}`,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Run() error: %v", err)
|
||||
}
|
||||
if resp.Type != "image" {
|
||||
t.Errorf("ToolResponse.Type = %q, want %q", resp.Type, "image")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,6 +157,18 @@ type LLMTextPart = fantasy.TextPart
|
||||
// LLMReasoningPart is a reasoning/chain-of-thought content part.
|
||||
type LLMReasoningPart = fantasy.ReasoningPart
|
||||
|
||||
// LLMToolCall represents the raw tool invocation passed to a [Tool]'s Run
|
||||
// method. It carries the call ID, tool name, and the JSON-encoded input
|
||||
// arguments from the LLM. This is the execution-layer call object — distinct
|
||||
// from [ToolCall] (a message content part).
|
||||
type LLMToolCall = fantasy.ToolCall
|
||||
|
||||
// LLMToolResponse represents the raw response returned from a [Tool]'s Run
|
||||
// method. Most SDK consumers should use [ToolOutput] with [NewTool] /
|
||||
// [NewParallelTool] instead — this alias is provided for advanced use cases
|
||||
// that need to call Tool.Run() directly (e.g. testing).
|
||||
type LLMToolResponse = fantasy.ToolResponse
|
||||
|
||||
// LLMToolCallPart represents an LLM-initiated tool invocation within a message.
|
||||
type LLMToolCallPart = fantasy.ToolCallPart
|
||||
|
||||
@@ -172,6 +184,40 @@ type LLMToolResultOutputContentText = fantasy.ToolResultOutputContentText
|
||||
// LLMToolResultOutputContentError is an error-valued tool result output.
|
||||
type LLMToolResultOutputContentError = fantasy.ToolResultOutputContentError
|
||||
|
||||
// LLMToolResultOutputContentMedia is a media-valued tool result output
|
||||
// (images, audio, etc.) carrying base64-encoded data and a MIME type.
|
||||
type LLMToolResultOutputContentMedia = fantasy.ToolResultOutputContentMedia
|
||||
|
||||
// LLMToolResultContentType classifies the kind of a tool result output
|
||||
// ("text", "error", or "media").
|
||||
type LLMToolResultContentType = fantasy.ToolResultContentType
|
||||
|
||||
// Tool result content type constants.
|
||||
const (
|
||||
// LLMToolResultContentTypeText represents text output.
|
||||
LLMToolResultContentTypeText = fantasy.ToolResultContentTypeText
|
||||
// LLMToolResultContentTypeError represents error text output.
|
||||
LLMToolResultContentTypeError = fantasy.ToolResultContentTypeError
|
||||
// LLMToolResultContentTypeMedia represents media (binary) output.
|
||||
LLMToolResultContentTypeMedia = fantasy.ToolResultContentTypeMedia
|
||||
)
|
||||
|
||||
// LLMToolInfo describes a tool's name, description, and JSON-Schema parameters.
|
||||
type LLMToolInfo = fantasy.ToolInfo
|
||||
|
||||
// LLMProviderOptions carries provider-specific key/value option maps, keyed
|
||||
// by provider name (e.g. "anthropic"). Use this when configuring or
|
||||
// inspecting provider-specific tool behaviour.
|
||||
type LLMProviderOptions = fantasy.ProviderOptions
|
||||
|
||||
// LLMProviderMetadata carries provider-specific metadata returned alongside
|
||||
// LLM responses, keyed by provider name.
|
||||
type LLMProviderMetadata = fantasy.ProviderMetadata
|
||||
|
||||
// LLMPrompt is an ordered sequence of [LLMMessage] values forming a complete
|
||||
// prompt for the LLM.
|
||||
type LLMPrompt = fantasy.Prompt
|
||||
|
||||
// LLMMessageRole identifies the participant role in an LLM conversation.
|
||||
type LLMMessageRole = fantasy.MessageRole
|
||||
|
||||
@@ -198,6 +244,12 @@ var NewLLMUserMessage = fantasy.NewUserMessage
|
||||
// prompt strings.
|
||||
var NewLLMSystemMessage = fantasy.NewSystemMessage
|
||||
|
||||
// newLLMTextErrorResponse creates a tool-error response (internal helper).
|
||||
var newLLMTextErrorResponse = fantasy.NewTextErrorResponse
|
||||
|
||||
// newLLMTextResponse creates a plain-text tool response (internal helper).
|
||||
var newLLMTextResponse = fantasy.NewTextResponse
|
||||
|
||||
// ==== Compaction Types (internal/compaction/) ====
|
||||
|
||||
// CompactionResult contains statistics from a compaction operation.
|
||||
|
||||
+102
-3
@@ -281,7 +281,7 @@ host.OnToolOutput(func(e kit.ToolOutputEvent) {
|
||||
// Streaming bash output chunks
|
||||
})
|
||||
|
||||
host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk) // real-time text streaming
|
||||
})
|
||||
|
||||
@@ -296,8 +296,64 @@ host.OnTurnStart(func(e kit.TurnStartEvent) {
|
||||
host.OnTurnEnd(func(e kit.TurnEndEvent) {
|
||||
// e.Response, e.Error, e.StopReason
|
||||
})
|
||||
|
||||
host.OnStepStart(func(e kit.StepStartEvent) {
|
||||
// e.StepNumber — which LLM call step (1-based)
|
||||
})
|
||||
|
||||
host.OnStepFinish(func(e kit.StepFinishEvent) {
|
||||
// e.StepNumber, e.HasToolCalls, e.FinishReason, e.Usage (LLMUsage)
|
||||
})
|
||||
|
||||
host.OnWarnings(func(e kit.WarningsEvent) {
|
||||
for _, w := range e.Warnings {
|
||||
log.Printf("warning: %s", w)
|
||||
}
|
||||
})
|
||||
|
||||
host.OnError(func(e kit.ErrorEvent) {
|
||||
log.Printf("agent error: %v", e.Error)
|
||||
})
|
||||
|
||||
host.OnRetry(func(e kit.RetryEvent) {
|
||||
log.Printf("retrying (attempt %d): %v", e.Attempt, e.Error)
|
||||
})
|
||||
|
||||
host.OnTextStart(func(e kit.TextStartEvent) {
|
||||
// e.ID — content block ID
|
||||
})
|
||||
|
||||
host.OnTextEnd(func(e kit.TextEndEvent) {
|
||||
// e.ID — content block ID
|
||||
})
|
||||
|
||||
host.OnReasoningStart(func(e kit.ReasoningStartEvent) {
|
||||
// e.ID — reasoning block ID
|
||||
})
|
||||
|
||||
host.OnSource(func(e kit.SourceEvent) {
|
||||
// e.SourceType, e.ID, e.URL, e.Title
|
||||
})
|
||||
|
||||
host.OnStreamFinish(func(e kit.StreamFinishEvent) {
|
||||
// e.Usage (LLMUsage), e.FinishReason
|
||||
})
|
||||
|
||||
// Additional typed subscribers for previously generic-only events:
|
||||
host.OnMessageStart(func(e kit.MessageStartEvent) {})
|
||||
host.OnMessageEnd(func(e kit.MessageEndEvent) { /* e.Content */ })
|
||||
host.OnReasoningDelta(func(e kit.ReasoningDeltaEvent) { /* e.Delta */ })
|
||||
host.OnReasoningComplete(func(e kit.ReasoningCompleteEvent) {})
|
||||
host.OnToolExecutionStart(func(e kit.ToolExecutionStartEvent) { /* e.ToolCallID, e.ToolName, e.ToolKind, e.ToolArgs */ })
|
||||
host.OnToolExecutionEnd(func(e kit.ToolExecutionEndEvent) { /* e.ToolCallID, e.ToolName, e.ToolKind */ })
|
||||
host.OnToolCallContent(func(e kit.ToolCallContentEvent) { /* e.Content */ })
|
||||
host.OnStepUsage(func(e kit.StepUsageEvent) { /* e.InputTokens, e.OutputTokens, e.CacheReadTokens, e.CacheWriteTokens */ })
|
||||
host.OnCompaction(func(e kit.CompactionEvent) { /* e.Summary, e.OriginalTokens, e.CompactedTokens, ... */ })
|
||||
host.OnSteerConsumed(func(e kit.SteerConsumedEvent) { /* e.Count */ })
|
||||
```
|
||||
|
||||
> **Rename note:** `OnStreaming` has been renamed to `OnMessageUpdate`. The old `OnStreaming` name is kept as a deprecated alias for one release cycle.
|
||||
|
||||
### Generic subscriber (receives all events)
|
||||
|
||||
```go
|
||||
@@ -336,6 +392,16 @@ unsub := host.Subscribe(func(e kit.Event) {
|
||||
| `reasoning_delta` | `ReasoningDeltaEvent` | `Delta` |
|
||||
| `step_usage` | `StepUsageEvent` | `InputTokens`, `OutputTokens`, `CacheReadTokens`, `CacheWriteTokens` |
|
||||
| `steer_consumed` | `SteerConsumedEvent` | `Count` |
|
||||
| `step_start` | `StepStartEvent` | `StepNumber` |
|
||||
| `step_finish` | `StepFinishEvent` | `StepNumber`, `HasToolCalls`, `FinishReason`, `Usage` |
|
||||
| `text_start` | `TextStartEvent` | `ID` |
|
||||
| `text_end` | `TextEndEvent` | `ID` |
|
||||
| `reasoning_start` | `ReasoningStartEvent` | `ID` |
|
||||
| `warnings` | `WarningsEvent` | `Warnings` |
|
||||
| `source` | `SourceEvent` | `SourceType`, `ID`, `URL`, `Title` |
|
||||
| `stream_finish` | `StreamFinishEvent` | `Usage`, `FinishReason` |
|
||||
| `error` | `ErrorEvent` | `Error` |
|
||||
| `retry` | `RetryEvent` | `Attempt`, `Error` |
|
||||
| `password_prompt` | `PasswordPromptEvent` | `Prompt`, `ResponseCh` |
|
||||
|
||||
**Tool call streaming lifecycle**: `ToolCallStartEvent` → `ToolCallDeltaEvent` (repeated) → `ToolCallEndEvent` → `ToolCallEvent` → `ToolExecutionStartEvent` → `ToolOutputEvent` (optional, repeated) → `ToolExecutionEndEvent` → `ToolResultEvent`
|
||||
@@ -421,6 +487,20 @@ host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
})
|
||||
```
|
||||
|
||||
### PrepareStep — intercept/replace messages before each LLM call
|
||||
|
||||
```go
|
||||
host.OnPrepareStep(kit.HookPriorityNormal, func(h kit.PrepareStepHook) *kit.PrepareStepResult {
|
||||
// h.StepNumber — which step in the current turn (1-based)
|
||||
// h.Messages — []kit.LLMMessage being sent to the LLM
|
||||
// Return nil to pass through unchanged, or replace messages:
|
||||
modified := filterSensitiveMessages(h.Messages)
|
||||
return &kit.PrepareStepResult{Messages: modified}
|
||||
})
|
||||
```
|
||||
|
||||
`PrepareStep` fires before every LLM API call within a turn (including tool-call loop iterations). Unlike `ContextPrepare` (which operates on the full context window once per turn), `PrepareStep` runs per-step and sees the messages that include the latest tool results.
|
||||
|
||||
### ContextPrepare — filter/inject context window
|
||||
|
||||
```go
|
||||
@@ -493,6 +573,8 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
|----------|-------------|
|
||||
| `kit.TextResult(content)` | Successful text result |
|
||||
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
|
||||
| `kit.ImageResult(content, data, mediaType)` | Image result with binary data (e.g. `"image/png"`) |
|
||||
| `kit.MediaResult(content, data, mediaType)` | Non-image media result (e.g. `"audio/mpeg"`) |
|
||||
|
||||
**ToolOutput fields** (for advanced use):
|
||||
|
||||
@@ -1088,13 +1170,30 @@ kit.Config, kit.MCPServerConfig
|
||||
// Provider types
|
||||
kit.ProviderConfig, kit.ProviderResult, kit.ModelInfo, kit.ModelCost, kit.ModelLimit
|
||||
|
||||
// LLM types — concrete Kit-owned structs (no external library dependency)
|
||||
// LLM types — clean aliases (no external library dependency in consumer code)
|
||||
kit.LLMMessage // {Role LLMMessageRole, Content string}
|
||||
kit.LLMMessagePart // interface for message content parts
|
||||
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ReasoningTokens,
|
||||
// CacheCreationTokens, CacheReadTokens}
|
||||
kit.LLMResponse // {Content, FinishReason, Usage}
|
||||
kit.LLMFilePart // {Filename, Data []byte, MediaType}
|
||||
kit.LLMTextPart // plain-text content part
|
||||
kit.LLMReasoningPart // reasoning/chain-of-thought content part
|
||||
kit.LLMToolCall // {ID, Name, Input string} — execution-layer tool call (for Tool.Run)
|
||||
kit.LLMToolResponse // {Type, Content, Data, MediaType, IsError, ...} — raw tool response
|
||||
kit.LLMToolCallPart // LLM-initiated tool invocation within a message
|
||||
kit.LLMToolResultPart // tool result within a message
|
||||
kit.LLMToolResultOutputContent // interface for tool result output
|
||||
kit.LLMToolResultOutputContentText // text tool result
|
||||
kit.LLMToolResultOutputContentError // error tool result
|
||||
kit.LLMToolResultOutputContentMedia // media tool result {Data, MediaType, Text}
|
||||
kit.LLMToolResultContentType // "text" | "error" | "media"
|
||||
kit.LLMToolInfo // {Name, Description, Parameters, Required, Parallel}
|
||||
kit.LLMProviderOptions // provider-specific option maps (keyed by provider name)
|
||||
kit.LLMProviderMetadata // provider-specific response metadata
|
||||
kit.LLMPrompt // []LLMMessage — ordered prompt sequence
|
||||
kit.LLMFinishReason // "stop" | "length" | "tool-calls" | ...
|
||||
|
||||
// Compaction types
|
||||
kit.CompactionResult, kit.CompactionOptions
|
||||
@@ -1168,7 +1267,7 @@ for {
|
||||
### Pattern: Streaming output to terminal
|
||||
|
||||
```go
|
||||
host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
response, _ := host.Prompt(ctx, "Write a poem")
|
||||
|
||||
@@ -88,6 +88,11 @@ mcpServers:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
url: "https://builds.mcp.example.com"
|
||||
tasksMode: always # always run tools/call as async tasks (Phase 1 MVP)
|
||||
```
|
||||
|
||||
### MCP server fields
|
||||
@@ -101,9 +106,34 @@ mcpServers:
|
||||
| `allowedTools` | list | Whitelist of tool names to expose |
|
||||
| `excludedTools` | list | Blacklist of tool names to hide |
|
||||
| `noOAuth` | bool | Skip OAuth for this server (for public servers that don't require auth) |
|
||||
| `tasksMode` | string | When to augment `tools/call` with MCP task metadata: `auto` (default — only when the server advertises task support), `never`, or `always`. See [MCP tasks](#mcp-tasks-long-running-tools). |
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
### MCP tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize` so servers can respond to `tools/call` with a
|
||||
`CreateTaskResult` (a task ID + `working` status) instead of blocking until
|
||||
the operation finishes. Kit then polls `tasks/get` / `tasks/result` until the
|
||||
task reaches a terminal state, and best-effort `tasks/cancel`s on context
|
||||
cancellation.
|
||||
|
||||
This avoids HTTP/SSE proxy timeouts on long builds, deploys, and batch jobs,
|
||||
and lets the user/agent abort cleanly with Ctrl-C.
|
||||
|
||||
**Per-server `tasksMode`:**
|
||||
|
||||
| Value | Behaviour |
|
||||
|-------|-----------|
|
||||
| `auto` (default) | Augment `tools/call` with task metadata only when the server advertised `tasks/toolCalls` capability. Servers that don't advertise it run synchronously, exactly as before. |
|
||||
| `never` | Always issue `tools/call` synchronously, regardless of server capability. |
|
||||
| `always` | Always opt into task augmentation, even when the server didn't advertise the capability. The server may still respond synchronously — this just expresses client intent unconditionally. |
|
||||
|
||||
Defaults are safe: any existing MCP server keeps its previous behaviour
|
||||
bit-for-bit. SDK consumers can also override the mode programmatically and
|
||||
plug in a progress callback — see [SDK options](/sdk/options#mcp-tasks).
|
||||
|
||||
## Custom models
|
||||
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
+54
-14
@@ -20,7 +20,7 @@ unsub2 := host.OnToolResult(func(event kit.ToolResultEvent) {
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(event kit.MessageUpdateEvent) {
|
||||
unsub3 := host.OnMessageUpdate(func(event kit.MessageUpdateEvent) {
|
||||
fmt.Print(event.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
@@ -116,6 +116,24 @@ host.OnAfterTurn(kit.HookPriorityNormal, func(h kit.AfterTurnHook) {
|
||||
})
|
||||
```
|
||||
|
||||
### PrepareStep — intercept messages between steps
|
||||
|
||||
The most powerful hook — fires between steps within a multi-step agent turn, after any steering messages are injected and before messages are sent to the LLM. Can replace the entire context window.
|
||||
|
||||
```go
|
||||
host.OnPrepareStep(kit.HookPriorityNormal, func(h kit.PrepareStepHook) *kit.PrepareStepResult {
|
||||
// h.StepNumber — zero-based step index within the turn
|
||||
// h.Messages — current context window (includes any steering)
|
||||
|
||||
// Example: transform tool results with images into user messages
|
||||
modified := transformImageToolResults(h.Messages)
|
||||
return &kit.PrepareStepResult{Messages: modified}
|
||||
// Return nil to pass through unchanged
|
||||
})
|
||||
```
|
||||
|
||||
Use cases: transforming tool results (e.g., image data for vision models), dynamic tool filtering per step, mid-turn context injection, custom stop conditions.
|
||||
|
||||
### Hook priorities
|
||||
|
||||
```go
|
||||
@@ -128,19 +146,41 @@ Lower values run first. First non-nil result wins.
|
||||
|
||||
## All event types
|
||||
|
||||
| Event | Description |
|
||||
|-------|-------------|
|
||||
| `ToolCallStartEvent` | LLM began generating tool call arguments (tool name known, args streaming) |
|
||||
| `ToolCallDeltaEvent` | Streamed JSON fragment of tool call arguments |
|
||||
| `ToolCallEndEvent` | Tool argument streaming complete, before execution begins |
|
||||
| `ToolCallEvent` | Tool call fully parsed and about to execute |
|
||||
| `ToolResultEvent` | Tool execution completed with result |
|
||||
| `ToolOutputEvent` | Streaming output chunk from tool (e.g., bash stdout/stderr) |
|
||||
| `MessageUpdateEvent` | Streaming text chunk from LLM |
|
||||
| `ResponseEvent` | Final response received |
|
||||
| `TurnStartEvent` | Agent turn started |
|
||||
| `TurnEndEvent` | Agent turn completed |
|
||||
| `PasswordPromptEvent` | Sudo command needs password (respond via `ResponseCh`) |
|
||||
| Event | Typed Subscriber | Description |
|
||||
|-------|-----------------|-------------|
|
||||
| `TurnStartEvent` | `OnTurnStart` | Agent turn started |
|
||||
| `TurnEndEvent` | `OnTurnEnd` | Agent turn completed |
|
||||
| `MessageStartEvent` | `OnMessageStart` | New assistant message begins |
|
||||
| `MessageUpdateEvent` | `OnMessageUpdate` | Streaming text chunk from LLM |
|
||||
| `MessageEndEvent` | `OnMessageEnd` | Assistant message complete |
|
||||
| `ToolCallStartEvent` | `OnToolCallStart` | LLM began generating tool call arguments |
|
||||
| `ToolCallDeltaEvent` | `OnToolCallDelta` | Streamed JSON fragment of tool call arguments |
|
||||
| `ToolCallEndEvent` | `OnToolCallEnd` | Tool argument streaming complete |
|
||||
| `ToolCallEvent` | `OnToolCall` | Tool call fully parsed, about to execute |
|
||||
| `ToolExecutionStartEvent` | `OnToolExecutionStart` | Tool begins executing |
|
||||
| `ToolExecutionEndEvent` | `OnToolExecutionEnd` | Tool finishes executing |
|
||||
| `ToolResultEvent` | `OnToolResult` | Tool execution completed with result |
|
||||
| `ToolCallContentEvent` | `OnToolCallContent` | Text content alongside tool calls |
|
||||
| `ToolOutputEvent` | `OnToolOutput` | Streaming output chunk from tool (e.g., bash) |
|
||||
| `ResponseEvent` | `OnResponse` | Final response received |
|
||||
| `ReasoningStartEvent` | `OnReasoningStart` | LLM begins reasoning/thinking |
|
||||
| `ReasoningDeltaEvent` | `OnReasoningDelta` | Streaming reasoning/thinking chunk |
|
||||
| `ReasoningCompleteEvent` | `OnReasoningComplete` | Reasoning/thinking finished |
|
||||
| `StepStartEvent` | `OnStepStart` | New LLM call begins within a turn |
|
||||
| `StepFinishEvent` | `OnStepFinish` | Step completes (with usage, finish reason, tool call info) |
|
||||
| `StepUsageEvent` | `OnStepUsage` | Per-step token usage |
|
||||
| `StreamFinishEvent` | `OnStreamFinish` | Per-step stream completes (with usage + finish reason) |
|
||||
| `TextStartEvent` | `OnTextStart` | LLM begins text content generation |
|
||||
| `TextEndEvent` | `OnTextEnd` | LLM finishes text content generation |
|
||||
| `WarningsEvent` | `OnWarnings` | LLM provider returned warnings |
|
||||
| `SourceEvent` | `OnSource` | LLM referenced a source (e.g., web search) |
|
||||
| `ErrorEvent` | `OnError` | Agent-level error during streaming |
|
||||
| `RetryEvent` | `OnRetry` | LLM request retried after transient error |
|
||||
| `CompactionEvent` | `OnCompaction` | Conversation compacted |
|
||||
| `SteerConsumedEvent` | `OnSteerConsumed` | Steering messages injected into turn |
|
||||
| `PasswordPromptEvent` | — | Sudo command needs password (respond via `ResponseCh`) |
|
||||
|
||||
> **Note:** `OnStreaming` is a deprecated alias for `OnMessageUpdate` and will be removed in a future release.
|
||||
|
||||
## Subagent event monitoring
|
||||
|
||||
|
||||
@@ -169,6 +169,12 @@ when embedding Kit as a library.
|
||||
| `MCPAuthHandler` | `MCPAuthHandler` | — | OAuth handler for remote MCP servers. `nil` disables OAuth (servers returning 401 fail with the authorization-required error). See [MCP OAuth](#mcp-oauth-authorization) below. |
|
||||
| `MCPTokenStoreFactory` | `func` | — | Custom OAuth token storage for MCP servers (default: JSON file in `$XDG_CONFIG_HOME/.kit/mcp_tokens.json`). |
|
||||
| `InProcessMCPServers` | `map[string]*MCPServer` | — | In-process mcp-go servers (no subprocess) |
|
||||
| `MCPTaskMode` | `map[string]MCPTaskMode` | — | Per-server override for task-augmented `tools/call`. Keys are server names; missing entries fall back to the `tasksMode` field of the matching `MCPServerConfig`. See [MCP Tasks](#mcp-tasks). |
|
||||
| `MCPTaskTimeout` | `time.Duration` | `15m` | Maximum wall-clock to wait for a task to reach a terminal state. Independent of any per-call context deadline. |
|
||||
| `MCPTaskTTL` | `time.Duration` | — | TTL hint sent in `TaskParams` for every task-augmented call. Zero omits the field and lets the server pick. |
|
||||
| `MCPTaskPollInterval` | `time.Duration` | `1s` | Fallback interval between `tasks/get` requests when the server does not suggest one. |
|
||||
| `MCPTaskMaxPollInterval` | `time.Duration` | `5s` | Cap on the polling interval (a server-supplied `pollInterval` can otherwise grow without bound). |
|
||||
| `MCPTaskProgress` | `MCPTaskProgressHandler` | — | Optional callback invoked once when a task is accepted and on every observed status transition. The final invocation always carries a terminal status. |
|
||||
|
||||
## MCP OAuth Authorization
|
||||
|
||||
@@ -248,6 +254,79 @@ authorization URL and hang until the 2-minute callback timeout fires. Always
|
||||
set `OnAuthURL`, or use a higher-level wrapper like `CLIMCPAuthHandler`.
|
||||
:::
|
||||
|
||||
## MCP Tasks
|
||||
|
||||
The [MCP Tasks utility](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
turns a synchronous `tools/call` into a pollable async job: the server
|
||||
returns a `taskId` with status `working` immediately, and the client polls
|
||||
`tasks/get` / `tasks/result` until the task reaches a terminal state.
|
||||
|
||||
Kit advertises task support during `initialize` and, by default, augments
|
||||
`tools/call` with task metadata only when the server advertises
|
||||
`tasks/toolCalls` capability — so any existing MCP server keeps its previous
|
||||
synchronous behaviour bit-for-bit. Long-running tools (builds, deployments,
|
||||
batch jobs, sub-agent runs) get HTTP/SSE timeout-resistance and clean
|
||||
cancellation "for free" once both sides opt in.
|
||||
|
||||
### Per-server mode
|
||||
|
||||
```go
|
||||
import "time"
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways, // force task-augmented calls
|
||||
"chat-server": kit.MCPTaskModeNever, // force synchronous calls
|
||||
// any server not in the map honours its `tasksMode` config field
|
||||
// (default "auto")
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
| Mode | Behaviour |
|
||||
|---|---|
|
||||
| `MCPTaskModeAuto` (default) | Augment `tools/call` with `TaskParams` only when the server advertised `tasks/toolCalls`. |
|
||||
| `MCPTaskModeNever` | Always issue `tools/call` synchronously, ignoring server capability. |
|
||||
| `MCPTaskModeAlways` | Always opt in, even when the server didn't advertise the capability. The server may still respond synchronously. |
|
||||
|
||||
### Progress callbacks
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskTimeout: 15 * time.Minute, // total wall-clock cap
|
||||
MCPTaskTTL: 30 * time.Minute, // server retention hint
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s/%s: %s %s", p.Server, p.TaskID, p.Status, p.Message)
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
The handler fires once when a task is accepted and again on every observed
|
||||
status transition. The final call always carries a terminal status
|
||||
(`MCPTaskStatusCompleted`, `MCPTaskStatusFailed`, or `MCPTaskStatusCancelled`).
|
||||
Do not block in the handler — dispatch long work on a goroutine.
|
||||
|
||||
### Inspecting and cancelling tasks
|
||||
|
||||
```go
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
for _, t := range tasks {
|
||||
fmt.Printf("%s: %s (%s)\n", t.TaskID, t.Status, t.StatusMessage)
|
||||
}
|
||||
|
||||
t, _ := host.GetMCPTask(ctx, "build-server", taskID)
|
||||
if !t.Status.IsTerminal() {
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", taskID)
|
||||
}
|
||||
```
|
||||
|
||||
`Kit.ListMCPTasks`, `Kit.GetMCPTask`, and `Kit.CancelMCPTask` work against any
|
||||
loaded MCP server that advertises the corresponding capability.
|
||||
`MCPTaskStatus.IsTerminal()` is the canonical check for completion.
|
||||
|
||||
Context cancellation also works end-to-end: cancelling the `ctx` passed to a
|
||||
tool execution triggers a best-effort `tasks/cancel` before the call returns.
|
||||
|
||||
## Precedence
|
||||
|
||||
For any given generation or provider field, the effective value is resolved
|
||||
|
||||
@@ -101,8 +101,10 @@ Return values:
|
||||
|--------|-------------|
|
||||
| `kit.TextResult(s)` | Successful text result |
|
||||
| `kit.ErrorResult(s)` | Error result (LLM sees it as a tool error) |
|
||||
| `kit.ImageResult(s, data, mediaType)` | Image result with binary data (e.g. `"image/png"`) |
|
||||
| `kit.MediaResult(s, data, mediaType)` | Non-image media result (e.g. `"audio/mpeg"`) |
|
||||
|
||||
For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
|
||||
Binary data (images, audio, etc.) in `ToolOutput.Data` is automatically forwarded to the LLM when `MediaType` is set. For advanced use, return a `kit.ToolOutput` struct directly with `Data`, `MediaType`, and `Metadata` fields.
|
||||
|
||||
Use `kit.NewParallelTool` for tools that are safe to run concurrently. Use `kit.ToolCallIDFromContext(ctx)` to retrieve the LLM-assigned call ID for logging or tracing.
|
||||
|
||||
@@ -141,7 +143,7 @@ host.OnToolResult(func(event kit.ToolResultEvent) {
|
||||
fmt.Println("Tool result:", event.Name)
|
||||
})
|
||||
|
||||
host.OnStreaming(func(event kit.MessageUpdateEvent) {
|
||||
host.OnMessageUpdate(func(event kit.MessageUpdateEvent) {
|
||||
fmt.Print(event.Chunk)
|
||||
})
|
||||
```
|
||||
@@ -213,6 +215,33 @@ resources := host.ListMCPResources()
|
||||
content, _ := host.ReadMCPResource(ctx, "server", "file:///path")
|
||||
```
|
||||
|
||||
## MCP tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`, so cooperating servers can return a `taskId` immediately
|
||||
and let Kit poll `tasks/get` / `tasks/result` until the operation completes.
|
||||
This avoids HTTP/SSE proxy timeouts on long tools and gives you clean
|
||||
cancellation via context.
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s: %s", p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
// Inspect / cancel in-flight tasks
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
```
|
||||
|
||||
Defaults to `MCPTaskModeAuto` per server, so any existing MCP server keeps
|
||||
its previous synchronous behaviour. See [SDK options → MCP Tasks](/sdk/options#mcp-tasks)
|
||||
for the full surface.
|
||||
|
||||
## Context and compaction
|
||||
|
||||
Monitor and manage context usage:
|
||||
|
||||
@@ -1968,35 +1968,42 @@ a:hover { text-decoration: underline; }
|
||||
|
||||
function renderEditBody(input, result) {
|
||||
let html = '';
|
||||
const oldText = input.old_text || '';
|
||||
const newText = input.new_text || '';
|
||||
const edits = input.edits || [];
|
||||
|
||||
html += '<div class="tool-input">';
|
||||
html += '<div class="tool-input-label">Changes</div>';
|
||||
html += '<div class="diff-display">';
|
||||
|
||||
// Render old text lines as removed
|
||||
if (oldText) {
|
||||
const oldLines = oldText.split('\n');
|
||||
for (const line of oldLines) {
|
||||
html += `<div class="diff-line removed">- ${escapeHtml(line)}</div>`;
|
||||
for (const edit of edits) {
|
||||
const oldText = edit.old_text || '';
|
||||
const newText = edit.new_text || '';
|
||||
|
||||
html += '<div class="diff-display">';
|
||||
|
||||
// Render old text lines as removed
|
||||
if (oldText) {
|
||||
const oldLines = oldText.split('\n');
|
||||
for (const line of oldLines) {
|
||||
html += `<div class="diff-line removed">- ${escapeHtml(line)}</div>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Separator
|
||||
if (oldText && newText) {
|
||||
html += '<div class="diff-separator">───</div>';
|
||||
}
|
||||
|
||||
// Render new text lines as added
|
||||
if (newText) {
|
||||
const newLines = newText.split('\n');
|
||||
for (const line of newLines) {
|
||||
html += `<div class="diff-line added">+ ${escapeHtml(line)}</div>`;
|
||||
// Separator
|
||||
if (oldText && newText) {
|
||||
html += '<div class="diff-separator">───</div>';
|
||||
}
|
||||
|
||||
// Render new text lines as added
|
||||
if (newText) {
|
||||
const newLines = newText.split('\n');
|
||||
for (const line of newLines) {
|
||||
html += `<div class="diff-line added">+ ${escapeHtml(line)}</div>`;
|
||||
}
|
||||
}
|
||||
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
html += '</div></div>';
|
||||
html += '</div>';
|
||||
|
||||
if (result && result.is_error) {
|
||||
html += renderToolOutput(result.content, true);
|
||||
|
||||
Reference in New Issue
Block a user