mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
107 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 10abb29e4f | |||
| 7a04bdfeba | |||
| 7e4708f511 | |||
| 1e12102b92 | |||
| ab2a77c95e | |||
| 1e78153b50 | |||
| a613361969 | |||
| 67722b0c24 | |||
| 1a2f6da40f | |||
| 747f5be099 | |||
| d7c4565999 | |||
| bd24f3315c | |||
| 592f8dc84f | |||
| 66c4a1eb15 | |||
| 5104477631 | |||
| 394a4676a1 | |||
| 30f2bc243d | |||
| 922e246098 | |||
| 32b6376515 | |||
| cf194ff89a | |||
| 03006425fa | |||
| a322dfc59a | |||
| b1387d837e | |||
| f561f4cfd9 | |||
| 64caed57d4 | |||
| 975c30a773 | |||
| 35b9360d64 | |||
| 1b8373e133 | |||
| 1a5e4ce7c5 | |||
| 8823977612 | |||
| 24e2ea111c | |||
| 31ea80ec4f | |||
| 99f2680c2e | |||
| da7e05eb87 | |||
| a95714a22d | |||
| c4a2b0f1a3 | |||
| 2016570e2d | |||
| d557f4b870 | |||
| 65054fe3db | |||
| 97d2246375 | |||
| 1e12505741 | |||
| 6755597c9b | |||
| 45689cb30d | |||
| 78570d4188 | |||
| 7cf38b37ee | |||
| 4ef57eec4e | |||
| cbd828e190 | |||
| d304805106 | |||
| 6e36053856 | |||
| 92eaaf6a59 | |||
| e6084b7bd0 | |||
| 34d5abff9c | |||
| fc0ddd5f4f | |||
| 7aa6160c75 | |||
| e830bf87ca | |||
| 3881d1c28f | |||
| 53f6682bd0 | |||
| 996b15c9b9 | |||
| aeb704367c | |||
| d2e23295b6 | |||
| e5a13e2e12 | |||
| 558fb5214f | |||
| 61408ed490 | |||
| 3cfb6437f9 | |||
| d33ad4028b | |||
| 307dcd1734 | |||
| 81240b075e | |||
| 9a662d440c | |||
| 4ba9d6fab3 | |||
| aec0e7cc01 | |||
| bac04636bf | |||
| 5f851fd08e | |||
| f8371836d8 | |||
| 74f00244be | |||
| b5d7fd4f3e | |||
| 5857d40978 | |||
| 3ff701054a | |||
| c1dee3ceba | |||
| 2d9783a44d | |||
| 88dd216e15 | |||
| 9e5806ade8 | |||
| 50f586ec8f | |||
| 8a8e684dff | |||
| 7ef99ac60f | |||
| a67f514560 | |||
| b6bb35cb71 | |||
| 4e82fac442 | |||
| 5ec2217b0f | |||
| 8a851723ba | |||
| 53b628c5f8 | |||
| e1c94cb362 | |||
| ecf95b52e1 | |||
| 0641c92acc | |||
| 3bb20f5283 | |||
| 633fa38b2b | |||
| f905cee48c | |||
| 182c10ea1a | |||
| fcaa52bf1c | |||
| 7e6455732c | |||
| 71301a9035 | |||
| 0974d37ab2 | |||
| 398e825df8 | |||
| 3c51c20be7 | |||
| 25410af440 | |||
| 26c9f009f9 | |||
| e068487ff7 | |||
| 0ffb0ba788 |
@@ -1,268 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
diagnosticsTimeout = 20 * time.Second
|
||||
maxOutputBytes = 12_000
|
||||
)
|
||||
|
||||
type toolPathInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type lintResult struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
absPath, ok := resolveGoFilePath(e.Input, ctx.CWD)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
borderColor = "#f38ba8" // red
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: strings.Join(msgLines, "\n"),
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isEditOrWrite(toolName string) bool {
|
||||
return strings.EqualFold(toolName, "edit") || strings.EqualFold(toolName, "write")
|
||||
}
|
||||
|
||||
func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
var args toolPathInput
|
||||
if err := json.Unmarshal([]byte(inputJSON), &args); err != nil || args.Path == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
absPath := args.Path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
absPath = filepath.Join(cwd, absPath)
|
||||
}
|
||||
|
||||
if strings.ToLower(filepath.Ext(absPath)) != ".go" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "gopls", "check", absPath)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes), Err: fmt.Errorf("failed to run gopls check: %w", err)}
|
||||
}
|
||||
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes)}
|
||||
}
|
||||
|
||||
func runGolangCILint(cwd, target string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []string{
|
||||
"run",
|
||||
target,
|
||||
"--show-stats=false",
|
||||
"--output.text.path", "stdout",
|
||||
"--output.text.colors=false",
|
||||
"--output.text.print-issued-lines=false",
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "golangci-lint", args...)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
trimmed := truncate(string(out), maxOutputBytes)
|
||||
if err == nil {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
exitErr, ok := err.(*exec.ExitError)
|
||||
if ok && exitErr.ExitCode() == 1 {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
return lintResult{Output: trimmed, Err: fmt.Errorf("failed to run golangci-lint: %w", err)}
|
||||
}
|
||||
|
||||
func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
var lines []string
|
||||
if res.Err != nil {
|
||||
lines = append(lines, "ERROR: "+res.Err.Error())
|
||||
}
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return emptyFallback
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "\n... output truncated ..."
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
lintCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return goplsCount, lintCount
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
---
|
||||
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
|
||||
---
|
||||
|
||||
Perform a comprehensive **read-only** audit of this repository and report
|
||||
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
|
||||
hints from the user: $@
|
||||
|
||||
## Scope
|
||||
|
||||
If the user supplied focus hints above (a package path, a subsystem name, a
|
||||
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
|
||||
audit the whole repo, prioritising the highest-traffic packages first
|
||||
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Map the repo first**:
|
||||
- `ls` / `find` the top-level layout and list every Go package
|
||||
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
|
||||
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
|
||||
extension surface)
|
||||
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
|
||||
(e.g. "no dependency name leakage", "UI never imports extensions
|
||||
directly") — these define what counts as a violation
|
||||
|
||||
2. **Hunt for dead code**:
|
||||
- Run `go vet ./...` and capture warnings
|
||||
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
|
||||
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
|
||||
with zero non-test references inside the module are suspects
|
||||
- Check for unreferenced files, `// TODO: remove` markers, commented-out
|
||||
blocks, and `_ = x` discard patterns
|
||||
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
|
||||
them and include their output verbatim
|
||||
- **Do not delete anything** — list candidates with file:line and a
|
||||
confidence level (high / medium / low)
|
||||
|
||||
3. **Find unnecessary duplication**:
|
||||
- Look for near-identical function bodies, struct shapes, or switch
|
||||
statements across packages — `grep` for repeated function signatures
|
||||
and copy-pasted string literals / error messages is a fast first pass
|
||||
- Distinguish *coincidental* duplication (two things that happen to look
|
||||
alike but evolve independently) from *unnecessary* duplication (same
|
||||
intent, drifting in lockstep) — only flag the latter
|
||||
- For each cluster, propose where the extracted helper should live
|
||||
(which package, which file) and whether it crosses a boundary
|
||||
|
||||
4. **Check concerns / boundary violations**:
|
||||
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
|
||||
in exported signatures, and for dependency-name leakage in exported
|
||||
names / godoc (e.g. library jargon appearing in `LLM*` types)
|
||||
- **UI ↔ extensions**: grep `internal/ui/` for any import of
|
||||
`internal/extensions/` — per AGENTS.md the UI must not import
|
||||
extensions directly; converters in `cmd/root.go` should bridge them
|
||||
- **cmd vs internal**: business logic living in `cmd/` that should be
|
||||
in `internal/` (and vice versa)
|
||||
- **Cyclic risk**: packages that import each other transitively or that
|
||||
reach across sibling boundaries unexpectedly
|
||||
- For each violation, cite the offending import / signature with
|
||||
file:line
|
||||
|
||||
5. **Spot refactor opportunities**:
|
||||
- Long functions (>80 lines) doing multiple unrelated things
|
||||
- Deeply nested conditionals that flatten well with early returns
|
||||
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
|
||||
that could become helpers — but only where the wrapping context is
|
||||
genuinely uniform
|
||||
- Structs with too many fields that hint at split responsibilities
|
||||
- Exported APIs that would be cleaner with options structs / functional
|
||||
options
|
||||
- Tests that share setup boilerplate ripe for a helper
|
||||
- Flag each with: location, current shape (1-2 lines), proposed shape
|
||||
(1-2 lines), and estimated risk (low / medium / high)
|
||||
|
||||
6. **Cross-check against project rules**:
|
||||
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
|
||||
findings contradicts the documented gotchas (Yaegi interface ban,
|
||||
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
|
||||
"refactor" would reintroduce a known pitfall, drop it from the report
|
||||
and note why
|
||||
|
||||
7. **Write the report** as your final message (do not write it to disk)
|
||||
structured as:
|
||||
|
||||
```
|
||||
# Code Audit Report
|
||||
|
||||
## Summary
|
||||
- N dead-code candidates
|
||||
- N duplication clusters
|
||||
- N boundary violations
|
||||
- N refactor opportunities
|
||||
|
||||
## Dead Code
|
||||
### High confidence
|
||||
- path/to/file.go:LINE — symbol — reason
|
||||
|
||||
### Medium confidence
|
||||
...
|
||||
|
||||
## Duplication
|
||||
### Cluster: <short name>
|
||||
- Sites: file:line, file:line, …
|
||||
- Suggested home: package/path
|
||||
- Notes: …
|
||||
|
||||
## Boundary Violations
|
||||
- Rule: <which rule from AGENTS.md / project convention>
|
||||
- Offender: file:line
|
||||
- Fix sketch: …
|
||||
|
||||
## Refactor Opportunities
|
||||
- Location: file:line
|
||||
- Current: …
|
||||
- Proposed: …
|
||||
- Risk: low/medium/high
|
||||
- Why it's worth it: …
|
||||
|
||||
## Suggested Next Steps
|
||||
1. …
|
||||
2. …
|
||||
```
|
||||
|
||||
8. **End the report with an explicit reminder** that no files were modified,
|
||||
and recommend the user pick the highest-leverage items to act on
|
||||
manually (or via a follow-up `/fix-issue` style prompt) rather than
|
||||
running a sweeping refactor.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
|
||||
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
|
||||
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
|
||||
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
|
||||
straight to it
|
||||
- **Be honest about confidence**: false positives in a code audit are
|
||||
expensive — prefer "medium confidence, worth a look" over confidently
|
||||
wrong claims
|
||||
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
|
||||
anything that's purely stylistic unless it directly causes one of the
|
||||
four issue categories above
|
||||
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
|
||||
`vendor/`) and obvious third-party copies
|
||||
- **Don't propose architectural rewrites** — stay within the existing
|
||||
shape of the repo and recommend incremental, reviewable changes
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
description: Open a GitHub PR for the current branch using the repo's PR template
|
||||
---
|
||||
|
||||
Open a GitHub pull request for the current branch, filling out the repository's PR template with a description grounded in the actual commits and diff.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Verify the branch is pushed**:
|
||||
- `git status -sb` and `git log @{u}..HEAD --oneline 2>/dev/null` — if there is no upstream or unpushed commits, run `git push -u origin "$(git branch --show-current)"` first
|
||||
- If the working tree is dirty, stop and tell the user to commit first (suggest `/commit-push`)
|
||||
2. **Gather context**:
|
||||
- `git log origin/main..HEAD --oneline` — list of commits going into the PR
|
||||
- `git diff origin/main...HEAD --stat` then `git diff origin/main...HEAD` — read the actual changes
|
||||
- Identify the linked issue (from commit messages, branch name, or extra user input: $@) — capture as `Fixes #N` if applicable
|
||||
3. **Locate the PR template**:
|
||||
- Check `.github/pull_request_template.md`, `.github/PULL_REQUEST_TEMPLATE.md`, or `docs/pull_request_template.md`
|
||||
- If none exists, use a minimal `## Description` / `## Type of Change` / `## Checklist` structure
|
||||
4. **Draft the PR body** by filling out the template:
|
||||
- **Description**: 1–3 short paragraphs explaining *what* changed and *why*, grounded in the diff. Include a brief before/after example for new APIs when useful.
|
||||
- **Fixes #N**: only if there is a real linked issue
|
||||
- **Type of Change**: tick the single most accurate box with `[x]` (leave others as `[ ]`)
|
||||
- **Checklist**: tick items that are genuinely true (style, self-review, tests added, docs updated)
|
||||
- **Additional Information**: bullet list of added / modified files and any backward-compatibility notes
|
||||
- Remove template sections explicitly marked "remove if not applicable" (e.g. MCP Spec Compliance) when they don't apply
|
||||
5. **Write the body to a temp file**: `/tmp/pr-body-<branch-or-issue>.md` — never inline a long body via `--body`, always use `--body-file`
|
||||
6. **Choose the title**: prefer the subject of the primary commit if it already follows Conventional Commits; otherwise craft one in the same style (`<type>(<scope>): <imperative summary>`, ≤72 chars)
|
||||
7. **Create the PR**:
|
||||
```
|
||||
gh pr create \
|
||||
--title "<title>" \
|
||||
--body-file /tmp/pr-body-<...>.md \
|
||||
--base main \
|
||||
--head "$(git branch --show-current)"
|
||||
```
|
||||
Use the repo's actual default branch if it isn't `main` (`gh repo view --json defaultBranchRef -q .defaultBranchRef.name`)
|
||||
8. **Report the PR URL** returned by `gh` and stop
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff and commit messages — do **not** invent features that aren't in the code
|
||||
- One PR per logical change; if the branch contains unrelated commits, surface that and ask before continuing
|
||||
- Keep the description focused on reviewer-relevant information (what / why), not a replay of the diff
|
||||
- Only check checklist boxes that are actually satisfied; leave the rest unchecked rather than lying
|
||||
- If `gh` is not authenticated (`gh auth status` fails), stop and tell the user
|
||||
|
||||
$@
|
||||
@@ -2,7 +2,7 @@
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
Create a feature request for the Kit repository. The user wants to request: $@
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
@@ -16,7 +16,7 @@ This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$+`
|
||||
1. **Understand the request** from the user input: $@
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
@@ -16,7 +16,7 @@ This repository has structured issue templates. You MUST use the appropriate tem
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$+`:
|
||||
1. **Determine the issue type** from the user input: $@
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
---
|
||||
description: Implement the fix/feature/docs change requested by a GitHub issue
|
||||
---
|
||||
|
||||
Resolve GitHub issue #$1 by reading it, classifying it, and producing the appropriate code or doc change. **Stop once the working tree contains the change** — committing, pushing, and opening a PR are handled by `/commit-push` and `/create-pr`.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch the issue**:
|
||||
- Run: gh issue view $1 --json number,title,body,labels,state,author,comments
|
||||
- If the issue is closed, stop and ask the user whether to proceed
|
||||
- Read the **entire** thread including comments — the latest comment often refines the ask
|
||||
|
||||
2. **Classify the issue** from labels, title prefix, and body content:
|
||||
- `bug` / `fix:` → reproduce, then fix
|
||||
- `enhancement` / `feature` / `feat:` → design, then implement
|
||||
- `documentation` / `docs:` → locate and update docs
|
||||
- `question` / `discussion` → answer in a comment, do **not** write code
|
||||
- Anything else → ask the user how to proceed
|
||||
|
||||
3. **Create a working branch** off the default branch:
|
||||
- `git checkout main && git pull --ff-only`
|
||||
- Branch name: <type>/$1-<slug> (e.g. `fix/42-borderColor-ignored`, `feat/57-keyboard-clear`, `docs/63-widget-lifecycle`)
|
||||
|
||||
4. **Do the work** based on type:
|
||||
|
||||
### Bug (`bug` label / `fix:` title)
|
||||
- Reproduce the failure first (write a failing test if feasible) — if you cannot reproduce, comment on the issue asking for clarification and stop
|
||||
- Locate the root cause; do not patch symptoms
|
||||
- Add or extend a regression test that fails before and passes after the fix
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Feature (`enhancement` / `feature` label / `feat:` title)
|
||||
- Re-read the motivation and proposed implementation in the issue body
|
||||
- For large, ambiguous, or breaking changes, sketch the design in a comment on the issue and wait for sign-off before writing code
|
||||
- Implement behind sensible defaults; add godoc on every exported symbol
|
||||
- Add unit tests covering the new behaviour and edge cases
|
||||
- Update `README.md` / `docs/` if the public surface changed
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Documentation (`documentation` label / `docs:` title)
|
||||
- Open the file/URL referenced in the issue's "Documentation Location"
|
||||
- Apply the suggested improvement; verify code samples compile (`go build ./...`)
|
||||
- No tests required, but run `golangci-lint run` if Go files were touched
|
||||
|
||||
5. **Report**:
|
||||
- Branch name (`git branch --show-current`)
|
||||
- Summary of files changed (`git status -s`) and the diff highlights
|
||||
- Test/lint results (pass/fail with key output)
|
||||
- Suggest the next step explicitly:
|
||||
- `/commit-push` to commit with a Conventional Commit subject (the message should reference `(#$1)` and include `Fixes #$1` so merge auto-closes)
|
||||
- then `/create-pr $1` to open the pull request
|
||||
|
||||
## Guidelines
|
||||
|
||||
- This prompt **stops at a clean working tree with the change applied** — do not run `git commit`, `git push`, or `gh pr create`
|
||||
- If the issue is unclear, post a clarifying comment on the issue and stop; do not guess
|
||||
- Keep the change scoped to the issue; surface unrelated cleanups separately
|
||||
- For breaking changes or architecture shifts, propose the design on the issue first and wait for maintainer sign-off
|
||||
- If the issue is a duplicate or already fixed on `main`, comment with the reference and stop
|
||||
- Do not close the issue manually — the eventual PR's `Fixes #$1` handles that on merge
|
||||
+48
-13
@@ -2,7 +2,7 @@
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
@@ -16,30 +16,64 @@ It becomes a `/slug` slash command in the kit input box — typed as `/filename`
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
Body text of the prompt. Reference user-supplied arguments
|
||||
with positional placeholders (see "Argument placeholders" below).
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
- **Required args**: kit infers required positional args from the highest `$N` it finds *outside* backtick/tilde code fences — a stray `$2` in active prose means kit will refuse to run without 2 arguments
|
||||
|
||||
## Argument placeholders
|
||||
|
||||
kit performs shell-style substitution before sending the prompt to the model:
|
||||
|
||||
- `$1`, `$2`, … — positional arguments (1-indexed)
|
||||
- `${1}`, `${2}`, … — same, brace form (use when followed by digits/letters: `${1}_suffix`)
|
||||
- `$@` — all arguments joined by spaces (zero or more, optional)
|
||||
- `$+` — all arguments, **at least one required**
|
||||
- `$ARGUMENTS` / `${ARGUMENTS}` — alias for `$@`
|
||||
- `${@:N}` — args from the Nth onwards (1-indexed, bash-style)
|
||||
- `${@:N:L}` — `L` args starting from the Nth
|
||||
|
||||
### ⚠️ Critical: code fences and inline code preserve placeholders verbatim
|
||||
|
||||
Anything inside triple-backtick fences, `~~~` fences, or single-backtick `inline` code spans is **left untouched** so example code samples don't get corrupted. That means:
|
||||
|
||||
- An inline-coded `gh issue view $1` stays literal `$1` in the model's input ❌
|
||||
- The same command without backticks: gh issue view $1 → expands to `gh issue view 42` ✓
|
||||
|
||||
**Rule of thumb:** if you want a placeholder to substitute, keep it outside backticks and fences. If you want a literal `$1` in the output (e.g. teaching the user shell syntax), put it inside backticks.
|
||||
|
||||
### Workarounds for "I want it to look like code AND substitute"
|
||||
|
||||
1. **Drop the backticks** around just the placeholder portion — the rest can still read as a command line in prose
|
||||
2. **Use a 4-space-indented code block** instead of a triple-backtick fence — kit only skips backtick/tilde fences, so indentation-style code blocks still get substitution:
|
||||
|
||||
git push -u origin "$(git branch --show-current)"
|
||||
gh pr create --title "fix: ... (#$1)" --base main
|
||||
|
||||
3. **Bind once, reference loosely**: put `Issue: $1` at the top in prose, then leave the backticked examples literal — the model will substitute mentally
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
1. **Understand the workflow** the user described in $@ — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
4. **Decide on arguments**:
|
||||
- No args needed → omit placeholders entirely
|
||||
- One required value (issue number, PR url, file path) → use `$1`
|
||||
- Free-form trailing context → end with a single `$@` line
|
||||
- Multiple distinct values → use `$1`, `$2`, … and document each at the top
|
||||
5. **Draft the body**:
|
||||
- Open with a single sentence stating the goal, weaving in `$1`/`$@` where the value belongs
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
- **Audit every backtick and code fence**: any `$N` or `$@` inside them will not expand — was that intentional? If not, apply one of the workarounds above
|
||||
6. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
7. **Verify substitution** by mentally (or actually) replacing `$1`/`$@` with a sample value and confirming every reference resolves — and that the prompt's *own* example snippets don't accidentally bump the required-arg count (wrap illustrative `$N` examples in triple-backtick fences, not 4-space indentation, so `RequiredArgs()` ignores them)
|
||||
8. **Confirm** by showing the final file content and the slash command that activates it (e.g. `/code-review 42`)
|
||||
|
||||
## Guidelines
|
||||
|
||||
@@ -47,3 +81,4 @@ $1 $2 etc. for positional arguments.
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
- When in doubt about substitution behaviour, write the file and run `/<slug> testvalue` once to confirm — wrong placement of backticks is the #1 failure mode
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
description: Audit and update project documentation (README and docs site) for a recent change
|
||||
---
|
||||
|
||||
Review recent code changes, identify all documentation surfaces that should
|
||||
mention them, and update each one — grounded in the actual diff, not guesses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Identify the change**:
|
||||
- If the user input ($@) names a commit / PR / branch / topic, use that as the focus
|
||||
- Otherwise inspect `git log origin/main..HEAD --oneline` and `git diff origin/main...HEAD --stat` to discover what shipped on the current branch
|
||||
- Read the actual diff (`git diff origin/main...HEAD`) — never document features that aren't in the code
|
||||
|
||||
2. **Inventory the doc surfaces**:
|
||||
- `README.md` at the repo root
|
||||
- Any docs site (commonly `www/`, `docs/`, `site/`) — list its pages and identify the one(s) most thematically related to the change
|
||||
- Inline godoc / API reference comments on the new exported symbols
|
||||
- `CHANGELOG.md` if the project keeps one
|
||||
- Any `examples/` directory entries that demonstrate the affected area
|
||||
|
||||
3. **Audit each surface** with `grep`:
|
||||
- Search for the names of related existing APIs (e.g. if you added `IterTools`, grep for `ListTools`) to find every page that already discusses the area
|
||||
- Decide for each hit: does it need a cross-reference, a side-by-side comparison, or to stay untouched?
|
||||
|
||||
4. **Decide where new content lives**:
|
||||
- Prefer extending an existing page over creating a new one
|
||||
- For a docs site, place new sections near related content (check the page's `## Heading` outline first)
|
||||
- Skip surfaces that genuinely don't apply (e.g. a server-focused README for a client-only change) and say so explicitly
|
||||
|
||||
5. **Draft the updates**:
|
||||
- Lead with a one-sentence statement of what's new and why
|
||||
- Show concrete code examples copied from real signatures — verify against the source files
|
||||
- Include a comparison / "when to use which" table when adding an alternative to an existing API
|
||||
- Note backwards-compatibility behaviour if relevant
|
||||
|
||||
6. **Verify the docs build** before committing:
|
||||
- For vocs / docusaurus / mkdocs sites, run the local build command (e.g. `npx vocs build`, `mkdocs build`) and fix any MDX/markdown errors
|
||||
- For godoc, run `go vet ./...` and `go doc <pkg> <Symbol>` to sanity-check rendering
|
||||
|
||||
7. **Report**:
|
||||
- List every file changed and every file deliberately left alone (with a one-line reason)
|
||||
- Suggest the next step (typically `/commit-push`) — do not auto-commit unless asked
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff before writing anything — invented API names erode trust faster than missing docs
|
||||
- One change per doc commit; keep doc updates separate from code changes when possible
|
||||
- Match the existing voice and formatting of each surface (headings, code-fence languages, table styles)
|
||||
- Prefer linking between pages over duplicating content
|
||||
|
||||
$@
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"permission": {
|
||||
"external_directory": {
|
||||
"~/go/**": "deny"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
# Autoscroll Fix - Final Summary
|
||||
|
||||
## Root Cause
|
||||
|
||||
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
|
||||
|
||||
### The Problem
|
||||
|
||||
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
|
||||
2. The `Height()` method returns `0` when `cachedRender == ""`
|
||||
3. `GotoBottom()` was calling:
|
||||
```go
|
||||
itemHeight := item.Height() // Returns 0 for reasoning
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width) // Renders but doesn't cache (reasoning)
|
||||
itemHeight = item.Height() // Still returns 0!
|
||||
}
|
||||
```
|
||||
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
|
||||
|
||||
## The Solution
|
||||
|
||||
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
|
||||
|
||||
```go
|
||||
// OLD: item.Height() which checks cached render
|
||||
itemHeight := item.Height()
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width)
|
||||
itemHeight = item.Height() // Still might be 0!
|
||||
}
|
||||
|
||||
// NEW: Calculate from rendered string directly
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
```
|
||||
|
||||
This works for **all** items regardless of whether they cache their render or not.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### `internal/ui/scrolllist.go`
|
||||
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
|
||||
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
|
||||
|
||||
### `internal/ui/model.go`
|
||||
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
|
||||
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
|
||||
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
|
||||
|
||||
## Testing Results
|
||||
|
||||
✅ **Test prompt**: "explore this repo"
|
||||
|
||||
**Before fix**:
|
||||
- Autoscroll stopped after reasoning block completed
|
||||
- Viewport stuck showing end of reasoning ("Thought for 203ms")
|
||||
- Assistant response streamed off-screen below
|
||||
|
||||
**After fix**:
|
||||
- Autoscroll works throughout reasoning block
|
||||
- Autoscroll continues during reasoning → assistant transition
|
||||
- Viewport stays at bottom showing latest assistant content
|
||||
- Final position shows end of response (build commands section)
|
||||
|
||||
## Behavior Verified
|
||||
|
||||
1. ✅ Streaming text auto-scrolls to bottom
|
||||
2. ✅ Works across reasoning → assistant transition
|
||||
3. ✅ Manual scroll up (PgUp) disables autoscroll
|
||||
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
|
||||
5. ✅ Accurate positioning with no offset errors
|
||||
|
||||
## Performance Note
|
||||
|
||||
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
|
||||
- `Render()` is already optimized with caching for non-reasoning items
|
||||
- `GotoBottom()` is only called during content updates (not every frame)
|
||||
- Reasoning blocks need to render anyway for live duration updates
|
||||
- This matches iteratr's approach of ensuring items are rendered before height calculations
|
||||
@@ -18,7 +18,8 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash (with interactive sudo password prompt), read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **Smart @ Attachments**: Binary files auto-detected via MIME type, MCP resources via `@mcp:server:uri`
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -28,7 +29,7 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
- **Session Management**: Tree-based conversation history with branching support
|
||||
- **Non-Interactive Mode**: Script-friendly positional args with JSON output
|
||||
- **ACP Server**: Run Kit as an [Agent Client Protocol](https://agentclientprotocol.com) agent over stdio
|
||||
- **Go SDK**: Embed Kit in your own applications
|
||||
- **Go SDK**: Embed Kit in your own applications with full agent lifecycle events (30+ event types) and behavior-modifying hooks
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -125,8 +126,14 @@ model: anthropic/claude-sonnet-latest
|
||||
max-tokens: 4096
|
||||
temperature: 0.7
|
||||
stream: true
|
||||
thinking-level: off # off, none, minimal, low, medium, high
|
||||
no-core-tools: false # set to true to disable all built-in core tools
|
||||
```
|
||||
|
||||
All of the above keys can also be set programmatically via the SDK
|
||||
(`kit.Options.MaxTokens`, `Options.Temperature`, `Options.ThinkingLevel`, etc.)
|
||||
without touching config files — see [SDK options](#with-options).
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
@@ -151,6 +158,16 @@ mcpServers:
|
||||
search:
|
||||
type: remote
|
||||
url: "https://mcp.example.com/search"
|
||||
|
||||
pubmed:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers that don't require auth
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
url: "https://builds.mcp.example.com"
|
||||
tasksMode: always # async task execution — see MCP Tasks below
|
||||
```
|
||||
|
||||
## CLI Reference
|
||||
@@ -179,19 +196,22 @@ mcpServers:
|
||||
--compact Enable compact output mode
|
||||
--auto-compact Auto-compact conversation near context limit
|
||||
|
||||
# Extensions
|
||||
# Extensions and tools
|
||||
--extension, -e Load additional extension file(s) (repeatable)
|
||||
--no-extensions Disable all extensions
|
||||
--no-core-tools Disable all built-in core tools (bash, read, write, edit, grep, find, ls, subagent)
|
||||
--prompt-template Load a specific prompt template by name
|
||||
--no-prompt-templates Disable prompt template loading
|
||||
|
||||
# Generation parameters
|
||||
--max-tokens Maximum tokens in response (default: 4096)
|
||||
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
|
||||
--temperature Randomness 0.0-1.0 (default: 0.7)
|
||||
--top-p Nucleus sampling 0.0-1.0 (default: 0.95)
|
||||
--top-k Limit top K tokens (default: 40)
|
||||
--stop-sequences Custom stop sequences (comma-separated)
|
||||
--thinking-level Extended thinking level: off, minimal, low, medium, high (default: off)
|
||||
--frequency-penalty Penalize frequent tokens 0.0-2.0 (default: 0.0)
|
||||
--presence-penalty Penalize present tokens 0.0-2.0 (default: 0.0)
|
||||
--thinking-level Extended thinking level: off, none, minimal, low, medium, high (default: off)
|
||||
|
||||
# System
|
||||
--config Config file path (default: ~/.kit.yml)
|
||||
@@ -203,9 +223,10 @@ mcpServers:
|
||||
|
||||
```bash
|
||||
# Authentication (for OAuth-enabled providers)
|
||||
kit auth login [provider] # Start OAuth flow (e.g., anthropic)
|
||||
kit auth logout [provider] # Remove credentials for provider
|
||||
kit auth status # Check authentication status
|
||||
kit auth login [provider] # Start OAuth flow (e.g., anthropic)
|
||||
kit auth login [provider] --set-default # Set provider's default model as system default
|
||||
kit auth logout [provider] # Remove credentials for provider
|
||||
kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
@@ -287,7 +308,7 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
@@ -321,6 +342,7 @@ See the `examples/extensions/` directory:
|
||||
- [`auto-commit.go`](examples/extensions/auto-commit.go) - Auto-commit on shutdown
|
||||
- [`bookmark.go`](examples/extensions/bookmark.go) - Bookmark conversations
|
||||
- [`branded-output.go`](examples/extensions/branded-output.go) - Branded output rendering
|
||||
- [`bridge-demo.go`](examples/extensions/bridge_demo.go) - Bridged SDK API demo (tree navigation, skills, templates, model resolution)
|
||||
- [`compact-notify.go`](examples/extensions/compact-notify.go) - Notification on compaction
|
||||
- [`confirm-destructive.go`](examples/extensions/confirm-destructive.go) - Confirm destructive operations
|
||||
- [`context-inject.go`](examples/extensions/context-inject.go) - Inject context into conversations
|
||||
@@ -428,10 +450,13 @@ Focus on $1 specifically.
|
||||
|
||||
**Argument placeholders:**
|
||||
- `$1`, `$2`, etc. — Individual arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments (zero or more)
|
||||
- `$+` — All arguments (one or more required; error if none given)
|
||||
- `${@:2}` — Arguments from position 2 onwards
|
||||
- `${@:1:3}` — 3 arguments starting at position 1
|
||||
|
||||
Placeholders inside fenced code blocks (```) and inline code spans are ignored.
|
||||
|
||||
Disable templates with `--no-prompt-templates` or load a specific template with `--prompt-template <name>`.
|
||||
|
||||
## Session Management
|
||||
@@ -480,6 +505,15 @@ During an interactive session, use these slash commands:
|
||||
| `/fork` | Fork to new session from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
|
||||
### Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Description |
|
||||
|----------|-------------|
|
||||
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
|
||||
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
|
||||
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
|
||||
| `↑` / `↓` | Navigate prompt history |
|
||||
|
||||
## Go SDK
|
||||
|
||||
Embed Kit in your Go applications:
|
||||
@@ -522,9 +556,23 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: "You are a helpful bot",
|
||||
ConfigFile: "/path/to/config.yml",
|
||||
MaxSteps: 10,
|
||||
Streaming: true,
|
||||
Streaming: ptr(true), // *bool: nil = unset (default true), &false = off
|
||||
Quiet: true,
|
||||
|
||||
// Generation parameters (override env/config/per-model defaults)
|
||||
MaxTokens: 16384, // 0 = auto-resolve (env → config → per-model → 8192 floor)
|
||||
ThinkingLevel: "medium", // "off", "none", "minimal", "low", "medium", "high"
|
||||
Temperature: ptr(float32(0.2)), // pointer so 0.0 != unset; nil = provider default
|
||||
TopP: nil, // nil = leave provider/per-model default
|
||||
TopK: nil,
|
||||
FrequencyPenalty: nil,
|
||||
PresencePenalty: nil,
|
||||
|
||||
// Provider configuration (override env/config without reaching into viper)
|
||||
ProviderAPIKey: "sk-...", // "" = use config / provider env var
|
||||
ProviderURL: "https://proxy.internal/v1", // "" = provider default
|
||||
TLSSkipVerify: false, // only takes effect when true
|
||||
|
||||
// Session options
|
||||
SessionPath: "./session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
@@ -533,7 +581,9 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
// Tool options
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
DisableCoreTools: true, // Disable all built-in core tools; also controllable via
|
||||
// --no-core-tools flag, KIT_NO_CORE_TOOLS env var,
|
||||
// or no-core-tools: true in .kit.yml
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
@@ -545,6 +595,108 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
**Generation & provider fields** (added in v0.55+) let SDK consumers configure
|
||||
Kit entirely in-code without `viper.Set()` workarounds or shipping a `.kit.yml`.
|
||||
Precedence is `Options` > `KIT_*` env vars > `.kit.yml` > per-model defaults
|
||||
(`modelSettings` / `customModels`) > provider-level defaults. Sampling params
|
||||
are pointer types so explicit `0.0` is distinguishable from "leave alone"; a
|
||||
non-zero `MaxTokens` suppresses automatic right-sizing the same way `--max-tokens`
|
||||
does on the CLI.
|
||||
|
||||
### Functional options (`NewAgent`)
|
||||
|
||||
For simple programmatic setups, `kit.NewAgent` offers an ergonomic
|
||||
functional-options front door over `kit.New`. Streaming is **enabled by
|
||||
default**; pass `kit.WithStreaming(false)` to opt out.
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.WithMaxTokens(8192),
|
||||
kit.WithThinkingLevel("medium"),
|
||||
kit.Ephemeral(), // in-memory session, no persistence
|
||||
)
|
||||
```
|
||||
|
||||
Available options: `WithModel`, `WithSystemPrompt`, `WithStreaming`,
|
||||
`WithMaxTokens`, `WithThinkingLevel`, `WithTools`, `WithExtraTools`,
|
||||
`WithProviderAPIKey`, `WithProviderURL`, `WithConfigFile`, `WithDebug`, and
|
||||
`Ephemeral`. For advanced configuration not covered by the helpers (custom MCP
|
||||
config, in-process MCP servers, session backends, MCP task tuning) construct an
|
||||
`Options` value explicitly and call `kit.New`.
|
||||
|
||||
### Per-instance config isolation
|
||||
|
||||
Each `kit.New` / `kit.NewAgent` call owns an **isolated configuration store**,
|
||||
so constructing multiple Kit instances in the same process is safe: setting the
|
||||
model, thinking level, or generation parameters on one never affects another,
|
||||
and runtime mutators (`SetModel`, `SetThinkingLevel`) only touch the owning
|
||||
instance. This makes subagent spawning and multi-Kit embedding race-free with
|
||||
no external synchronization required.
|
||||
|
||||
### MCP OAuth (remote MCP servers)
|
||||
|
||||
When a remote MCP server returns 401, Kit runs the full OAuth flow (dynamic
|
||||
client registration → PKCE → token exchange → persistence) but delegates the
|
||||
user-facing step — showing the authorization URL and receiving the callback —
|
||||
to an `MCPAuthHandler` that you pass explicitly via `Options.MCPAuthHandler`.
|
||||
If nil, OAuth is disabled and the authorization-required error surfaces to the
|
||||
caller; the SDK never auto-opens a browser or binds a localhost port.
|
||||
|
||||
```go
|
||||
// CLI/TUI apps: opens the system browser + prints status to stderr.
|
||||
authHandler, _ := kit.NewCLIMCPAuthHandler()
|
||||
defer authHandler.Close()
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPAuthHandler: authHandler,
|
||||
})
|
||||
|
||||
// Custom UX: reuse the SDK's port + callback server, supply your own
|
||||
// presentation via OnAuthURL (TUI modal, QR code, web redirect, etc.).
|
||||
// h, _ := kit.NewDefaultMCPAuthHandler()
|
||||
// h.OnAuthURL = func(server, authURL string) { myUI.Show(server, authURL) }
|
||||
//
|
||||
// Full control (web apps, daemons): implement kit.MCPAuthHandler yourself —
|
||||
// no localhost binding, no side effects.
|
||||
```
|
||||
|
||||
Tokens are persisted to `$XDG_CONFIG_HOME/.kit/mcp_tokens.json` by default; swap
|
||||
in a custom `MCPTokenStoreFactory` for encrypted, DB-backed, or in-memory
|
||||
storage. See the [SDK options docs](/sdk/options#mcp-oauth-authorization) for
|
||||
the full matrix.
|
||||
|
||||
### MCP Tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`, so cooperating MCP servers can respond to `tools/call`
|
||||
with a `taskId` instead of blocking the connection. Kit then polls
|
||||
`tasks/get` / `tasks/result` until the task reaches a terminal state, and
|
||||
best-effort `tasks/cancel`s on context cancellation.
|
||||
|
||||
Defaults are safe — a server that doesn't advertise task capability runs
|
||||
synchronously, exactly as before. Opt in per server via `tasksMode` in
|
||||
`.kit.yml` (`auto` | `never` | `always`) or programmatically through the SDK:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskTimeout: 15 * time.Minute,
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s: %s", p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
```
|
||||
|
||||
See the [configuration docs](/configuration#mcp-tasks-long-running-tools) and
|
||||
[SDK options → MCP Tasks](/sdk/options#mcp-tasks) for the full surface.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
@@ -565,7 +717,28 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. Binary data (images, audio, etc.) in `ToolOutput.Data` is automatically forwarded to the LLM when `MediaType` is set. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
#### Return Helpers
|
||||
|
||||
| Helper | Description |
|
||||
| --- | --- |
|
||||
| `kit.TextResult(content)` | Successful text result |
|
||||
| `kit.ErrorResult(content)` | Error result (LLM sees it as a tool error) |
|
||||
| `kit.ImageResult(content, data, mediaType)` | Image result with binary data (e.g. `"image/png"`) |
|
||||
| `kit.MediaResult(content, data, mediaType)` | Non-image media result (e.g. `"audio/mpeg"`) |
|
||||
|
||||
#### ToolOutput Fields
|
||||
|
||||
```go
|
||||
kit.ToolOutput{
|
||||
Content: "result text", // text returned to the LLM
|
||||
IsError: false, // true = LLM sees this as an error
|
||||
Data: pngBytes, // optional binary data (images, audio)
|
||||
MediaType: "image/png", // MIME type for binary Data
|
||||
Metadata: map[string]any{}, // opaque metadata for hooks/UI (not sent to LLM)
|
||||
}
|
||||
```
|
||||
|
||||
### With Callbacks
|
||||
|
||||
@@ -582,7 +755,7 @@ unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
unsub3 := host.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
@@ -619,6 +792,45 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Runtime Skills & Context Files
|
||||
|
||||
For multi-tenant hosts (chatbots, per-user agents, web services), the SDK
|
||||
lets you swap skills and `AGENTS.md`-style context files **after** Kit
|
||||
construction. Every mutation recomposes the system prompt and applies it to
|
||||
the agent so the next turn picks up the new instructions — no restart needed.
|
||||
|
||||
```go
|
||||
// Programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Per-user AGENTS.md content pulled from a database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state on logout.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set atomically.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Skills dedupe by `Name`, context files dedupe by `Path` (which can be any
|
||||
opaque identifier — it doesn't have to be a real filesystem path). All
|
||||
mutators and readers (`GetSkills`, `GetContextFiles`) are safe to call
|
||||
concurrently from multiple goroutines. See the [SDK overview docs](/sdk/overview#runtime-skills-and-context-files)
|
||||
for the full reference.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Subagent Pattern
|
||||
@@ -760,6 +972,31 @@ This automatically defaults to `custom/custom` without needing to specify a mode
|
||||
- Reasoning and temperature support
|
||||
- Optional `CUSTOM_API_KEY` environment variable or `--provider-api-key` flag
|
||||
|
||||
### Auto-routed Providers
|
||||
|
||||
Any provider in the [models.dev](https://models.dev) database can be used as
|
||||
`provider/model` without a dedicated native integration. Kit auto-routes the
|
||||
request through the matching **wire protocol** based on the provider's npm package
|
||||
(or per-model override), using its `api` URL as the base:
|
||||
|
||||
| npm package | Wire protocol |
|
||||
|-------------|---------------|
|
||||
| `@ai-sdk/openai` | OpenAI (Responses API) |
|
||||
| `@ai-sdk/openai-compatible` | OpenAI (chat completions) |
|
||||
| `@ai-sdk/anthropic` | Anthropic |
|
||||
| `@ai-sdk/google` | Google Gemini |
|
||||
|
||||
Providers with an `api` URL but an unrecognized npm package fall back to the
|
||||
OpenAI-compatible wire. Because routing follows the wire protocol, aggregator/proxy
|
||||
providers work across all of their models — including Claude, GPT, *and* Gemini
|
||||
routes:
|
||||
|
||||
```bash
|
||||
kit --model opencode/claude-haiku-4-5 "Hello" # → Anthropic wire
|
||||
kit --model opencode/gpt-5 "Hello" # → OpenAI wire
|
||||
kit --model opencode/gemini-3.5-flash "Hello" # → Google wire
|
||||
```
|
||||
|
||||
### Model String Format
|
||||
|
||||
```bash
|
||||
|
||||
+64
-4
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"charm.land/huh/v2"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -54,9 +55,13 @@ Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
|
||||
Example:
|
||||
Flags:
|
||||
--set-default Set this provider's default model as the system default
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai`,
|
||||
kit auth login openai
|
||||
kit auth login openai --set-default`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -99,10 +104,43 @@ Example:
|
||||
RunE: runAuthStatus,
|
||||
}
|
||||
|
||||
var (
|
||||
loginSetDefault bool
|
||||
)
|
||||
|
||||
// defaultModels maps providers to their recommended default models.
|
||||
// These are used when --set-default flag is passed to auth login.
|
||||
var defaultModels = map[string]string{
|
||||
"anthropic": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"openai": "openai/gpt-5.4",
|
||||
}
|
||||
|
||||
// setDefaultModelIfRequested sets the default model for the given provider
|
||||
// if the --set-default flag was provided.
|
||||
func setDefaultModelIfRequested(provider string) error {
|
||||
if !loginSetDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
model, ok := defaultModels[provider]
|
||||
if !ok {
|
||||
return fmt.Errorf("no default model configured for provider: %s", provider)
|
||||
}
|
||||
|
||||
if err := ui.SaveModelPreference(model); err != nil {
|
||||
return fmt.Errorf("failed to save model preference: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("\n✓ Set default model to: %s\n", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
authCmd.AddCommand(authLoginCmd)
|
||||
authCmd.AddCommand(authLogoutCmd)
|
||||
authCmd.AddCommand(authStatusCmd)
|
||||
|
||||
authLoginCmd.Flags().BoolVar(&loginSetDefault, "set-default", false, "Set this provider's default model as the system default after login")
|
||||
}
|
||||
|
||||
func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
@@ -288,6 +326,17 @@ func loginAnthropic() error {
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for Anthropic API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
// Set default model if requested
|
||||
if err := setDefaultModelIfRequested("anthropic"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remind users how to set this as default if they didn't use --set-default
|
||||
if !loginSetDefault {
|
||||
fmt.Println("\n💡 To set Anthropic as your default model, run:")
|
||||
fmt.Println(" kit auth login anthropic --set-default")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -454,6 +503,17 @@ func loginOpenAI() error {
|
||||
fmt.Println("\n🎉 Your OAuth credentials will now be used for OpenAI API calls.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
// Set default model if requested
|
||||
if err := setDefaultModelIfRequested("openai"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remind users how to set this as default if they didn't use --set-default
|
||||
if !loginSetDefault {
|
||||
fmt.Println("\n💡 To set OpenAI as your default model, run:")
|
||||
fmt.Println(" kit auth login openai --set-default")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -504,13 +564,13 @@ func startOpenAICallbackServer(expectedState string) (*callbackServer, error) {
|
||||
}
|
||||
|
||||
// Return success page
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = fmt.Fprintf(w, `<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Authentication Successful</title></head>
|
||||
<body style="font-family: sans-serif; text-align: center; padding: 50px;">
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<h1>✓ Authentication Successful</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>`)
|
||||
|
||||
@@ -0,0 +1,473 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// extensionContextDeps groups the runtime dependencies needed to wire up
|
||||
// an extensions.Context for the interactive TUI mode.
|
||||
type extensionContextDeps struct {
|
||||
ctx context.Context
|
||||
cwd string
|
||||
modelName string
|
||||
interactive bool
|
||||
kitInstance *kit.Kit
|
||||
appInstance *app.App
|
||||
usageTracker *ui.UsageTracker
|
||||
}
|
||||
|
||||
// buildInteractiveExtensionContext returns an extensions.Context with every
|
||||
// field except Print / PrintInfo / PrintError populated. Callers must set
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
+241
-833
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ import (
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -26,7 +26,7 @@ func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -84,7 +84,7 @@ func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -130,11 +130,63 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_ConcurrentSubagents verifies no panics when multiple
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 5 subagents concurrently
|
||||
done := make(chan struct{}, 5)
|
||||
for i := range 5 {
|
||||
go func(idx int) {
|
||||
defer func() { done <- struct{}{} }()
|
||||
|
||||
callID := fmt.Sprintf("concurrent-%d", idx)
|
||||
task := fmt.Sprintf("concurrent task %d", idx)
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
})
|
||||
|
||||
// Emit many chunks rapidly
|
||||
for j := range 20 {
|
||||
_, _ = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("agent %d chunk %d", idx, j),
|
||||
})
|
||||
}
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
Response: "done",
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 5 {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Allow any final processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
//go:build ignore
|
||||
|
||||
// sudo-handler.go - Extension to handle sudo password prompts securely
|
||||
//
|
||||
// This extension intercepts bash commands containing "sudo" and:
|
||||
// 1. Checks if sudo credentials are already cached (via sudo -n)
|
||||
// 2. If not cached, prompts the user for their password (with masking)
|
||||
// 3. Temporarily sets SUDO_PASSWORD environment variable for execution
|
||||
// 4. The bash tool automatically uses sudo -S -p '' to pipe the password
|
||||
//
|
||||
// Usage: kit -e examples/extensions/sudo-handler.go
|
||||
//
|
||||
// Security notes:
|
||||
// - Password is only stored in memory for the duration of the session
|
||||
// - Password is never logged or displayed
|
||||
// - Each session requires re-authentication (sudo -k is used)
|
||||
// - The SUDO_PASSWORD env var is set only during tool execution
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
// cachedPassword stores the sudo password for the session
|
||||
cachedPassword string
|
||||
// hasCachedPassword tracks if we have a valid cached password
|
||||
hasCachedPassword bool
|
||||
// mu protects cached password access
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Init sets up the sudo handler extension
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName != "bash" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the command from tool input
|
||||
var input struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.Input), &input); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if command contains sudo
|
||||
if !containsSudo(input.Command) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we already have cached credentials
|
||||
mu.RLock()
|
||||
password := cachedPassword
|
||||
hasCached := hasCachedPassword
|
||||
mu.RUnlock()
|
||||
|
||||
if hasCached {
|
||||
// Use cached password
|
||||
os.Setenv("SUDO_PASSWORD", password)
|
||||
return nil
|
||||
}
|
||||
|
||||
// No cached password - prompt user
|
||||
result := ctx.PromptInput(ext.PromptInputConfig{
|
||||
Message: "🔐 Sudo password required for:\n " + truncateCommand(input.Command, 60),
|
||||
Placeholder: "Enter your password",
|
||||
})
|
||||
|
||||
if result.Cancelled {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "Sudo password prompt cancelled by user",
|
||||
}
|
||||
}
|
||||
|
||||
if result.Value == "" {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "No password provided",
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the password for this session
|
||||
mu.Lock()
|
||||
cachedPassword = result.Value
|
||||
hasCachedPassword = true
|
||||
mu.Unlock()
|
||||
|
||||
// Set environment variable for the bash tool to use
|
||||
os.Setenv("SUDO_PASSWORD", result.Value)
|
||||
|
||||
// Show confirmation (without revealing password)
|
||||
ctx.PrintInfo("Sudo password cached for this session")
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Clear cached password when session ends
|
||||
api.OnSessionShutdown(func(event ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
mu.Lock()
|
||||
cachedPassword = ""
|
||||
hasCachedPassword = false
|
||||
mu.Unlock()
|
||||
os.Unsetenv("SUDO_PASSWORD")
|
||||
})
|
||||
}
|
||||
|
||||
// containsSudo checks if the command contains sudo as a command (not in a string)
|
||||
func containsSudo(command string) bool {
|
||||
// Simple check for sudo as a word, not inside quotes or as part of another word
|
||||
lower := strings.ToLower(command)
|
||||
|
||||
// Check for sudo at start or after separators
|
||||
patterns := []string{
|
||||
"sudo ",
|
||||
"sudo\t",
|
||||
";sudo ",
|
||||
"&& sudo ",
|
||||
"|| sudo ",
|
||||
"| sudo ",
|
||||
"$(sudo ",
|
||||
"`sudo ",
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if command starts with sudo
|
||||
if strings.HasPrefix(lower, "sudo ") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// truncateCommand truncates a long command for display
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
@@ -42,4 +42,14 @@ defer host.Close()
|
||||
response, err := host.Prompt(ctx, "Hello!")
|
||||
```
|
||||
|
||||
Or use the functional-options constructor for quick setups (streaming defaults on):
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.Ephemeral(),
|
||||
)
|
||||
```
|
||||
|
||||
See the [SDK README](../../pkg/kit/README.md) for the full API reference.
|
||||
|
||||
@@ -62,7 +62,7 @@ func main() {
|
||||
}
|
||||
})
|
||||
// Subscribe to streaming chunks.
|
||||
host3.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
host3.OnMessageUpdate(func(e kit.MessageUpdateEvent) {
|
||||
fmt.Print(e.Chunk)
|
||||
})
|
||||
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.5
|
||||
charm.land/fantasy v0.17.2
|
||||
charm.land/bubbletea/v2 v2.0.6
|
||||
charm.land/fantasy v0.25.0
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/alecthomas/chroma/v2 v2.26.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/coder/acp-go-sdk v0.13.0
|
||||
github.com/fsnotify/fsnotify v1.10.1
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.48.0
|
||||
github.com/mark3labs/mcp-go v0.54.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/term v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -35,23 +35,23 @@ require (
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 // indirect
|
||||
github.com/aws/smithy-go v1.26.0 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -59,44 +59,46 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dlclark/regexp2 v1.12.0 // indirect
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.0 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.5 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.13 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/gjson v1.19.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
@@ -104,21 +106,21 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 // indirect
|
||||
go.opentelemetry.io/otel v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.44.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/crypto v0.52.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/api v0.282.0 // indirect
|
||||
google.golang.org/genai v1.58.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
|
||||
google.golang.org/grpc v1.81.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -129,13 +131,13 @@ require (
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/spf13/pflag v1.0.10
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0
|
||||
)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.5 h1:TQlLFqxo39AAHSVuOhJ5D3nH7O9Nk8JGinsfWQ4y1U4=
|
||||
charm.land/bubbletea/v2 v2.0.5/go.mod h1:dvbsYZD+MHkdIZl+Z67D212hEvB+GII2tfH8f9SnoDw=
|
||||
charm.land/fantasy v0.17.2 h1:ojTMufMxY/PVH7TzYUxht2SVkvD90iCTJfmPR6c8BR8=
|
||||
charm.land/fantasy v0.17.2/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
|
||||
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
|
||||
charm.land/fantasy v0.25.0 h1:oXOWY1ivmTSnhYGzAolscF8zKtavWZyBWv0LHRSwN5Q=
|
||||
charm.land/fantasy v0.25.0/go.mod h1:8QrWUzIcKwZQP+aAnC9vLu3iID6hu9/Jt+rPMiieBkc=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
@@ -16,8 +16,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||
@@ -28,42 +28,42 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1 h1:2X21EdxGZNv5GF9mG5u+uzc02GCFyGxbcBm3Grd9A78=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1/go.mod h1:lxhRRa9H4hPmRLOOdYga4zkQIQjq3dtrrdwQeCfu78Y=
|
||||
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
|
||||
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 h1:sRs7nG6/RiEBZ/K5UO2sNw0w40U02Nmz1VtARloTZXk=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 h1:qRhIJMbevHUvIE7X4TK8N8zye5+5AhapcslPrvB+qKE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19/go.mod h1:RbJ24nfoya63+Mf5VI+CGCGk9vEdv28xPeii+gojRYs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 h1:GcXQz2M/0ZvMo0v5DakUqbDBeBM1ZNaivkolEF4Esgw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18/go.mod h1:sHJ06tMGcD3ZpmMyJqV+VBsGilhSIZPIN+ZFy5Dg0C4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 h1:FQm5ApnyzkuJdXLGskPce83CK1CQKC4RUnIHKVe4BU4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24/go.mod h1:JsC7dqQc55MlZ5mvNsDMMge71u8pVcSzU3RNz2h/5yQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 h1:u6kJU2i0va1AgtJsH3RdWKWqHULlTh7zHwb35Womf74=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24/go.mod h1:7GY+xLcXOFUpCkNwDReft9qOAVg54A4/AnjHIU7sSAY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 h1:Xhbcf3KugX6vX7SDyUK205Oicyfg7EGuvoVNyP5L6DM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24/go.mod h1:rwDgb2HNOGZsnTHylOUedM7Vnl+bCfnXDqUNPsFWYfk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 h1:54CTMmlJ71Rk2dYvM9qZOob+39wjlVja2zDLxCu69Ew=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25/go.mod h1:BZaHqxsS9vN1fvV5EfEl0OBLOk5+AajWsMu6MjqnZB4=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 h1:CQW2FTrflfoslYWLf3fv7vG28Q219+v8YJS5QTQb2+Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24/go.mod h1:Xfx13T+u3nH6EEzgl9fBSO6nDRmze1FvnZNYkctQ2zw=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 h1:yQo3eZ5qFaL1sJWqs1nL6j3yPHA2/R7c6tQ4T+0IO10=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0/go.mod h1:3Zzou41Qt/ueXfIzHvTEjDNuR5IjCUBVF01SNhrt1e8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 h1:ApLTFdAZfDhZSiY5uskwECKHkSNNF83y2Ru2r7SezWA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18/go.mod h1:A9K9qx2l6nK89hp+a350FdGfRkrkH5HdiEjHbiy/Q/c=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 h1:4VD7TIZOGzehrgQ8vDE+1c6BQW4ErZPGY8ohZT5LXEE=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1/go.mod h1:er0SFJfdV89Rit5hIJu/EXtv+qC2XMnxoksLmcUFkqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 h1:XKnxlM4KZH1gktcsh3zSWc7GW4KivEv/OkifmHOhCUY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2/go.mod h1:KJYmkQaFB3SUW2j3aBkPsxNmAb4ZsSOvbvCpuxzHJA0=
|
||||
github.com/aws/smithy-go v1.26.0 h1:9ouqbi+NyKP7fV3Te7UElCwdAb6Y8uk7LGwPE5tVe/s=
|
||||
github.com/aws/smithy-go v1.26.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
@@ -86,8 +86,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e h1:O5hZFj55wZQWxMiRtQLa3uLKhZGZGS/j8M3OXinQlrw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654 h1:FpSYhY28ucg9ZRr+2wj67FAQ0Ey5yiK0072PmRDJNek=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654/go.mod h1:hFpumms29Smx3LStRfku8vcCTBe1Kq8aCXtHUJa3mjY=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -98,14 +98,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 h1:6F/6bu5nBLjodsvaU5xAszTaxtHrDU5UiJarpMPZj48=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d h1:sMilwx1YIYTrQva6jsB522AoRYAerNaDIKP4ZPtUq0A=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 h1:RiZFY92Ug9iz1CenzxSSQla2Z3WflsR7bIuXq40JlpU=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d h1:RxcAR+vJCoD8QqT1cqLtkQKw+1cqvjqnu5IpPqYzPco=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -124,15 +124,17 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.6.3 h1:LsXQytehdjKIYJnoVWON/nf7mqbiarnyuyE3rrjBsXQ=
|
||||
github.com/coder/acp-go-sdk v0.6.3/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/coder/acp-go-sdk v0.13.0 h1:IAKBDIbe/iBfKAGikeIndzb8fowt4ioD+gCtSU4HwMA=
|
||||
github.com/coder/acp-go-sdk v0.13.0/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
|
||||
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
|
||||
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 h1:LCUGyd9Wf+r+VVOl8Ny38JTpWJcAsdVnCIuhhtthmKw=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
@@ -146,10 +148,10 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 h1:NZBJxCpbHS1gzS6xAmyxbJznosZIIPk9IB42v62UvKA=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
|
||||
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -167,16 +169,16 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 h1:F/VPrx0YPBdksZJQdCAp0WUsqnNmZpUZszzfYt0M5Dw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16/go.mod h1:9Yb0eAkH/Xqhvv3zbeKf/+wMJqCeocWc6KIhDvEAuYE=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
@@ -187,14 +189,14 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.0 h1:i7L3U2yurg+xhokITtJ0k+mjHnXqkoyz8ju5Wb7W8Oc=
|
||||
github.com/kaptinlin/go-i18n v0.4.0/go.mod h1:njA6x0+4MWGcLWT0KLrwekhRPmze1Hnstf2+VJFzwpM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 h1:a0ufTd5liiUubIGeGxpSTnNS8ZSrN4DV01/wGFmfzMs=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20/go.mod h1:FqdEPfQLkqVBX7OBRMPgYwUPvKYJohFD9Ok1BMzCfIo=
|
||||
github.com/kaptinlin/go-i18n v0.4.5 h1:9tIlo5A0RXth+yZJO2MG7Bhpu/X9PlzQnGz/qyYWNoY=
|
||||
github.com/kaptinlin/go-i18n v0.4.5/go.mod h1:mU/7BH4molY5lGZYBwBRKAaiJ70dWRHuqmQ0/pFLGno=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 h1:iJ197e8n+WwqaqBsa53FqG3rPJCg5oijyFXEXNWWC3E=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
|
||||
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 h1:D6jiXFsKW4/JG2CMddv/F6Rev9KVbCRKEzzV5QOAcpc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -203,10 +205,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.48.0 h1:o+MXuGW/HCeR2ny5LcAcZQn2bo6I2xaZMEHnpRG+dtw=
|
||||
github.com/mark3labs/mcp-go v0.48.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mark3labs/mcp-go v0.54.1 h1:Ap/ptEB9FtWzFKM8NDsTA7QDxerQOC06eZigrTldVj0=
|
||||
github.com/mark3labs/mcp-go v0.54.1/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
@@ -223,8 +225,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -238,6 +240,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -254,8 +258,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
@@ -274,54 +278,54 @@ github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 h1:2yEATaop1/a1I4psnSLgWVPLWwCzkqWakgJy7xTDVy0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0/go.mod h1:D7J12YRapIekYyPWgGPlA/23pRmpSEZC5xJC/TTLI9U=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 h1:8tvICD4vSTOOsNrsI4Ljf6C+6UKvpTEH5XY3JMoyPoo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0/go.mod h1:z9+yiacE0IHRqM4qFfkbt/JYlmYXgss8GY/jXoNuPJI=
|
||||
go.opentelemetry.io/otel v1.44.0 h1:JjwHmHpA4iZ3wBxluu2fbbE7j4kqlE8jXyAyPXH7HqU=
|
||||
go.opentelemetry.io/otel v1.44.0/go.mod h1:BMgjTHL9WPRlRjL2oZCBTL4whCGtXch2H4BhOPIAyYc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSYCdDdmc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0 h1:nHYwb9lK+fJPU/dnT6s7W7Z8itMWyqrnVfbheVYrZ58=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0/go.mod h1:Osuydd3Se74nqjAKxid74N5eC+jfEqfTegHRnq58oK0=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0 h1:3LlKgI+VjbVsjNRFZJZAJ30WjXC5VkNRks6si09iEfI=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0/go.mod h1:5B5pMARnXxKhltooO4xUuCBorl65a4EpnTalObqOigA=
|
||||
go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk=
|
||||
go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 h1:cHpkPjp4TILjdZxz/O4ykwCpeS+dDqNuDGse4zgQDCk=
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 h1:RmoJA1ujG+/lRGNfUnOMfhCy5EipVMyvUE+KNbPbTlw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/api v0.282.0 h1:WmJiSVqUnKqJCpJOx7YADbXaC+9DDsnGSfllFSj7R2I=
|
||||
google.golang.org/api v0.282.0/go.mod h1:6Wssta4c5n9qHq5CBhmlai5h/PUa1djdDAIhYEHyvcM=
|
||||
google.golang.org/genai v1.58.0 h1:MNA3ZkRyr7MnRwZ9RNZ60p4+UMKV3yYRw6pyHq4pp0U=
|
||||
google.golang.org/genai v1.58.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348 h1:JjVGDZYWkJWZcxveJGzfkXC5myDVWAd4dZdgbzrDUv8=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348/go.mod h1:95PqD4xM+AdOcBGsmgfaofXsiA37uXDtDufVbntT3TU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348 h1:U8orV30l6KpDsi9dxU0CoJZGbjS8EEpw+6ba+XwGPQA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348/go.mod h1:Yzdzr5OOZFgSsEV2D/Xi9NL3bszpXFAg0hFJiRohcD8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -177,22 +177,75 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// SetSessionModel changes the active model for a session.
|
||||
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
|
||||
// ListSessions returns an empty session list. Kit doesn't persist sessions
|
||||
// across restarts in ACP mode, so this is effectively a no-op.
|
||||
func (a *Agent) ListSessions(_ context.Context, _ acp.ListSessionsRequest) (acp.ListSessionsResponse, error) {
|
||||
return acp.ListSessionsResponse{
|
||||
Sessions: []acp.SessionInfo{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CloseSession cancels any ongoing work for the session and frees its resources.
|
||||
func (a *Agent) CloseSession(_ context.Context, params acp.CloseSessionRequest) (acp.CloseSessionResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
modelID := string(params.ModelId)
|
||||
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
|
||||
log.Debug("acp: close session", "session", sessionID)
|
||||
sess.cancelPrompt()
|
||||
a.registry.remove(sessionID)
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
if err := sess.kit.SetModel(ctx, modelID); err != nil {
|
||||
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
|
||||
// ResumeSession is not supported — Kit doesn't persist sessions across
|
||||
// restarts in ACP mode. Clients should use NewSession instead.
|
||||
func (a *Agent) ResumeSession(_ context.Context, _ acp.ResumeSessionRequest) (acp.ResumeSessionResponse, error) {
|
||||
return acp.ResumeSessionResponse{}, fmt.Errorf("resume session not supported")
|
||||
}
|
||||
|
||||
// SetSessionConfigOption handles session configuration changes. Currently
|
||||
// supports the "model" config option to change the active model for a session.
|
||||
func (a *Agent) SetSessionConfigOption(ctx context.Context, params acp.SetSessionConfigOptionRequest) (acp.SetSessionConfigOptionResponse, error) {
|
||||
// Extract session ID and config ID from whichever variant is present.
|
||||
var sessionID string
|
||||
var configID string
|
||||
var value string
|
||||
|
||||
switch {
|
||||
case params.ValueId != nil:
|
||||
sessionID = string(params.ValueId.SessionId)
|
||||
configID = string(params.ValueId.ConfigId)
|
||||
value = string(params.ValueId.Value)
|
||||
case params.Boolean != nil:
|
||||
sessionID = string(params.Boolean.SessionId)
|
||||
configID = string(params.Boolean.ConfigId)
|
||||
// Boolean config options are not used for model selection.
|
||||
log.Debug("acp: set_session_config_option (boolean)", "session", sessionID, "config", configID, "value", params.Boolean.Value)
|
||||
return acp.SetSessionConfigOptionResponse{}, nil
|
||||
default:
|
||||
return acp.SetSessionConfigOptionResponse{}, acp.NewInvalidParams("unsupported config option variant")
|
||||
}
|
||||
|
||||
return acp.SetSessionModelResponse{}, nil
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionConfigOptionResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
|
||||
log.Debug("acp: set_session_config_option", "session", sessionID, "config", configID, "value", value)
|
||||
|
||||
// Handle known config options.
|
||||
switch configID {
|
||||
case "model":
|
||||
if err := sess.kit.SetModel(ctx, value); err != nil {
|
||||
return acp.SetSessionConfigOptionResponse{}, fmt.Errorf("set model: %w", err)
|
||||
}
|
||||
default:
|
||||
log.Debug("acp: unknown config option", "config", configID)
|
||||
}
|
||||
|
||||
return acp.SetSessionConfigOptionResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -37,10 +39,21 @@ func newSessionRegistry() *sessionRegistry {
|
||||
// given working directory. The Kit-generated session ID is used as the ACP
|
||||
// session ID so the mapping is 1:1.
|
||||
func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, error) {
|
||||
// Each ACP session gets its own isolated config store (CLI is left nil) so
|
||||
// per-session SetModel / SetThinkingLevel calls cannot race or bleed across
|
||||
// the sessionRegistry. We seed the relevant root-command flag values from
|
||||
// the process-global store (which cobra populated from flags) so launching
|
||||
// `kit acp -m <model> [--thinking-level ...] [--provider-url ...]` is still
|
||||
// honored; .kit.yml and KIT_* env vars are loaded per session by kit.New.
|
||||
streamOn := true
|
||||
kitInstance, err := kit.New(ctx, &kit.Options{
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: &streamOn,
|
||||
Model: viper.GetString("model"),
|
||||
ThinkingLevel: viper.GetString("thinking-level"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
})
|
||||
if err != nil {
|
||||
// Provide actionable guidance for provider auth errors, which are
|
||||
@@ -152,38 +165,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
@@ -232,6 +214,20 @@ func (r *sessionRegistry) closeAll() {
|
||||
}
|
||||
}
|
||||
|
||||
// remove closes and removes a single session by ID.
|
||||
func (r *sessionRegistry) remove(sessionID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
sess, ok := r.sessions[sessionID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if sess.kit != nil {
|
||||
_ = sess.kit.Close()
|
||||
}
|
||||
delete(r.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cancelPrompt cancels the current prompt for a session, if any.
|
||||
func (s *acpSession) cancelPrompt() {
|
||||
s.cancelMu.Lock()
|
||||
@@ -255,40 +251,3 @@ func (s *acpSession) clearCancel() {
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
+461
-73
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
@@ -58,6 +60,11 @@ type AgentConfig struct {
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour for any
|
||||
// server that didn't advertise task support during initialize.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -87,6 +94,19 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Called when the last reasoning token has been processed, before text streaming starts.
|
||||
type ReasoningCompleteHandler func()
|
||||
|
||||
// ToolCallStartHandler is a function type for handling the moment when the LLM
|
||||
// begins generating tool call arguments. The tool name is known but the full
|
||||
// argument JSON is still streaming.
|
||||
type ToolCallStartHandler func(toolCallID, toolName string)
|
||||
|
||||
// ToolCallDeltaHandler is a function type for handling streamed fragments of
|
||||
// tool call arguments as they arrive from the LLM.
|
||||
type ToolCallDeltaHandler func(toolCallID, delta string)
|
||||
|
||||
// ToolCallEndHandler is a function type for handling the end of tool argument
|
||||
// streaming, before the tool call is parsed and execution begins.
|
||||
type ToolCallEndHandler func(toolCallID string)
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
@@ -94,6 +114,12 @@ type ReasoningCompleteHandler func()
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// PasswordPromptHandler is a function type for password prompts.
|
||||
// Used by the bash tool when sudo requires a password. The handler receives
|
||||
// a prompt message and returns the password and whether it was cancelled.
|
||||
// Note: This is an alias for core.PasswordPromptCallback.
|
||||
type PasswordPromptHandler = core.PasswordPromptCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
@@ -107,6 +133,76 @@ type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// StepStartHandler is called when a new LLM step begins within a turn.
|
||||
type StepStartHandler func(stepNumber int)
|
||||
|
||||
// StepFinishHandler is called when a step completes with full context.
|
||||
type StepFinishHandler func(stepNumber int, hasToolCalls bool, finishReason string, usage fantasy.Usage)
|
||||
|
||||
// TextStartHandler is called when the LLM begins generating text content.
|
||||
type TextStartHandler func(id string)
|
||||
|
||||
// TextEndHandler is called when the LLM finishes generating text content.
|
||||
type TextEndHandler func(id string)
|
||||
|
||||
// ReasoningStartHandler is called when the LLM begins reasoning/thinking.
|
||||
type ReasoningStartHandler func(id string)
|
||||
|
||||
// WarningsHandler is called when the LLM provider returns warnings.
|
||||
type WarningsHandler func(warnings []string)
|
||||
|
||||
// SourceHandler is called when the LLM references a source.
|
||||
type SourceHandler func(sourceType, id, url, title string)
|
||||
|
||||
// StreamFinishHandler is called when a per-step LLM stream completes.
|
||||
type StreamFinishHandler func(usage fantasy.Usage, finishReason string)
|
||||
|
||||
// ErrorHandler is called when an agent-level error occurs.
|
||||
type ErrorHandler func(err error)
|
||||
|
||||
// RetryHandler is called when the LLM request is retried.
|
||||
type RetryHandler func(attempt int, err error)
|
||||
|
||||
// PrepareStepHandler is called between steps to allow message modification.
|
||||
// It receives the step number and current messages, and returns replacement
|
||||
// messages (or nil to keep unchanged).
|
||||
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
|
||||
|
||||
// GenerateCallbacks consolidates all callback functions for
|
||||
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
|
||||
// 16+ positional callback parameters, making it easier to add new callbacks
|
||||
// without breaking existing callers (new fields default to nil).
|
||||
type GenerateCallbacks struct {
|
||||
OnToolCall ToolCallHandler
|
||||
OnToolExecution ToolExecutionHandler
|
||||
OnToolResult ToolResultHandler
|
||||
OnResponse ResponseHandler
|
||||
OnToolCallContent ToolCallContentHandler
|
||||
OnStreamingResponse StreamingResponseHandler
|
||||
OnReasoningDelta ReasoningDeltaHandler
|
||||
OnReasoningComplete ReasoningCompleteHandler
|
||||
OnToolOutput ToolOutputHandler
|
||||
OnStepMessages StepMessagesHandler
|
||||
OnStepUsage StepUsageHandler
|
||||
OnPasswordPrompt PasswordPromptHandler
|
||||
OnToolCallStart ToolCallStartHandler
|
||||
OnToolCallDelta ToolCallDeltaHandler
|
||||
OnToolCallEnd ToolCallEndHandler
|
||||
|
||||
// New callbacks for previously unwired Fantasy lifecycle events.
|
||||
OnStepStart StepStartHandler
|
||||
OnStepFinish StepFinishHandler
|
||||
OnTextStart TextStartHandler
|
||||
OnTextEnd TextEndHandler
|
||||
OnReasoningStart ReasoningStartHandler
|
||||
OnWarnings WarningsHandler
|
||||
OnSource SourceHandler
|
||||
OnStreamFinish StreamFinishHandler
|
||||
OnError ErrorHandler
|
||||
OnRetry RetryHandler
|
||||
OnPrepareStep PrepareStepHandler
|
||||
}
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
@@ -141,11 +237,21 @@ type Agent struct {
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpTaskConfig is stored from AgentConfig so AddMCPServer() can
|
||||
// propagate it to a lazily-created MCPToolManager.
|
||||
mcpTaskConfig tools.MCPTaskConfig
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
|
||||
// promptMu serializes runtime updates to systemPrompt and the
|
||||
// accompanying fantasy agent rebuild so concurrent SetSystemPrompt
|
||||
// callers (e.g. Kit.applyComposedSystemPrompt invoked from multiple
|
||||
// goroutines) don't race on a.systemPrompt / a.fantasyAgent.
|
||||
promptMu sync.Mutex
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -239,13 +345,13 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
mcpTaskConfig: agentConfig.MCPTaskConfig,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
// The mcpReady channel is closed when loading completes (success or failure).
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager := tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
@@ -259,6 +365,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
// Apply task-augmented tool execution config (zero value = no-op).
|
||||
toolManager.SetTaskConfig(agentConfig.MCPTaskConfig)
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
@@ -325,7 +433,7 @@ func (a *Agent) rebuildFantasyAgent() {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
@@ -405,13 +513,20 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil)
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
// The agent handles the tool call loop internally.
|
||||
//
|
||||
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
|
||||
// struct and is easier to extend with new callbacks.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
@@ -421,6 +536,35 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
onPasswordPrompt PasswordPromptHandler,
|
||||
onToolCallStart ToolCallStartHandler,
|
||||
onToolCallDelta ToolCallDeltaHandler,
|
||||
onToolCallEnd ToolCallEndHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
OnStreamingResponse: onStreamingResponse,
|
||||
OnReasoningDelta: onReasoningDelta,
|
||||
OnReasoningComplete: onReasoningComplete,
|
||||
OnToolOutput: onToolOutput,
|
||||
OnStepMessages: onStepMessages,
|
||||
OnStepUsage: onStepUsage,
|
||||
OnPasswordPrompt: onPasswordPrompt,
|
||||
OnToolCallStart: onToolCallStart,
|
||||
OnToolCallDelta: onToolCallDelta,
|
||||
OnToolCallEnd: onToolCallEnd,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Message,
|
||||
cb GenerateCallbacks,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Wait for background MCP tool loading to complete and rebuild the
|
||||
@@ -429,8 +573,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
a.ensureMCPTools()
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
if cb.OnToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, cb.OnToolOutput)
|
||||
}
|
||||
|
||||
// Inject password prompt handler into context for use by bash tool.
|
||||
if cb.OnPasswordPrompt != nil {
|
||||
ctx = core.ContextWithPasswordPrompt(ctx, cb.OnPasswordPrompt)
|
||||
}
|
||||
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
@@ -443,15 +592,25 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
// Track tool call args per-ToolCallID so parallel tool calls in a single
|
||||
// step don't clobber each other. Without this, OnToolResult callbacks would
|
||||
// all see the args of the last OnToolCall in the step. The mutex guards
|
||||
// against the possibility that the underlying streaming layer dispatches
|
||||
// callbacks from multiple goroutines.
|
||||
toolCallArgs := make(map[string]string)
|
||||
var toolCallArgsMu sync.Mutex
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
onToolCallContent != nil || onStreamingResponse != nil || onReasoningDelta != nil
|
||||
hasCallbacks := cb.OnToolCall != nil || cb.OnToolExecution != nil || cb.OnToolResult != nil ||
|
||||
cb.OnToolCallContent != nil || cb.OnStreamingResponse != nil || cb.OnReasoningDelta != nil ||
|
||||
cb.OnToolCallStart != nil || cb.OnToolCallDelta != nil || cb.OnToolCallEnd != nil ||
|
||||
cb.OnStepStart != nil || cb.OnStepFinish != nil || cb.OnTextStart != nil ||
|
||||
cb.OnTextEnd != nil || cb.OnReasoningStart != nil || cb.OnWarnings != nil ||
|
||||
cb.OnSource != nil || cb.OnStreamFinish != nil || cb.OnError != nil ||
|
||||
cb.OnRetry != nil || cb.OnPrepareStep != nil
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
@@ -460,9 +619,11 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// input) were persisted incrementally via cb.OnStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
// stepCounter tracks the current step number for StepStart/StepFinish events.
|
||||
var stepCounter int
|
||||
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
@@ -470,13 +631,73 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
Files: files,
|
||||
Messages: history,
|
||||
|
||||
// Tool input streaming callbacks — fire during tool argument generation
|
||||
OnToolInputStart: func(id, toolName string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnToolCallStart != nil {
|
||||
cb.OnToolCallStart(id, toolName)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnToolInputDelta: func(id, delta string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnToolCallDelta != nil {
|
||||
cb.OnToolCallDelta(id, delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnToolInputEnd: func(id string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnToolCallEnd != nil {
|
||||
cb.OnToolCallEnd(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text start/end callbacks
|
||||
OnTextStart: func(id string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnTextStart != nil {
|
||||
cb.OnTextStart(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
OnTextEnd: func(id string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnTextEnd != nil {
|
||||
cb.OnTextEnd(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning start callback
|
||||
OnReasoningStart: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnReasoningStart != nil {
|
||||
cb.OnReasoningStart(id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning/thinking streaming callback
|
||||
OnReasoningDelta: func(id, delta string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningDelta != nil {
|
||||
onReasoningDelta(delta)
|
||||
if cb.OnReasoningDelta != nil {
|
||||
cb.OnReasoningDelta(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -486,8 +707,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
if cb.OnReasoningComplete != nil {
|
||||
cb.OnReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -497,8 +718,64 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onStreamingResponse != nil {
|
||||
onStreamingResponse(text)
|
||||
if cb.OnStreamingResponse != nil {
|
||||
cb.OnStreamingResponse(text)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Warnings callback
|
||||
OnWarnings: func(warnings []fantasy.CallWarning) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnWarnings != nil {
|
||||
strs := make([]string, len(warnings))
|
||||
for i, w := range warnings {
|
||||
strs[i] = w.Message
|
||||
}
|
||||
cb.OnWarnings(strs)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Source callback
|
||||
OnSource: func(source fantasy.SourceContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnSource != nil {
|
||||
cb.OnSource(string(source.SourceType), source.ID, source.URL, source.Title)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Stream finish callback (per-step stream completion)
|
||||
OnStreamFinish: func(usage fantasy.Usage, finishReason fantasy.FinishReason, _ fantasy.ProviderMetadata) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if cb.OnStreamFinish != nil {
|
||||
cb.OnStreamFinish(usage, string(finishReason))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Error callback
|
||||
OnError: func(err error) {
|
||||
if cb.OnError != nil {
|
||||
cb.OnError(err)
|
||||
}
|
||||
},
|
||||
|
||||
// Step start callback
|
||||
OnStepStart: func(stepNumber int) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
stepCounter = stepNumber
|
||||
if cb.OnStepStart != nil {
|
||||
cb.OnStepStart(stepNumber)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -508,16 +785,18 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolArgs = tc.Input
|
||||
toolCallArgsMu.Lock()
|
||||
toolCallArgs[tc.ToolCallID] = tc.Input
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify about the tool call
|
||||
if onToolCall != nil {
|
||||
onToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
|
||||
if cb.OnToolCall != nil {
|
||||
cb.OnToolCall(tc.ToolCallID, tc.ToolName, tc.Input)
|
||||
}
|
||||
|
||||
// Notify tool execution starting
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tc.ToolCallID, tc.ToolName, tc.Input, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -528,15 +807,22 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Look up the args recorded for this specific tool call. Delete
|
||||
// the entry so the map doesn't accumulate across steps.
|
||||
toolCallArgsMu.Lock()
|
||||
args := toolCallArgs[tr.ToolCallID]
|
||||
delete(toolCallArgs, tr.ToolCallID)
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify tool execution finished
|
||||
if onToolExecution != nil {
|
||||
onToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, args, false)
|
||||
}
|
||||
|
||||
if onToolResult != nil {
|
||||
if cb.OnToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
onToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, args, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -550,8 +836,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
if cb.OnStepMessages != nil && len(step.Messages) > 0 {
|
||||
cb.OnStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
@@ -561,65 +847,88 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// Check if step has text content alongside tool calls
|
||||
text := step.Content.Text()
|
||||
toolCalls := step.Content.ToolCalls()
|
||||
if text != "" && len(toolCalls) > 0 && onToolCallContent != nil {
|
||||
onToolCallContent(text)
|
||||
if text != "" && len(toolCalls) > 0 && cb.OnToolCallContent != nil {
|
||||
cb.OnToolCallContent(text)
|
||||
}
|
||||
// Emit step usage for real-time cost tracking
|
||||
if onStepUsage != nil {
|
||||
onStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
if cb.OnStepUsage != nil {
|
||||
cb.OnStepUsage(step.Usage.InputTokens, step.Usage.OutputTokens,
|
||||
step.Usage.CacheReadTokens, step.Usage.CacheCreationTokens)
|
||||
}
|
||||
// Emit unified step finish event
|
||||
if cb.OnStepFinish != nil {
|
||||
cb.OnStepFinish(stepCounter, len(toolCalls) > 0, string(step.FinishReason), step.Usage)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// If a steer channel is attached to the context, wire up a
|
||||
// PrepareStep function that drains the channel between steps
|
||||
// and injects pending steer messages as user messages before
|
||||
// the next LLM call. This enables graceful mid-turn steering
|
||||
// without cancelling in-progress tool execution.
|
||||
if steerCh := steerChFromContext(ctx); steerCh != nil {
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
// Always wire up PrepareStep to handle both steering and the
|
||||
// OnPrepareStep hook. Steering drains its channel first, then
|
||||
// OnPrepareStep hooks run against the (possibly already steered)
|
||||
// messages.
|
||||
steerCh := steerChFromContext(ctx)
|
||||
onConsumed := steerConsumedFromContext(ctx)
|
||||
hasSteering := steerCh != nil
|
||||
hasPrepareStepHook := cb.OnPrepareStep != nil
|
||||
|
||||
if hasSteering || hasPrepareStepHook {
|
||||
streamCall.PrepareStep = func(
|
||||
stepCtx context.Context,
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
done:
|
||||
result := fantasy.PrepareStepResult{
|
||||
Model: opts.Model,
|
||||
Messages: opts.Messages,
|
||||
}
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
|
||||
// Phase 1: Drain steering channel (if present).
|
||||
if hasSteering {
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
steered = append(steered, msg)
|
||||
default:
|
||||
goto done
|
||||
}
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
done:
|
||||
if len(steered) > 0 {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Run OnPrepareStep hook (if registered).
|
||||
if hasPrepareStepHook {
|
||||
if replacement := cb.OnPrepareStep(opts.StepNumber, result.Messages); replacement != nil {
|
||||
result.Messages = replacement
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wire OnRetry callback if provided.
|
||||
if cb.OnRetry != nil {
|
||||
streamCall.OnRetry = func(err *fantasy.ProviderError, _ time.Duration) {
|
||||
// Use the retry number from the error if available; Fantasy
|
||||
// doesn't pass a counter directly, so we approximate with a
|
||||
// counter incremented on each call.
|
||||
cb.OnRetry(0, err)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := a.fantasyAgent.Stream(ctx, streamCall)
|
||||
if err != nil {
|
||||
// On cancellation (or any error), return a partial result
|
||||
@@ -645,8 +954,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// empty (e.g. reasoning-only responses) so the UI properly resets
|
||||
// the stream component and avoids duplicate content on the next
|
||||
// flush.
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
if cb.OnResponse != nil {
|
||||
cb.OnResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
r := convertAgentResult(result, messages)
|
||||
@@ -666,8 +975,8 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
// For non-streaming, fire the response callback so callers can reset
|
||||
// streaming state (see streaming path comment above).
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
if cb.OnResponse != nil {
|
||||
cb.OnResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
@@ -808,7 +1117,7 @@ func (a *Agent) GetTools() []fantasy.AgentTool {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
@@ -852,13 +1161,13 @@ func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPSer
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
a.toolManager.SetModel(a.model)
|
||||
if a.authHandler != nil {
|
||||
a.toolManager.SetAuthHandler(a.authHandler)
|
||||
}
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetTaskConfig(a.mcpTaskConfig)
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
@@ -914,6 +1223,56 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// GetMCPPrompts returns all prompts discovered from connected MCP servers.
|
||||
// Returns nil if no MCP servers are configured or no prompts were found.
|
||||
func (a *Agent) GetMCPPrompts() []tools.MCPPrompt {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetPrompts()
|
||||
}
|
||||
|
||||
// GetMCPPrompt retrieves and expands a specific prompt from an MCP server.
|
||||
// This is a lazy call — the server is contacted each time.
|
||||
func (a *Agent) GetMCPPrompt(ctx context.Context, serverName, promptName string, args map[string]string) (*tools.MCPPromptResult, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.GetPrompt(ctx, serverName, promptName, args)
|
||||
}
|
||||
|
||||
// GetMCPResources returns all resources discovered from connected MCP servers.
|
||||
func (a *Agent) GetMCPResources() []tools.MCPResource {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetResources()
|
||||
}
|
||||
|
||||
// ReadMCPResource reads a specific resource from an MCP server by URI.
|
||||
func (a *Agent) ReadMCPResource(ctx context.Context, serverName, uri string) (*tools.MCPResourceContent, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.ReadResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SubscribeMCPResource subscribes to change notifications for a resource.
|
||||
func (a *Agent) SubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.SubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// UnsubscribeMCPResource cancels change notifications for a resource.
|
||||
func (a *Agent) UnsubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.UnsubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
@@ -933,11 +1292,6 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
_ = a.providerCloser.Close()
|
||||
}
|
||||
|
||||
// Update model info on MCP tool manager.
|
||||
if a.toolManager != nil {
|
||||
a.toolManager.SetModel(providerResult.Model)
|
||||
}
|
||||
|
||||
// Swap fields.
|
||||
a.model = providerResult.Model
|
||||
a.providerCloser = providerResult.Closer
|
||||
@@ -970,6 +1324,40 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// SetSystemPrompt updates the agent's system prompt and rebuilds the underlying
|
||||
// fantasy agent so subsequent turns use the new prompt. Safe to call while the
|
||||
// agent is idle; if invoked during an in-flight turn the new prompt takes
|
||||
// effect on the next LLM call.
|
||||
func (a *Agent) SetSystemPrompt(prompt string) {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
a.systemPrompt = prompt
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the agent's current system prompt.
|
||||
func (a *Agent) GetSystemPrompt() string {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
return a.systemPrompt
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the effective max output tokens the agent currently
|
||||
// sends to the LLM provider, after per-model defaults, right-sizing, and any
|
||||
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
|
||||
// attached (e.g. early init) or when the provider suppresses the parameter
|
||||
// (e.g. Codex OAuth), which allows callers to differentiate "default" from
|
||||
// "explicitly capped".
|
||||
func (a *Agent) GetMaxTokens() int {
|
||||
if a.skipMaxOutputTokens {
|
||||
return 0
|
||||
}
|
||||
if a.modelConfig == nil {
|
||||
return 0
|
||||
}
|
||||
return a.modelConfig.MaxTokens
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources.
|
||||
// If MCP tools are still loading in the background, Close waits for them
|
||||
// to finish before closing connections to avoid resource leaks.
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// fakeParallelAgent simulates a provider that emits two parallel tool_use
|
||||
// blocks in a single step. It invokes the streaming callbacks in the order:
|
||||
//
|
||||
// OnToolCall(A) -> OnToolCall(B) -> OnToolResult(A) -> OnToolResult(B)
|
||||
//
|
||||
// Before the fix in #33 the agent-layer wrapper recorded a single
|
||||
// `currentToolArgs` variable that was clobbered by the second OnToolCall, so
|
||||
// both OnToolResult callbacks received B's args instead of their own.
|
||||
type fakeParallelAgent struct {
|
||||
calls []fantasy.ToolCallContent
|
||||
results []fantasy.ToolResultContent
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Stream(_ context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
|
||||
for _, tc := range f.calls {
|
||||
if opts.OnToolCall != nil {
|
||||
if err := opts.OnToolCall(tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tr := range f.results {
|
||||
if opts.OnToolResult != nil {
|
||||
if err := opts.OnToolResult(tr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
// TestGenerateWithCallbacks_ParallelToolArgs is the regression test for #33.
|
||||
// It drives the streaming-callback wiring inside GenerateWithCallbacks with a
|
||||
// fake fantasy.Agent that emits two parallel tool calls before either result.
|
||||
// Each OnToolResult must receive the args of its own tool call (matched by
|
||||
// ToolCallID), not the args of the last OnToolCall in the step.
|
||||
func TestGenerateWithCallbacks_ParallelToolArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
argsA := `{"name":"scheduled_jobs"}`
|
||||
argsB := `{"name":"gmail_trigger"}`
|
||||
|
||||
fake := &fakeParallelAgent{
|
||||
calls: []fantasy.ToolCallContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Input: argsA},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Input: argsB},
|
||||
},
|
||||
results: []fantasy.ToolResultContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-A"}},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-B"}},
|
||||
},
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fake,
|
||||
streamingEnabled: false, // exercise the "hasCallbacks" branch
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
resultArgs := map[string]string{}
|
||||
executionArgs := map[string]string{} // captured when running == false
|
||||
|
||||
cb := GenerateCallbacks{
|
||||
OnToolExecution: func(id, _, args string, running bool) {
|
||||
if running {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
executionArgs[id] = args
|
||||
},
|
||||
OnToolResult: func(id, _, args, _, _ string, _ bool) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
resultArgs[id] = args
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := a.GenerateWithCallbacks(context.Background(), nil, cb); err != nil {
|
||||
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
|
||||
}
|
||||
|
||||
if got, want := resultArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolResult for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := resultArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolResult for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -56,6 +56,8 @@ type AgentCreationOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -76,6 +78,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpExecutor is the subset of *tools.MCPToolManager that the adapter
|
||||
// actually uses. Extracted as an interface so the adapter is unit-testable
|
||||
// without constructing a full manager + connection pool.
|
||||
type mcpExecutor interface {
|
||||
ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error)
|
||||
}
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
exec mcpExecutor
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: t.tool.Name,
|
||||
Description: t.tool.Description,
|
||||
Parameters: t.tool.Parameters,
|
||||
Required: t.tool.Required,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
//
|
||||
// MCP-side failures (JSON-RPC protocol errors, transport failures, schema
|
||||
// validation rejections from the server) are surfaced to the model as soft
|
||||
// tool errors rather than escalated to a critical agent error. This matches
|
||||
// the contract that native Kit tools follow via kit.ErrorResult(...) and
|
||||
// lets the model self-correct (e.g. retry with a fixed argument shape) or
|
||||
// give up gracefully rather than aborting the turn mid-run.
|
||||
//
|
||||
// Context cancellation is the one exception: if the caller cancelled the
|
||||
// context the turn was aborted intentionally, so we propagate the ctx error
|
||||
// to let the agent loop unwind cleanly.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.exec.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return fantasy.ToolResponse{}, ctxErr
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("MCP tool %q failed: %s", t.tool.Name, err.Error()),
|
||||
), nil
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(result.Content), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(result.Content), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
|
||||
// mcpToolsToAgentTools converts a slice of MCPTool to fantasy.AgentTool
|
||||
// implementations that route execution through the MCPToolManager.
|
||||
func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManager) []fantasy.AgentTool {
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
exec: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// stubExecutor lets each test script the (result, err) pair returned by
|
||||
// ExecuteTool. The adapter holds an mcpExecutor interface, so this is the
|
||||
// only seam the tests need.
|
||||
type stubExecutor struct {
|
||||
result *tools.MCPToolResult
|
||||
err error
|
||||
// called records the last invocation for assertion.
|
||||
called bool
|
||||
name string
|
||||
input string
|
||||
}
|
||||
|
||||
func (s *stubExecutor) ExecuteTool(_ context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error) {
|
||||
s.called = true
|
||||
s.name = prefixedName
|
||||
s.input = inputJSON
|
||||
return s.result, s.err
|
||||
}
|
||||
|
||||
func newMCPAgentTool(exec mcpExecutor, name string) *mcpAgentTool {
|
||||
return &mcpAgentTool{
|
||||
tool: tools.MCPTool{Name: name},
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager-side Go errors (JSON-RPC protocol errors, transport failures,
|
||||
// schema validation rejections from the MCP server) must be surfaced to
|
||||
// the model as soft tool errors so the agent loop can keep going. Aborting
|
||||
// the turn would discard all prior tool results — see issue #N.
|
||||
func TestMCPAgentTool_RPCErrorBecomesSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
err: errors.New("MCP error -32602: Invalid params: missing field \"task\""),
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "pubmed__search",
|
||||
Input: `{"query":"foo"}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft), got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "pubmed__search") {
|
||||
t.Errorf("expected tool name in error content, got %q", resp.Content)
|
||||
}
|
||||
if !strings.Contains(resp.Content, "-32602") {
|
||||
t.Errorf("expected underlying error text in content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Context cancellation is the one error that must remain critical: it
|
||||
// means the caller intentionally aborted, and the agent loop needs to
|
||||
// unwind cleanly rather than burning more steps.
|
||||
func TestMCPAgentTool_CtxCancelStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
// Real managers typically return ctx.Err() (or a wrapper) when the
|
||||
// context is cancelled mid-call.
|
||||
err: context.Canceled,
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline-exceeded behaves the same as cancellation: ctx.Err() is
|
||||
// non-nil, so the adapter must propagate the critical error rather than
|
||||
// converting the executor's error into a soft response.
|
||||
func TestMCPAgentTool_CtxDeadlineStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{err: context.DeadlineExceeded}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side soft errors (CallToolResult{ isError: true }) must continue
|
||||
// to flow through as soft errors — this was the existing behavior and
|
||||
// must not regress.
|
||||
func TestMCPAgentTool_ServerIsErrorRemainsSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: true,
|
||||
Content: "search service is rate limited; try again in 30s",
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if resp.Content != "search service is rate limited; try again in 30s" {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Happy path: ordinary successful tool result is passed through unchanged.
|
||||
func TestMCPAgentTool_SuccessIsPassthrough(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: false,
|
||||
Content: `{"hits":3}`,
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected IsError=false")
|
||||
}
|
||||
if resp.Content != `{"hits":3}` {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
+259
-43
@@ -70,14 +70,24 @@ type App struct {
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
// widgetUpdatePending is set to true while a WidgetUpdateEvent burst is
|
||||
// being coalesced. The leading edge fires immediately; subsequent calls
|
||||
// within the debounce window set widgetUpdateTrailing so a final event
|
||||
// is delivered with the latest runner state at the end of the window.
|
||||
// Without the trailing send, a rapid SetWidget→RemoveWidget pair (e.g.
|
||||
// SubagentEnd pushing a final frame then removing the widget) would let
|
||||
// the second call get silently dropped, leaving the TUI's layout stuck
|
||||
// on the pre-removal widget height — visible as empty rows below the
|
||||
// status bar after the widget disappears.
|
||||
widgetUpdatePending atomic.Bool
|
||||
widgetUpdateTrailing atomic.Bool
|
||||
|
||||
// steerDrainFn is the test seam used by releaseBusyAfterCompact to pull
|
||||
// any steer messages that arrived during compaction. In production it is
|
||||
// nil and the helper falls back to a.opts.Kit.DrainSteer(); tests that
|
||||
// need to exercise the steer-drain path without standing up a full
|
||||
// *kit.Kit can set this field directly to inject fake items.
|
||||
steerDrainFn func() []queueItem
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
@@ -356,6 +366,10 @@ func (a *App) AddContextMessage(text string) {
|
||||
// tea.Program. customInstructions is optional text appended to the summary
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
//
|
||||
// Any prompts queued via Run/RunWithFiles or steering messages injected via
|
||||
// Steer/SteerWithFiles while compaction is running are flushed automatically
|
||||
// once compaction completes (see releaseBusyAfterCompact).
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Lock()
|
||||
@@ -377,11 +391,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -420,6 +430,9 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
//
|
||||
// Like CompactConversation, any prompts/steer messages received during
|
||||
// compaction are flushed automatically once compaction finishes.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
@@ -440,11 +453,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -489,6 +498,81 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseBusyAfterCompact is the deferred tail that runs at the end of every
|
||||
// compaction goroutine (success, error, or panic-after-recover paths). It
|
||||
// flips a.busy back to false, but before doing so it checks whether any
|
||||
// prompts piled up while compaction was running:
|
||||
//
|
||||
// - Run/RunWithFiles append to a.queue when a.busy is set.
|
||||
// - Steer/SteerWithFiles deposit messages into the SDK steer channel via
|
||||
// Kit.InjectSteerWithFiles when a.busy is set.
|
||||
//
|
||||
// Without this hand-off the queue would sit idle until the user submits
|
||||
// another prompt — see issue #27. If we find anything pending we keep busy
|
||||
// set, splice the steer messages to the front of the queue, and start a
|
||||
// fresh drainQueue goroutine to deliver them as a single batched turn.
|
||||
func (a *App) releaseBusyAfterCompact() {
|
||||
// Pull steer messages outside the app mutex; DrainSteer takes its own
|
||||
// internal lock and we don't want to nest the two. The test seam
|
||||
// (a.steerDrainFn) takes precedence so unit tests can inject fake
|
||||
// steer items without a real *kit.Kit.
|
||||
var steerItems []queueItem
|
||||
switch {
|
||||
case a.steerDrainFn != nil:
|
||||
steerItems = a.steerDrainFn()
|
||||
case a.opts.Kit != nil:
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
steerItems = make([]queueItem, len(leftover))
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
// If the app was closed while compaction was running, drop everything
|
||||
// and just clear busy. Run/Steer would have rejected new items already
|
||||
// after Close(), but this guards against in-flight items that slipped
|
||||
// in just before closed was set.
|
||||
if a.closed {
|
||||
a.queue = a.queue[:0]
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Combine steer-channel items (front) with the in-memory queue (back).
|
||||
// Steer messages are placed first so they retain their "act now"
|
||||
// semantics relative to ordinary queued prompts that arrived later.
|
||||
pending := append(steerItems, a.queue...)
|
||||
a.queue = a.queue[:0]
|
||||
|
||||
if len(pending) == 0 {
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hand off to drainQueue: it will pick up the first item directly and
|
||||
// scoop the rest from a.queue on its first iteration.
|
||||
first := pending[0]
|
||||
if len(pending) > 1 {
|
||||
a.queue = append(a.queue, pending[1:]...)
|
||||
}
|
||||
// Stay busy across the goroutine swap.
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Notify the UI that steer-channel messages were consumed so the
|
||||
// steering badge can clear; ordinary queued prompts will be reflected
|
||||
// by the QueueUpdatedEvent that drainQueue emits as it picks them up.
|
||||
if len(steerItems) > 0 {
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
|
||||
go a.drainQueue(first)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -497,6 +581,12 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
// response text to stdout. No intermediate events are emitted. Blocks until
|
||||
// the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
return a.RunOnceWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithFiles executes a single agent step synchronously with optional
|
||||
// multimodal file attachments. Prints the response to stdout and returns.
|
||||
func (a *App) RunOnceWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -504,7 +594,7 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -519,6 +609,12 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
// full TurnResult without printing anything. This is used by --json mode to
|
||||
// capture structured output for serialization.
|
||||
func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
return a.RunOnceResultWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceResultWithFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments and returns the full TurnResult.
|
||||
func (a *App) RunOnceResultWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -526,7 +622,7 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
return a.executeStep(stepCtx, prompt, nil, nil)
|
||||
return a.executeStep(stepCtx, prompt, nil, files)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplay executes a single agent step synchronously, sending
|
||||
@@ -540,6 +636,12 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
//
|
||||
// Blocks until the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn func(tea.Msg)) error {
|
||||
return a.RunOnceWithDisplayAndFiles(ctx, prompt, eventFn, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplayAndFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments, sending intermediate display events.
|
||||
func (a *App) RunOnceWithDisplayAndFiles(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -547,7 +649,7 @@ func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn fun
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -870,6 +972,12 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
switch ev := e.(type) {
|
||||
case kit.ToolCallEvent:
|
||||
sendFn(ToolCallStartedEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs})
|
||||
case kit.ToolCallStartEvent:
|
||||
sendFn(ToolCallInputStartEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolKind: ev.ToolKind})
|
||||
case kit.ToolCallDeltaEvent:
|
||||
sendFn(ToolCallInputDeltaEvent{ToolCallID: ev.ToolCallID, Delta: ev.Delta})
|
||||
case kit.ToolCallEndEvent:
|
||||
sendFn(ToolCallInputEndEvent{ToolCallID: ev.ToolCallID})
|
||||
case kit.ToolExecutionStartEvent:
|
||||
sendFn(ToolExecutionEvent{ToolCallID: ev.ToolCallID, ToolName: ev.ToolName, ToolArgs: ev.ToolArgs, IsStarting: true})
|
||||
case kit.ToolExecutionEndEvent:
|
||||
@@ -899,7 +1007,23 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
a.recordStepUsage(ev, stepUsageSeen, sendFn)
|
||||
case kit.PasswordPromptEvent:
|
||||
// Convert SDK PasswordPromptEvent to app PasswordPromptEvent
|
||||
// The TUI will handle this and send the response back
|
||||
responseCh := make(chan PasswordPromptResponse, 1)
|
||||
sendFn(PasswordPromptEvent{
|
||||
Prompt: ev.Prompt,
|
||||
ResponseCh: responseCh,
|
||||
})
|
||||
// Wait for TUI response and forward to SDK
|
||||
resp := <-responseCh
|
||||
ev.ResponseCh <- kit.PasswordPromptResponse{
|
||||
Password: resp.Password,
|
||||
Cancelled: resp.Cancelled,
|
||||
}
|
||||
case kit.TurnEndEvent:
|
||||
a.handleTurnEnd(ev, sendFn)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -910,6 +1034,64 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
}
|
||||
}
|
||||
|
||||
// handleTurnEnd inspects a turn's final StopReason and surfaces actionable
|
||||
// feedback to the user when the turn ended in a state they can act on.
|
||||
//
|
||||
// Today the only surfaced case is FinishReasonLength — the model hit its
|
||||
// configured max_output_tokens budget and the reply was truncated. Without
|
||||
// this banner the TUI used to swallow the truncation silently, leading to
|
||||
// "ghost" cut-offs with no indication of why.
|
||||
//
|
||||
// Separated from subscribeSDKEvents so tests can exercise it directly via a
|
||||
// stubbed sendFn without standing up a full Kit.
|
||||
func (a *App) handleTurnEnd(ev kit.TurnEndEvent, sendFn func(tea.Msg)) {
|
||||
if sendFn == nil {
|
||||
return
|
||||
}
|
||||
if ev.StopReason != kit.FinishReasonLength {
|
||||
return
|
||||
}
|
||||
sendFn(ExtensionPrintEvent{
|
||||
Level: "info",
|
||||
Text: a.formatMaxTokensTruncatedMessage(),
|
||||
})
|
||||
}
|
||||
|
||||
// formatMaxTokensTruncatedMessage builds the user-facing explanation for a
|
||||
// truncated turn. It reports the active max_output_tokens budget and, when
|
||||
// known, the model's catalog output ceiling so the user can judge how much
|
||||
// headroom is available.
|
||||
func (a *App) formatMaxTokensTruncatedMessage() string {
|
||||
k := a.opts.Kit
|
||||
if k == nil {
|
||||
// Extremely early / test-stub case: still emit a useful generic hint.
|
||||
return "⚠ Response truncated: the model hit the configured max_output_tokens limit. " +
|
||||
"Raise it with --max-tokens N, KIT_MAX_TOKENS=N, or per-model " +
|
||||
"modelSettings[provider/model].maxTokens in config."
|
||||
}
|
||||
current := k.MaxTokens()
|
||||
ceiling := k.MaxOutputLimit()
|
||||
model := k.GetModelString()
|
||||
|
||||
msg := "⚠ Response truncated: "
|
||||
if model != "" {
|
||||
msg += fmt.Sprintf("%s hit the configured max_output_tokens limit", model)
|
||||
} else {
|
||||
msg += "the model hit the configured max_output_tokens limit"
|
||||
}
|
||||
if current > 0 {
|
||||
msg += fmt.Sprintf(" (%d)", current)
|
||||
}
|
||||
msg += "."
|
||||
if ceiling > 0 && current > 0 && ceiling > current {
|
||||
msg += fmt.Sprintf(" This model supports up to %d output tokens.", ceiling)
|
||||
}
|
||||
msg += "\n\nRaise it with --max-tokens N, KIT_MAX_TOKENS=N, " +
|
||||
"or per-model modelSettings[provider/model].maxTokens in your config. " +
|
||||
"Re-run the last prompt after raising it to get the full response."
|
||||
return msg
|
||||
}
|
||||
|
||||
// QuitFromExtension triggers a graceful shutdown. In interactive mode it
|
||||
// sends a tea.QuitMsg to the program so the TUI exits cleanly. In
|
||||
// non-interactive mode it cancels the root context, stopping any in-flight
|
||||
@@ -978,32 +1160,47 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
// Coalescing (leading + trailing edge): the first call in an idle period
|
||||
// fires immediately for responsiveness. Subsequent calls within a ~16 ms
|
||||
// debounce window are batched into a single trailing event delivered at
|
||||
// the end of the window. The trailing send is essential for correctness:
|
||||
// extensions routinely make tight SetWidget→RemoveWidget pairs (e.g. on
|
||||
// SubagentEnd) and silently dropping the second call would leave the TUI's
|
||||
// layout stuck on stale widget dimensions until some other event happens
|
||||
// to trigger a re-render.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
// A leading-edge event is already in flight — mark that the runner
|
||||
// state has changed again so the trailing send below picks it up.
|
||||
a.widgetUpdateTrailing.Store(true)
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
if prog == nil {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
return
|
||||
}
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
// If any extra calls came in during the debounce window, deliver
|
||||
// one trailing event so the TUI sees the latest widget state. We
|
||||
// swap-and-test instead of plain-load so concurrent calls after
|
||||
// the trailing send still race correctly with the pending reset.
|
||||
if a.widgetUpdateTrailing.Swap(false) {
|
||||
a.mu.Lock()
|
||||
p := a.program
|
||||
a.mu.Unlock()
|
||||
if p != nil {
|
||||
p.Send(WidgetUpdateEvent{})
|
||||
}
|
||||
}
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
@@ -1143,7 +1340,16 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
//
|
||||
// Both session totals (cost, token counts) and the context window fill level
|
||||
// are updated here so the status bar reflects progress after every LLM call,
|
||||
// not just at the end of the full turn. Context fill monotonically increases
|
||||
// across steps because each step re-sends the entire conversation plus any
|
||||
// new tool results, so the numbers only go up.
|
||||
//
|
||||
// sendFn is called with a UsageUpdatedEvent to trigger a TUI re-render so
|
||||
// the updated values are visible immediately.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool, sendFn func(tea.Msg)) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
@@ -1164,11 +1370,21 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
// Update context window fill from this step's usage. Each step sends
|
||||
// the full conversation to the LLM, so the reported token counts
|
||||
// represent the actual context utilization at that point.
|
||||
contextFill := int(ev.InputTokens) + int(ev.CacheReadTokens) + int(ev.CacheWriteTokens) + int(ev.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheWrite=%d + Output=%d)",
|
||||
contextFill, ev.InputTokens, ev.CacheReadTokens, ev.CacheWriteTokens, ev.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
// Notify the TUI so it re-renders the status bar with updated values.
|
||||
if sendFn != nil {
|
||||
sendFn(UsageUpdatedEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
|
||||
+310
-7
@@ -3,10 +3,12 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -532,9 +534,9 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
// recorded immediately for cost tracking. Context tokens are also updated so
|
||||
// the status bar reflects context fill after every LLM call in a multi-step
|
||||
// turn, not just at the end.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
@@ -545,7 +547,7 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
}, nil, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
@@ -557,9 +559,13 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
// Context tokens should now be updated per-step (Input + CacheRead + CacheWrite + Output).
|
||||
if usage.contextCalls != 1 {
|
||||
t.Fatalf("expected 1 context token update from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
expectedContext := 120 + 45 + 5 + 2
|
||||
if usage.lastContextTokens != expectedContext {
|
||||
t.Fatalf("expected context tokens %d, got %d", expectedContext, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -666,3 +672,300 @@ func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T)
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_LengthEmitsWarning verifies that when the SDK reports a
|
||||
// FinishReasonLength (max_output_tokens hit), the app surfaces a user-visible
|
||||
// ExtensionPrintEvent with Level="info" so the TUI can render a banner
|
||||
// instead of silently showing a truncated reply.
|
||||
func TestHandleTurnEnd_LengthEmitsWarning(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
var mu sync.Mutex
|
||||
var received []tea.Msg
|
||||
sendFn := func(m tea.Msg) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
received = append(received, m)
|
||||
}
|
||||
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, sendFn)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 event on length stop, got %d", len(received))
|
||||
}
|
||||
ev, ok := received[0].(ExtensionPrintEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ExtensionPrintEvent, got %T", received[0])
|
||||
}
|
||||
if ev.Level != "info" {
|
||||
t.Errorf("expected Level=info, got %q", ev.Level)
|
||||
}
|
||||
if ev.Text == "" {
|
||||
t.Error("expected non-empty warning text")
|
||||
}
|
||||
if !strings.Contains(ev.Text, "max_output_tokens") {
|
||||
t.Errorf("warning text should mention max_output_tokens, got: %s", ev.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_NonLengthIgnored verifies that ordinary stop reasons
|
||||
// (stop, tool-calls, error, unknown, "") do not produce a warning banner.
|
||||
func TestHandleTurnEnd_NonLengthIgnored(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
reasons := []string{
|
||||
kit.FinishReasonStop,
|
||||
kit.FinishReasonToolCalls,
|
||||
kit.FinishReasonError,
|
||||
kit.FinishReasonContentFilter,
|
||||
kit.FinishReasonOther,
|
||||
kit.FinishReasonUnknown,
|
||||
"",
|
||||
}
|
||||
for _, r := range reasons {
|
||||
var called bool
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: r}, func(m tea.Msg) {
|
||||
called = true
|
||||
})
|
||||
if called {
|
||||
t.Errorf("stop reason %q unexpectedly emitted a warning", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_NilSendFn guards against panics when no TUI listener is
|
||||
// attached (e.g. early init or headless teardown).
|
||||
func TestHandleTurnEnd_NilSendFn(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Should not panic with a nil sendFn.
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, nil)
|
||||
}
|
||||
|
||||
// TestFormatMaxTokensTruncatedMessage_NoKit verifies the fallback message
|
||||
// when Options.Kit is nil (test/stub path).
|
||||
func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
msg := app.formatMaxTokensTruncatedMessage()
|
||||
if msg == "" {
|
||||
t.Fatal("expected non-empty fallback message")
|
||||
}
|
||||
for _, needle := range []string{"max_output_tokens", "--max-tokens", "KIT_MAX_TOKENS", "modelSettings"} {
|
||||
if !strings.Contains(msg, needle) {
|
||||
t.Errorf("fallback message missing %q:\n%s", needle, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// releaseBusyAfterCompact (issue #27)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestReleaseBusyAfterCompact_flushesQueuedMessages is a regression test for
|
||||
// issue #27: messages queued via Run() while /compact is running used to sit
|
||||
// in a.queue indefinitely until the user typed another prompt. After the fix
|
||||
// the deferred releaseBusyAfterCompact tail picks up any pending items and
|
||||
// dispatches drainQueue automatically.
|
||||
//
|
||||
// We simulate the compaction completion path directly (bypassing the SDK)
|
||||
// by toggling busy=true, populating the queue exactly as Run() would have
|
||||
// during compaction, and then invoking releaseBusyAfterCompact.
|
||||
func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("compacted then drained"), nil
|
||||
},
|
||||
)
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate the state at the start of the compaction tail: busy is set
|
||||
// and a couple of prompts have piled up in the queue while we were
|
||||
// summarising. (Run() would have appended them and returned a queue
|
||||
// length > 0 to the caller.)
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued during compact #1"},
|
||||
queueItem{Prompt: "queued during compact #2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
// Invoke the deferred tail directly. It should kick off drainQueue.
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// drainQueue runs in a goroutine. Wait for the app to come back to idle.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after releaseBusyAfterCompact: queue not drained")
|
||||
}
|
||||
|
||||
// Wait for any in-flight goroutine to finish before reading state.
|
||||
app.wg.Wait()
|
||||
|
||||
if got := app.QueueLength(); got != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d", got)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatalf("expected stub PromptFunc to fire at least once after compact, got %d calls", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_idleWhenQueueEmpty verifies that with no
|
||||
// pending messages the helper just clears busy and does NOT spawn a
|
||||
// drainQueue goroutine (no spurious agent turn).
|
||||
func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false after releaseBusyAfterCompact with empty queue")
|
||||
}
|
||||
|
||||
// Give any rogue goroutine a moment to (incorrectly) call PromptFunc.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls when queue empty, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue exercises the SDK
|
||||
// steer-drain branch of releaseBusyAfterCompact (issue #27 follow-up).
|
||||
//
|
||||
// Production wires a.opts.Kit.DrainSteer() to pull messages that arrived via
|
||||
// Steer/SteerWithFiles during compaction, but Options.Kit is *kit.Kit (a
|
||||
// concrete struct) so unit tests cannot stand up a real instance without a
|
||||
// full LLM backend. The test uses the unexported steerDrainFn seam to inject
|
||||
// fake steer items, then asserts that:
|
||||
//
|
||||
// - Steer items are dispatched ahead of any prompts that piled up in
|
||||
// a.queue (steer retains "act now" priority over ordinary queued
|
||||
// prompts), and
|
||||
// - the helper still hands off to drainQueue so the steer item actually
|
||||
// fires (the previous behaviour left them stranded — see #27).
|
||||
func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
|
||||
var pmu sync.Mutex
|
||||
var firstPrompt string
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("steer dispatched"), nil
|
||||
},
|
||||
)
|
||||
// Wrap PromptFunc so we can capture the prompt text the stub receives
|
||||
// (newStubWithFuncs's fns ignore prompt; we need it to verify ordering).
|
||||
capturingPrompt := func(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
pmu.Lock()
|
||||
if firstPrompt == "" {
|
||||
firstPrompt = prompt
|
||||
}
|
||||
pmu.Unlock()
|
||||
return stub.fn(ctx, prompt)
|
||||
}
|
||||
app := New(Options{PromptFunc: capturingPrompt}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Inject fake steer items via the test seam. In production the same
|
||||
// items would have been delivered through Kit.InjectSteerWithFiles
|
||||
// during /compact and pulled by DrainSteer here.
|
||||
app.steerDrainFn = func() []queueItem {
|
||||
return []queueItem{
|
||||
{Prompt: "steer-1"},
|
||||
{Prompt: "steer-2"},
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate the state at the end of compaction: busy is set and a couple
|
||||
// of regular Run() prompts have piled up after the steer messages.
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued-1"},
|
||||
queueItem{Prompt: "queued-2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// Wait for the dispatched batch to complete.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after steer-spliced releaseBusyAfterCompact")
|
||||
}
|
||||
app.wg.Wait()
|
||||
|
||||
// drainQueue picks up `first` directly and batches the rest. With
|
||||
// PromptFunc set, executeBatch invokes us with items[0] only — that
|
||||
// item must be the first steer message, proving steer items were
|
||||
// spliced ahead of the previously queued prompts.
|
||||
pmu.Lock()
|
||||
got := firstPrompt
|
||||
pmu.Unlock()
|
||||
if got != "steer-1" {
|
||||
t.Fatalf("expected first dispatched prompt to be steer item %q (steer items must come before queued prompts), got %q",
|
||||
"steer-1", got)
|
||||
}
|
||||
|
||||
// Queue should be fully drained and PromptFunc must have actually fired.
|
||||
if n := app.QueueLength(); n != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d entries", n)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatal("expected stub PromptFunc to fire at least once after splice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_dropsQueueWhenClosed verifies that if the app
|
||||
// was closed during compaction the helper discards any pending items rather
|
||||
// than spawning drainQueue against a torn-down App.
|
||||
func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
|
||||
app.closed = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
qLen := len(app.queue)
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false even when closed")
|
||||
}
|
||||
if qLen != 0 {
|
||||
t.Fatalf("expected queue cleared on closed app, got %d entries", qLen)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,36 @@ type ToolCallStartedEvent struct {
|
||||
ToolArgs string
|
||||
}
|
||||
|
||||
// ToolCallInputStartEvent is sent when the LLM begins generating tool call
|
||||
// arguments. The tool name is known but the full argument JSON is still being
|
||||
// streamed. UIs can use this to show a "running" indicator immediately instead
|
||||
// of waiting for the full argument JSON to finish streaming.
|
||||
type ToolCallInputStartEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
// ToolName is the name of the tool being called.
|
||||
ToolName string
|
||||
// ToolKind classifies the tool: "execute", "edit", "read", "search", "agent".
|
||||
ToolKind string
|
||||
}
|
||||
|
||||
// ToolCallInputDeltaEvent is sent for each streamed fragment of tool call
|
||||
// arguments as they arrive from the LLM. Useful for live-previewing content
|
||||
// or showing a progress indicator with byte count.
|
||||
type ToolCallInputDeltaEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
// Delta is a JSON fragment of tool call arguments.
|
||||
Delta string
|
||||
}
|
||||
|
||||
// ToolCallInputEndEvent is sent when tool argument streaming is complete,
|
||||
// before the tool call is parsed and execution begins.
|
||||
type ToolCallInputEndEvent struct {
|
||||
// ToolCallID is the stable identifier for correlating tool lifecycle events.
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// ToolExecutionEvent is sent when a tool starts or finishes executing.
|
||||
// The IsStarting flag distinguishes between the start and end of execution.
|
||||
type ToolExecutionEvent struct {
|
||||
@@ -79,6 +109,24 @@ type ToolCallContentEvent struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
// PasswordPromptEvent is sent when a sudo command needs a password.
|
||||
// The TUI should display a password prompt overlay and send the result back.
|
||||
type PasswordPromptEvent struct {
|
||||
// Prompt is the message to display to the user.
|
||||
Prompt string
|
||||
// ResponseCh receives the password from the TUI.
|
||||
// The TUI must send exactly one value.
|
||||
ResponseCh chan<- PasswordPromptResponse
|
||||
}
|
||||
|
||||
// PasswordPromptResponse carries the user's password input.
|
||||
type PasswordPromptResponse struct {
|
||||
// Password is the entered password.
|
||||
Password string
|
||||
// Cancelled is true if the user cancelled the prompt.
|
||||
Cancelled bool
|
||||
}
|
||||
|
||||
// ResponseCompleteEvent is sent when the LLM produces a final (non-streaming) response.
|
||||
// In streaming mode, this may be empty if all content was delivered via StreamChunkEvents.
|
||||
type ResponseCompleteEvent struct {
|
||||
@@ -162,6 +210,12 @@ type ModelChangedEvent struct {
|
||||
ModelName string
|
||||
}
|
||||
|
||||
// UsageUpdatedEvent is sent after each completed LLM step to notify the TUI
|
||||
// that token counts and costs have changed. The UsageTracker is updated
|
||||
// in-place before this event is sent; the TUI just needs to re-render to
|
||||
// reflect the new values in the status bar.
|
||||
type UsageUpdatedEvent struct{}
|
||||
|
||||
// WidgetUpdateEvent is sent when an extension adds, updates, or removes a
|
||||
// widget via ctx.SetWidget or ctx.RemoveWidget. The TUI re-reads widget state
|
||||
// from its WidgetProvider on the next render cycle.
|
||||
|
||||
@@ -3,24 +3,21 @@ package app
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// makeTextMsg builds a minimal kit.LLMMessage using fantasy.NewUserMessage
|
||||
// or constructing with the given role.
|
||||
// makeTextMsg builds a minimal kit.LLMMessage with the given role and text.
|
||||
func makeTextMsg(role, text string) kit.LLMMessage {
|
||||
return kit.LLMMessage{
|
||||
Role: kit.LLMMessageRole(role),
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
|
||||
Content: []kit.LLMMessagePart{kit.LLMTextPart{Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
// textOf extracts the plain text from an LLMMessage for assertions.
|
||||
func textOf(msg kit.LLMMessage) string {
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(fantasy.TextPart); ok {
|
||||
if tp, ok := part.(kit.LLMTextPart); ok {
|
||||
return tp.Text
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,29 +255,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
@@ -417,26 +394,6 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
@@ -471,5 +428,13 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
|
||||
return envKey, "ANTHROPIC_API_KEY environment variable", nil
|
||||
}
|
||||
|
||||
// Check if OpenAI credentials exist to provide a helpful suggestion
|
||||
if cm != nil {
|
||||
hasOpenAI, _ := cm.HasOpenAICredentials()
|
||||
if hasOpenAI {
|
||||
return "", "", fmt.Errorf("no Anthropic API key found. Use 'kit auth login anthropic', set ANTHROPIC_API_KEY environment variable, or use --provider-api-key flag\n\nNote: OpenAI credentials were detected. To use OpenAI, run with --model openai/gpt-5.4 or set it as default:\n kit auth login openai --set-default")
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("no Anthropic API key found. Use 'kit auth login anthropic', set ANTHROPIC_API_KEY environment variable, or use --provider-api-key flag")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,37 @@ type MCPServerConfig struct {
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
|
||||
// NoOAuth disables OAuth transport configuration for this server, even
|
||||
// when the connection pool has an auth handler. Use this for public MCP
|
||||
// servers (e.g. PubMed) that don't require authentication. Without this
|
||||
// flag, the pool would attach OAuth transport to every remote server,
|
||||
// causing proactive dynamic-client-registration attempts that fail on
|
||||
// servers that don't support it.
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
|
||||
// TasksMode controls when this server's tools/call requests are augmented
|
||||
// with MCP task metadata (turning a synchronous call into an asynchronous,
|
||||
// pollable job — see https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks).
|
||||
//
|
||||
// Valid values:
|
||||
// - "" or "auto": (default) augment requests with task metadata only
|
||||
// when the server advertises tasks/toolCalls capability during initialize.
|
||||
// - "never": never augment — every tool call is synchronous, regardless
|
||||
// of server capability.
|
||||
// - "always": always augment, even when the server didn't advertise
|
||||
// task support. The server may still respond synchronously; this just
|
||||
// opts in unconditionally on the client side.
|
||||
//
|
||||
// In all modes, when the server returns a CreateTaskResult the client polls
|
||||
// tasks/get / tasks/result until the task reaches a terminal state.
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
|
||||
// InProcessServer holds a live *server.MCPServer for in-process transport.
|
||||
// When set (and Type is "inprocess"), the connection pool creates an
|
||||
// in-process client instead of spawning a subprocess or making HTTP calls.
|
||||
// This field is never serialized — it is only used programmatically via the SDK.
|
||||
InProcessServer any `json:"-" yaml:"-"`
|
||||
|
||||
// Legacy fields for backward compatibility
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
@@ -53,6 +84,8 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -65,6 +98,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Try new format first
|
||||
@@ -80,6 +114,8 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.OAuthClientID = newConfig.OAuthClientID
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
s.NoOAuth = newConfig.NoOAuth
|
||||
s.TasksMode = newConfig.TasksMode
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -100,6 +136,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = legacyConfig.Headers
|
||||
s.AllowedTools = legacyConfig.AllowedTools
|
||||
s.ExcludedTools = legacyConfig.ExcludedTools
|
||||
s.TasksMode = legacyConfig.TasksMode
|
||||
|
||||
// Infer type from legacy format for better compatibility
|
||||
// Only set Type when it doesn't change existing transport behavior
|
||||
@@ -277,11 +314,18 @@ func (s *MCPServerConfig) GetTransportType() string {
|
||||
return "stdio"
|
||||
case "remote":
|
||||
return "streamable"
|
||||
case "inprocess":
|
||||
return "inprocess"
|
||||
default:
|
||||
return s.Type
|
||||
}
|
||||
}
|
||||
|
||||
// Programmatic in-process server detection.
|
||||
if s.InProcessServer != nil {
|
||||
return "inprocess"
|
||||
}
|
||||
|
||||
// Backward compatibility: infer transport type
|
||||
if len(s.Command) > 0 {
|
||||
return "stdio"
|
||||
@@ -301,6 +345,17 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("server %s: allowedTools and excludedTools are mutually exclusive", serverName)
|
||||
}
|
||||
|
||||
// Reject unknown tasksMode values up front so a typo (e.g. "alwasy")
|
||||
// fails loud here instead of being silently downgraded to "auto" by
|
||||
// the runtime parser. Comparison is case-insensitive to match
|
||||
// tools.ParseTaskMode.
|
||||
switch strings.ToLower(strings.TrimSpace(serverConfig.TasksMode)) {
|
||||
case "", "auto", "never", "always":
|
||||
// ok
|
||||
default:
|
||||
return fmt.Errorf("server %s: invalid tasksMode %q (expected one of: auto, never, always)", serverName, serverConfig.TasksMode)
|
||||
}
|
||||
|
||||
transport := serverConfig.GetTransportType()
|
||||
switch transport {
|
||||
case "stdio":
|
||||
@@ -312,8 +367,12 @@ func (c *Config) Validate() error {
|
||||
if serverConfig.URL == "" {
|
||||
return fmt.Errorf("server %s: url is required for %s transport", serverName, transport)
|
||||
}
|
||||
case "inprocess":
|
||||
if serverConfig.InProcessServer == nil {
|
||||
return fmt.Errorf("server %s: InProcessServer is required for inprocess transport", serverName)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable", serverName, transport)
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable, inprocess", serverName, transport)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -627,3 +627,92 @@ func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_NewFormat(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://my-mcp-server.com",
|
||||
"tasksMode": "always"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "always" {
|
||||
t.Errorf("expected TasksMode 'always', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_LegacyFormat(t *testing.T) {
|
||||
// tasksMode also recognised in the legacy unmarshal path so users on
|
||||
// the older command/args shape can opt in without migrating.
|
||||
jsonData := `{
|
||||
"command": "npx",
|
||||
"args": ["@modelcontextprotocol/server-filesystem", "/path"],
|
||||
"tasksMode": "never"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "never" {
|
||||
t.Errorf("expected TasksMode 'never', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_DefaultEmpty(t *testing.T) {
|
||||
// When tasksMode is not set the field stays empty, which downstream
|
||||
// resolves to "auto" via tools.ParseTaskMode.
|
||||
jsonData := `{"type":"remote","url":"https://x.example"}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "" {
|
||||
t.Errorf("expected default TasksMode to be empty, got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_TasksMode(t *testing.T) {
|
||||
t.Run("empty is valid", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example"},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("empty TasksMode should validate, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("known values are valid", func(t *testing.T) {
|
||||
for _, mode := range []string{"auto", "never", "always", "AUTO", " always "} {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example", TasksMode: mode},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("TasksMode=%q should validate, got %v", mode, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("typo is rejected with a clear error", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"buildbot": {Type: "remote", URL: "https://x.example", TasksMode: "alwasy"},
|
||||
},
|
||||
}
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for invalid TasksMode")
|
||||
}
|
||||
// Error must mention the server name AND the bad value so the
|
||||
// user knows where to look.
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "buildbot") || !strings.Contains(msg, `"alwasy"`) {
|
||||
t.Errorf("error %q should mention both server name and bad value", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,32 +7,48 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// LoadAndValidateConfig loads configuration from viper, fixes environment variable
|
||||
// casing issues, and validates the configuration. Returns an error if loading or
|
||||
// validation fails.
|
||||
// LoadAndValidateConfig loads configuration from the process-global viper
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. Returns an error if loading or validation fails.
|
||||
//
|
||||
// This is a convenience wrapper around [LoadAndValidateConfigFrom] using the
|
||||
// shared global store; it is retained for the CLI and other callers that rely
|
||||
// on viper's process-global state.
|
||||
func LoadAndValidateConfig() (*Config, error) {
|
||||
return LoadAndValidateConfigFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// LoadAndValidateConfigFrom loads configuration from the supplied per-instance
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. When v is nil, the process-global store is used. Threading an
|
||||
// explicit store lets each Kit instance own an isolated configuration without
|
||||
// clobbering other instances in the same process.
|
||||
func LoadAndValidateConfigFrom(v *viper.Viper) (*Config, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
config := &Config{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}
|
||||
if err := viper.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
if err := v.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Fix environment variable case sensitivity issue
|
||||
// Viper lowercases all keys, but we need to preserve the original case for environment variables
|
||||
fixEnvironmentCase(config)
|
||||
fixEnvironmentCase(v, config)
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// fixEnvironmentCase fixes the case of environment variable keys that were lowercased by Viper
|
||||
func fixEnvironmentCase(config *Config) {
|
||||
func fixEnvironmentCase(v *viper.Viper, config *Config) {
|
||||
// Get the raw config data from viper
|
||||
rawConfig := viper.AllSettings()
|
||||
rawConfig := v.AllSettings()
|
||||
|
||||
// Check if we have mcpServers in the raw config
|
||||
if mcpServersRaw, ok := rawConfig["mcpservers"]; ok {
|
||||
|
||||
+167
-6
@@ -19,10 +19,18 @@ import (
|
||||
// It receives tool call ID, tool name, output chunk, and whether it's stderr.
|
||||
type ToolOutputCallback func(toolCallID, toolName, chunk string, isStderr bool)
|
||||
|
||||
// PasswordPromptCallback is the signature for password prompts.
|
||||
// It receives a prompt message and returns the password and whether it was cancelled.
|
||||
type PasswordPromptCallback func(prompt string) (password string, cancelled bool)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions.
|
||||
type contextKey string
|
||||
|
||||
const toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
const (
|
||||
toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
sudoPasswordKey contextKey = "sudoPassword"
|
||||
passwordPromptKey contextKey = "passwordPrompt"
|
||||
)
|
||||
|
||||
// ContextWithToolOutputCallback returns a new context with the tool output callback set.
|
||||
func ContextWithToolOutputCallback(ctx context.Context, callback ToolOutputCallback) context.Context {
|
||||
@@ -37,6 +45,34 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithPasswordPrompt returns a new context with the password prompt callback set.
|
||||
// This allows the TUI to show a modal password prompt when sudo needs a password.
|
||||
func ContextWithPasswordPrompt(ctx context.Context, callback PasswordPromptCallback) context.Context {
|
||||
return context.WithValue(ctx, passwordPromptKey, callback)
|
||||
}
|
||||
|
||||
// passwordPromptFromContext retrieves the password prompt callback from context.
|
||||
func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
if cb, ok := ctx.Value(passwordPromptKey).(PasswordPromptCallback); ok {
|
||||
return cb
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
return pw
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
@@ -73,6 +109,57 @@ func NewBashTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// sudoCommandRe matches sudo commands that need to be rewritten for -S mode.
|
||||
// It matches "sudo" as a word boundary, optionally preceded by environment variables.
|
||||
var sudoCommandRe = regexp.MustCompile(`(?i)(^|[&|;|]|\|\||&&)\s*(\w+=\S+\s+)?\bsudo\b`)
|
||||
|
||||
// truncateCommand truncates a long command for display.
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// rewriteSudoForStdin rewrites sudo commands to use -S -p ” for stdin password input.
|
||||
// It transforms: sudo cmd → sudo -S -p ” cmd
|
||||
func rewriteSudoForStdin(command string) string {
|
||||
// Find all matches and their positions
|
||||
matches := sudoCommandRe.FindAllStringIndex(command, -1)
|
||||
if matches == nil {
|
||||
return command
|
||||
}
|
||||
|
||||
// Build result from end to start to preserve indices
|
||||
result := command
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
match := matches[i]
|
||||
start, end := match[0], match[1]
|
||||
matchedText := result[start:end]
|
||||
|
||||
// Extract just the "sudo" part (after any prefix)
|
||||
sudoIdx := strings.Index(strings.ToLower(matchedText), "sudo")
|
||||
if sudoIdx == -1 {
|
||||
continue
|
||||
}
|
||||
prefix := matchedText[:sudoIdx]
|
||||
sudoPart := matchedText[sudoIdx:]
|
||||
|
||||
// Check if the text immediately after "sudo" in the result contains -S
|
||||
afterSudo := result[end:]
|
||||
if strings.HasPrefix(strings.TrimLeft(afterSudo, " \t"), "-S") {
|
||||
// Already has -S flag, skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert -S -p '' after "sudo"
|
||||
newSudo := strings.Replace(sudoPart, "sudo", "sudo -S -p ''", 1)
|
||||
result = result[:start] + prefix + newSudo + result[end:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
@@ -97,7 +184,47 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", args.Command)
|
||||
// Check for sudo password in context or environment
|
||||
sudoPassword := sudoPasswordFromContext(ctx)
|
||||
if sudoPassword == "" {
|
||||
sudoPassword = os.Getenv("SUDO_PASSWORD")
|
||||
}
|
||||
command := args.Command
|
||||
|
||||
// If command contains sudo and we don't have a password, check if sudo needs one
|
||||
if sudoPassword == "" && sudoCommandRe.MatchString(command) {
|
||||
// Check if sudo credentials are cached using sudo -n (non-interactive)
|
||||
testCmd := exec.CommandContext(cmdCtx, "sudo", "-n", "true")
|
||||
testCmd.Dir = workDir
|
||||
if err := testCmd.Run(); err != nil {
|
||||
// Sudo needs a password - try to prompt via callback
|
||||
if promptCallback := passwordPromptFromContext(ctx); promptCallback != nil {
|
||||
pw, cancelled := promptCallback("Sudo password required for: " + truncateCommand(args.Command, 60))
|
||||
if cancelled {
|
||||
return fantasy.NewTextErrorResponse("sudo password prompt cancelled"), nil
|
||||
}
|
||||
if pw == "" {
|
||||
return fantasy.NewTextErrorResponse("no sudo password provided"), nil
|
||||
}
|
||||
sudoPassword = pw
|
||||
command = rewriteSudoForStdin(command)
|
||||
} else {
|
||||
// No callback available - return error with helpful message
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"This command requires sudo access. " +
|
||||
"Please run 'sudo -v' in your terminal first to cache credentials, " +
|
||||
"or set the SUDO_PASSWORD environment variable."), nil
|
||||
}
|
||||
}
|
||||
// Credentials are cached or password was provided, proceed
|
||||
}
|
||||
|
||||
// If we have a sudo password, rewrite the command to use sudo -S
|
||||
if sudoPassword != "" && sudoCommandRe.MatchString(command) {
|
||||
command = rewriteSudoForStdin(command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", command)
|
||||
if workDir != "" {
|
||||
cmd.Dir = workDir
|
||||
}
|
||||
@@ -115,18 +242,18 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
if outputCallback != nil {
|
||||
// Streaming mode: use pipes to capture output as it arrives
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback)
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback, sudoPassword)
|
||||
}
|
||||
|
||||
// Non-streaming mode: collect all output at once (original behavior)
|
||||
return executeBashBuffered(cmdCtx, call, cmd)
|
||||
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd) (fantasy.ToolResponse, error) {
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -136,10 +263,27 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe and write the password
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
@@ -181,7 +325,7 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
}
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback) (fantasy.ToolResponse, error) {
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -191,11 +335,28 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
@@ -127,3 +127,72 @@ func TestBash_EmptyCommand(t *testing.T) {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteSudoForStdin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple sudo",
|
||||
input: "sudo apt update",
|
||||
expected: "sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo with env var",
|
||||
input: "DEBIAN_FRONTEND=noninteractive sudo apt update",
|
||||
expected: "DEBIAN_FRONTEND=noninteractive sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo in pipeline",
|
||||
input: "echo test | sudo tee /etc/test.conf",
|
||||
expected: "echo test | sudo -S -p '' tee /etc/test.conf",
|
||||
},
|
||||
{
|
||||
name: "sudo after &&",
|
||||
input: "apt update && sudo apt upgrade",
|
||||
expected: "apt update && sudo -S -p '' apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "already has -S flag",
|
||||
input: "sudo -S apt update",
|
||||
expected: "sudo -S apt update",
|
||||
},
|
||||
{
|
||||
name: "no sudo",
|
||||
input: "apt update && apt upgrade",
|
||||
expected: "apt update && apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "sudo in string (should not match)",
|
||||
input: "echo 'use sudo carefully'",
|
||||
expected: "echo 'use sudo carefully'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := rewriteSudoForStdin(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("rewriteSudoForStdin(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
}
|
||||
|
||||
// Test without password
|
||||
ctx = context.Background()
|
||||
pw = sudoPasswordFromContext(ctx)
|
||||
if pw != "" {
|
||||
t.Errorf("expected empty password, got %q", pw)
|
||||
}
|
||||
}
|
||||
|
||||
+6
-42
@@ -21,12 +21,9 @@ type Edit struct {
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
Path string `json:"path"`
|
||||
Edits []Edit `json:"edits"`
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
@@ -52,20 +49,12 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Description: "Edit a file by replacing exact text. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)",
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
@@ -85,7 +74,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
Required: []string{"path", "edits"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -163,36 +152,11 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
if len(args.Edits) == 0 {
|
||||
return nil, fmt.Errorf("edits array is required and must not be empty")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
|
||||
+62
-44
@@ -389,9 +389,11 @@ func TestExecuteEdit_ExactMatch(t *testing.T) {
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -426,9 +428,11 @@ func TestExecuteEdit_ExactMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
target := lines[49]
|
||||
replacement := "REPLACED_LINE_50"
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -470,9 +474,11 @@ func TestExecuteEdit_FuzzyMatch_TrailingWhitespace(t *testing.T) {
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -519,9 +525,11 @@ func TestExecuteEdit_FuzzyMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
search := strings.Repeat("x", 10) + "\n" + strings.Repeat("x", 10)
|
||||
// But this matches lines 1-2, 2-3, etc. — should fail due to ambiguity.
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -546,9 +554,11 @@ func TestExecuteEdit_MultipleMatches_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -575,9 +585,11 @@ func TestExecuteEdit_NoMatch_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -601,9 +613,11 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
writeFileOrFail(t, path, "line1\r\nline2\r\nline3\r\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -622,8 +636,10 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -636,9 +652,11 @@ func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_NonexistentFile(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -661,9 +679,11 @@ func TestExecuteEdit_DiffContainsHunkHeader(t *testing.T) {
|
||||
writeFileOrFail(t, path, strings.Join(lines, "\n")+"\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -684,9 +704,11 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
writeFileOrFail(t, path, "old content\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -905,18 +927,14 @@ func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
func TestExecuteEdit_EmptyEditsArray_Fails(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -924,10 +942,10 @@ func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
if !strings.Contains(resp.Content, "required") {
|
||||
t.Errorf("expected 'required' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ Example use cases:
|
||||
},
|
||||
"model": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional model override (e.g. 'anthropic/claude-haiku-3-5-20241022' for faster/cheaper tasks)",
|
||||
"description": "Optional model override. Empty string uses the current model.",
|
||||
},
|
||||
"system_prompt": map[string]any{
|
||||
"type": "string",
|
||||
@@ -94,7 +94,7 @@ Example use cases:
|
||||
},
|
||||
"timeout_seconds": map[string]any{
|
||||
"type": "number",
|
||||
"description": "Maximum execution time in seconds (default: 300, max: 1800)",
|
||||
"description": "Maximum execution time in seconds (default: 300, max: 1800, minimum recommended: 240)",
|
||||
},
|
||||
},
|
||||
Required: []string{"task"},
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// Package extbridge wires the public Kit SDK to the internal extensions
|
||||
// package. It exists so that cmd/ and internal/acpserver/ don't both
|
||||
// reimplement the same SDK→extension event/subagent conversions.
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
|
||||
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
|
||||
// (Type=="") for events that don't map to anything useful — callers should
|
||||
// drop those.
|
||||
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
|
||||
// the result/events back into the extension-facing types. The returned
|
||||
// handle is always nil — the SDK path runs synchronously and does not
|
||||
// expose a separate process handle. Callers that need non-blocking
|
||||
// behaviour should run this in their own goroutine.
|
||||
//
|
||||
// This function consolidates the previously-duplicated wiring in
|
||||
// cmd/root.go (interactive + runtime contexts) and
|
||||
// internal/acpserver/session.go.
|
||||
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: cfg.Prompt,
|
||||
Model: cfg.Model,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := SDKEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
cfg.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := k.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
}
|
||||
+202
-1
@@ -918,7 +918,7 @@ type ExtensionEntry struct {
|
||||
type ContextMessage struct {
|
||||
// Index is the position of this message in the original context array
|
||||
// (0-based). When returning messages from a ContextPrepareResult,
|
||||
// messages with Index >= 0 reuse the original fantasy.Message at that
|
||||
// messages with Index >= 0 reuse the original LLM message at that
|
||||
// position (preserving tool calls, reasoning, and other complex parts).
|
||||
// Set Index to -1 for newly injected messages (created from Role + Content).
|
||||
Index int
|
||||
@@ -1063,6 +1063,9 @@ type PrintBlockOpts struct {
|
||||
type API struct {
|
||||
// Event-specific registration functions (wired by the loader).
|
||||
onToolCall func(func(ToolCallEvent, Context) *ToolCallResult)
|
||||
onToolCallInputStart func(func(ToolCallInputStartEvent, Context))
|
||||
onToolCallInputDelta func(func(ToolCallInputDeltaEvent, Context))
|
||||
onToolCallInputEnd func(func(ToolCallInputEndEvent, Context))
|
||||
onToolExecStart func(func(ToolExecutionStartEvent, Context))
|
||||
onToolExecEnd func(func(ToolExecutionEndEvent, Context))
|
||||
onToolOutput func(func(ToolOutputEvent, Context))
|
||||
@@ -1091,6 +1094,14 @@ type API struct {
|
||||
onSubagentStart func(func(SubagentStartEvent, Context))
|
||||
onSubagentChunk func(func(SubagentChunkEvent, Context))
|
||||
onSubagentEnd func(func(SubagentEndEvent, Context))
|
||||
onStepStart func(func(StepStartEvent, Context))
|
||||
onStepFinish func(func(StepFinishEvent, Context))
|
||||
onReasoningStart func(func(ReasoningStartEvent, Context))
|
||||
onWarnings func(func(WarningsEvent, Context))
|
||||
onSource func(func(SourceEvent, Context))
|
||||
onError func(func(ErrorEvent, Context))
|
||||
onRetry func(func(RetryEvent, Context))
|
||||
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -1099,6 +1110,26 @@ func (a *API) OnToolCall(handler func(ToolCallEvent, Context) *ToolCallResult) {
|
||||
a.onToolCall(handler)
|
||||
}
|
||||
|
||||
// OnToolCallInputStart registers a handler that fires when the LLM begins
|
||||
// generating tool call arguments. The tool name is known but the full
|
||||
// argument JSON is still being streamed. Useful for showing a "running"
|
||||
// indicator immediately without waiting for the full arguments.
|
||||
func (a *API) OnToolCallInputStart(handler func(ToolCallInputStartEvent, Context)) {
|
||||
a.onToolCallInputStart(handler)
|
||||
}
|
||||
|
||||
// OnToolCallInputDelta registers a handler that fires for each streamed
|
||||
// fragment of tool call arguments as they arrive from the LLM.
|
||||
func (a *API) OnToolCallInputDelta(handler func(ToolCallInputDeltaEvent, Context)) {
|
||||
a.onToolCallInputDelta(handler)
|
||||
}
|
||||
|
||||
// OnToolCallInputEnd registers a handler that fires when tool argument
|
||||
// streaming is complete, before the tool call is parsed and execution begins.
|
||||
func (a *API) OnToolCallInputEnd(handler func(ToolCallInputEndEvent, Context)) {
|
||||
a.onToolCallInputEnd(handler)
|
||||
}
|
||||
|
||||
// OnToolExecutionStart registers a handler for tool execution start.
|
||||
func (a *API) OnToolExecutionStart(handler func(ToolExecutionStartEvent, Context)) {
|
||||
a.onToolExecStart(handler)
|
||||
@@ -1278,6 +1309,56 @@ func (a *API) OnBeforeCompact(handler func(BeforeCompactEvent, Context) *BeforeC
|
||||
a.onBeforeCompact(handler)
|
||||
}
|
||||
|
||||
// OnStepStart registers a handler that fires when a new LLM call begins
|
||||
// within a multi-step agent turn.
|
||||
func (a *API) OnStepStart(handler func(StepStartEvent, Context)) {
|
||||
a.onStepStart(handler)
|
||||
}
|
||||
|
||||
// OnStepFinish registers a handler that fires when a step completes,
|
||||
// providing step number, finish reason, and decomposed token usage.
|
||||
func (a *API) OnStepFinish(handler func(StepFinishEvent, Context)) {
|
||||
a.onStepFinish(handler)
|
||||
}
|
||||
|
||||
// OnReasoningStart registers a handler that fires when the LLM begins
|
||||
// reasoning/thinking.
|
||||
func (a *API) OnReasoningStart(handler func(ReasoningStartEvent, Context)) {
|
||||
a.onReasoningStart(handler)
|
||||
}
|
||||
|
||||
// OnWarnings registers a handler that fires when the LLM provider returns
|
||||
// warnings about the request.
|
||||
func (a *API) OnWarnings(handler func(WarningsEvent, Context)) {
|
||||
a.onWarnings(handler)
|
||||
}
|
||||
|
||||
// OnSource registers a handler that fires when the LLM references a source
|
||||
// (e.g. from web search tools).
|
||||
func (a *API) OnSource(handler func(SourceEvent, Context)) {
|
||||
a.onSource(handler)
|
||||
}
|
||||
|
||||
// OnError registers a handler that fires when an agent-level error occurs
|
||||
// during streaming.
|
||||
func (a *API) OnError(handler func(ErrorEvent, Context)) {
|
||||
a.onError(handler)
|
||||
}
|
||||
|
||||
// OnRetry registers a handler that fires when the LLM provider request is
|
||||
// retried after a transient error.
|
||||
func (a *API) OnRetry(handler func(RetryEvent, Context)) {
|
||||
a.onRetry(handler)
|
||||
}
|
||||
|
||||
// OnPrepareStep registers a handler that fires between steps within a
|
||||
// multi-step agent turn, after steering messages are injected and before
|
||||
// messages are sent to the LLM. Return a non-nil PrepareStepResult with
|
||||
// Messages to replace the context window for this step.
|
||||
func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStepResult) {
|
||||
a.onPrepareStep(handler)
|
||||
}
|
||||
|
||||
// RegisterToolRenderer registers a custom renderer for a specific tool's
|
||||
// display in the TUI. The renderer controls the header (parameter summary)
|
||||
// and/or body (result display) of the tool's output block. If multiple
|
||||
@@ -1890,6 +1971,34 @@ type ToolCallResult struct {
|
||||
|
||||
func (ToolCallResult) isResult() {}
|
||||
|
||||
// ToolCallInputStartEvent fires when the LLM begins generating tool call
|
||||
// arguments. The tool name is known but the full argument JSON is still
|
||||
// being streamed.
|
||||
type ToolCallInputStartEvent struct {
|
||||
ToolCallID string
|
||||
ToolName string
|
||||
ToolKind string // Tool classification: "execute", "edit", "read", "search", "agent"
|
||||
}
|
||||
|
||||
func (e ToolCallInputStartEvent) Type() EventType { return ToolCallInputStart }
|
||||
|
||||
// ToolCallInputDeltaEvent fires for each streamed fragment of tool call
|
||||
// arguments as they arrive from the LLM.
|
||||
type ToolCallInputDeltaEvent struct {
|
||||
ToolCallID string
|
||||
Delta string // JSON fragment of tool arguments
|
||||
}
|
||||
|
||||
func (e ToolCallInputDeltaEvent) Type() EventType { return ToolCallInputDelta }
|
||||
|
||||
// ToolCallInputEndEvent fires when tool argument streaming is complete,
|
||||
// before the tool call is parsed and execution begins.
|
||||
type ToolCallInputEndEvent struct {
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
func (e ToolCallInputEndEvent) Type() EventType { return ToolCallInputEnd }
|
||||
|
||||
// ToolExecutionStartEvent fires when a tool begins executing.
|
||||
type ToolExecutionStartEvent struct {
|
||||
ToolCallID string
|
||||
@@ -2202,6 +2311,98 @@ type SubagentEndEvent struct {
|
||||
|
||||
func (e SubagentEndEvent) Type() EventType { return SubagentEnd }
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Step lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// StepStartEvent fires when a new LLM call begins within a multi-step agent turn.
|
||||
type StepStartEvent struct {
|
||||
StepNumber int
|
||||
}
|
||||
|
||||
func (e StepStartEvent) Type() EventType { return StepStart }
|
||||
|
||||
// StepFinishEvent fires when a step completes, providing step metadata and
|
||||
// token usage. Usage fields are plain int64 (not LLMUsage) because Yaegi
|
||||
// cannot handle fantasy types across the interpreter boundary.
|
||||
type StepFinishEvent struct {
|
||||
StepNumber int
|
||||
HasToolCalls bool
|
||||
FinishReason string
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheReadTokens int64
|
||||
CacheWriteTokens int64
|
||||
}
|
||||
|
||||
func (e StepFinishEvent) Type() EventType { return StepFinish }
|
||||
|
||||
// ReasoningStartEvent fires when the LLM begins reasoning/thinking.
|
||||
type ReasoningStartEvent struct {
|
||||
ID string
|
||||
}
|
||||
|
||||
func (e ReasoningStartEvent) Type() EventType { return ReasoningStart }
|
||||
|
||||
// WarningsEvent fires when the LLM provider returns warnings about the request.
|
||||
type WarningsEvent struct {
|
||||
Warnings []string
|
||||
}
|
||||
|
||||
func (e WarningsEvent) Type() EventType { return Warnings }
|
||||
|
||||
// SourceEvent fires when the LLM references a source (e.g. from web search).
|
||||
type SourceEvent struct {
|
||||
SourceType string
|
||||
ID string
|
||||
URL string
|
||||
Title string
|
||||
}
|
||||
|
||||
func (e SourceEvent) Type() EventType { return Source }
|
||||
|
||||
// ErrorEvent fires when an agent-level error occurs during streaming.
|
||||
// Uses string instead of error because Yaegi cannot handle the error
|
||||
// interface reliably across the interpreter boundary.
|
||||
type ErrorEvent struct {
|
||||
Error string
|
||||
}
|
||||
|
||||
func (e ErrorEvent) Type() EventType { return Error }
|
||||
|
||||
// RetryEvent fires when the LLM provider request is retried after a
|
||||
// transient error.
|
||||
type RetryEvent struct {
|
||||
Attempt int
|
||||
Error string
|
||||
}
|
||||
|
||||
func (e RetryEvent) Type() EventType { return Retry }
|
||||
|
||||
// PrepareStepEvent fires between steps within a multi-step agent turn,
|
||||
// after steering messages are injected and before messages are sent to
|
||||
// the LLM. Handlers can inspect and replace the context window.
|
||||
type PrepareStepEvent struct {
|
||||
// StepNumber is the zero-based step index within the current turn.
|
||||
StepNumber int
|
||||
// Messages is the current context window that will be sent to the LLM.
|
||||
Messages []ContextMessage
|
||||
}
|
||||
|
||||
func (e PrepareStepEvent) Type() EventType { return PrepareStep }
|
||||
|
||||
// PrepareStepResult allows extensions to replace the context window between
|
||||
// steps. Return nil Messages to leave the context unchanged.
|
||||
type PrepareStepResult struct {
|
||||
// Messages replaces the entire context window for this step. If nil,
|
||||
// the original messages are used unchanged. Messages with a non-negative
|
||||
// Index reuse the original message at that position; messages with
|
||||
// Index < 0 are created fresh from Role + Content.
|
||||
Messages []ContextMessage
|
||||
}
|
||||
|
||||
func (PrepareStepResult) isResult() {}
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -13,6 +13,19 @@ const (
|
||||
// ToolCall fires before a tool executes. Handlers can block execution.
|
||||
ToolCall EventType = "tool_call"
|
||||
|
||||
// ToolCallInputStart fires when the LLM begins generating tool call
|
||||
// arguments. The tool name is known but the full argument JSON is still
|
||||
// being streamed.
|
||||
ToolCallInputStart EventType = "tool_call_input_start"
|
||||
|
||||
// ToolCallInputDelta fires for each streamed fragment of tool call
|
||||
// arguments as they arrive from the LLM.
|
||||
ToolCallInputDelta EventType = "tool_call_input_delta"
|
||||
|
||||
// ToolCallInputEnd fires when tool argument streaming is complete,
|
||||
// before the tool call is parsed and execution begins.
|
||||
ToolCallInputEnd EventType = "tool_call_input_end"
|
||||
|
||||
// ToolExecutionStart fires when a tool begins executing.
|
||||
ToolExecutionStart EventType = "tool_execution_start"
|
||||
|
||||
@@ -83,18 +96,50 @@ const (
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
|
||||
// StepStart fires when a new LLM call begins within a multi-step
|
||||
// agent turn.
|
||||
StepStart EventType = "step_start"
|
||||
|
||||
// StepFinish fires when a step completes, providing step number,
|
||||
// finish reason, and token usage.
|
||||
StepFinish EventType = "step_finish"
|
||||
|
||||
// ReasoningStart fires when the LLM begins reasoning/thinking.
|
||||
ReasoningStart EventType = "reasoning_start"
|
||||
|
||||
// Warnings fires when the LLM provider returns warnings.
|
||||
Warnings EventType = "warnings"
|
||||
|
||||
// Source fires when the LLM references a source (e.g. web search).
|
||||
Source EventType = "source"
|
||||
|
||||
// Error fires when an agent-level error occurs during streaming.
|
||||
Error EventType = "error"
|
||||
|
||||
// Retry fires when the LLM provider request is retried after a
|
||||
// transient error.
|
||||
Retry EventType = "retry"
|
||||
|
||||
// PrepareStep fires between steps within a multi-step agent turn,
|
||||
// after steering messages are injected and before messages are sent
|
||||
// to the LLM. Handlers can replace the context window for this step.
|
||||
PrepareStep EventType = "prepare_step"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
func AllEventTypes() []EventType {
|
||||
return []EventType{
|
||||
ToolCall, ToolExecutionStart, ToolExecutionEnd, ToolResult,
|
||||
ToolCall, ToolCallInputStart, ToolCallInputDelta, ToolCallInputEnd,
|
||||
ToolExecutionStart, ToolExecutionEnd, ToolResult,
|
||||
Input, BeforeAgentStart, AgentStart, AgentEnd,
|
||||
MessageStart, MessageUpdate, MessageEnd,
|
||||
SessionStart, SessionShutdown,
|
||||
ModelChange, ContextPrepare,
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
|
||||
PrepareStep,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 21 {
|
||||
t.Fatalf("expected 21 event types, got %d", len(all))
|
||||
if len(all) != 32 {
|
||||
t.Fatalf("expected 32 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,9 @@ func TestEventType_TypeMethod(t *testing.T) {
|
||||
want EventType
|
||||
}{
|
||||
{ToolCallEvent{ToolName: "test"}, ToolCall},
|
||||
{ToolCallInputStartEvent{ToolCallID: "x", ToolName: "test"}, ToolCallInputStart},
|
||||
{ToolCallInputDeltaEvent{ToolCallID: "x", Delta: "{"}, ToolCallInputDelta},
|
||||
{ToolCallInputEndEvent{ToolCallID: "x"}, ToolCallInputEnd},
|
||||
{ToolExecutionStartEvent{ToolName: "test"}, ToolExecutionStart},
|
||||
{ToolExecutionEndEvent{ToolName: "test"}, ToolExecutionEnd},
|
||||
{ToolResultEvent{ToolName: "test"}, ToolResult},
|
||||
|
||||
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
|
||||
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
|
||||
// saveManifest round-trip against a temp directory, ensuring an absent
|
||||
// manifest loads as empty and a saved manifest reads back identically.
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
err = installer.saveManifest(manifest, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
t.Fatalf("saveManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was written to expected path
|
||||
if _, err := os.Stat(manifestPath); err != nil {
|
||||
t.Fatalf("manifest file not created: %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
loaded, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
|
||||
// followed by removeFromManifest leaves the manifest in its original
|
||||
// (empty) state, using a temp-directory installer scope.
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
|
||||
t.Fatalf("addToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
|
||||
t.Fatalf("removeFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
manifest, err = installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindInManifest writes a manifest file directly to the path
|
||||
// resolved by the package-level manifestPathForScope helper and then
|
||||
// confirms FindInManifest locates the entry by identity (and returns
|
||||
// nil for a non-existent identity).
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
t.Setenv("XDG_DATA_HOME", tempDir)
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
// Write a manifest entry directly via the package-level path resolver
|
||||
// so FindInManifest (which uses manifestPathForScope) can read it back.
|
||||
manifestPath := manifestPathForScope(ScopeGlobal)
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
|
||||
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
|
||||
@@ -429,6 +429,24 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onToolCallInputStart: func(h func(ToolCallInputStartEvent, Context)) {
|
||||
reg(ToolCallInputStart, func(e Event, c Context) Result {
|
||||
h(e.(ToolCallInputStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolCallInputDelta: func(h func(ToolCallInputDeltaEvent, Context)) {
|
||||
reg(ToolCallInputDelta, func(e Event, c Context) Result {
|
||||
h(e.(ToolCallInputDeltaEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolCallInputEnd: func(h func(ToolCallInputEndEvent, Context)) {
|
||||
reg(ToolCallInputEnd, func(e Event, c Context) Result {
|
||||
h(e.(ToolCallInputEndEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onToolExecStart: func(h func(ToolExecutionStartEvent, Context)) {
|
||||
reg(ToolExecutionStart, func(e Event, c Context) Result {
|
||||
h(e.(ToolExecutionStartEvent), c)
|
||||
@@ -600,6 +618,57 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onStepStart: func(h func(StepStartEvent, Context)) {
|
||||
reg(StepStart, func(e Event, c Context) Result {
|
||||
h(e.(StepStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onStepFinish: func(h func(StepFinishEvent, Context)) {
|
||||
reg(StepFinish, func(e Event, c Context) Result {
|
||||
h(e.(StepFinishEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onReasoningStart: func(h func(ReasoningStartEvent, Context)) {
|
||||
reg(ReasoningStart, func(e Event, c Context) Result {
|
||||
h(e.(ReasoningStartEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onWarnings: func(h func(WarningsEvent, Context)) {
|
||||
reg(Warnings, func(e Event, c Context) Result {
|
||||
h(e.(WarningsEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onSource: func(h func(SourceEvent, Context)) {
|
||||
reg(Source, func(e Event, c Context) Result {
|
||||
h(e.(SourceEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onError: func(h func(ErrorEvent, Context)) {
|
||||
reg(Error, func(e Event, c Context) Result {
|
||||
h(e.(ErrorEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onRetry: func(h func(RetryEvent, Context)) {
|
||||
reg(Retry, func(e Event, c Context) Result {
|
||||
h(e.(RetryEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onPrepareStep: func(h func(PrepareStepEvent, Context) *PrepareStepResult) {
|
||||
reg(PrepareStep, func(e Event, c Context) Result {
|
||||
r := h(e.(PrepareStepEvent), c)
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return *r
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
|
||||
@@ -1,21 +1,93 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// reentrantMu — a per-extension mutex that allows the same goroutine to
|
||||
// re-enter (e.g. handler → ctx.EmitCustomEvent → handler in same extension).
|
||||
// Different goroutines are serialized, preventing concurrent state mutation.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type reentrantMu struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
owner int64 // goroutine ID that holds the lock, or 0
|
||||
depth int // re-entrancy depth
|
||||
}
|
||||
|
||||
// initReentrantMu initializes the reentrant mutex in-place. Must be called
|
||||
// after the struct is at its final memory location (not before copying).
|
||||
func (r *reentrantMu) init() {
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
}
|
||||
|
||||
// lock acquires the mutex. If the calling goroutine already holds it, the
|
||||
// call succeeds immediately (re-entrant). Every call to lock must be paired
|
||||
// with a call to unlock.
|
||||
func (r *reentrantMu) lock() {
|
||||
gid := goroutineID()
|
||||
r.mu.Lock()
|
||||
if r.owner == gid {
|
||||
// Re-entrant: same goroutine already holds the lock.
|
||||
r.depth++
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Wait for the current owner to release.
|
||||
for r.owner != 0 {
|
||||
r.cond.Wait() // releases mu, blocks, re-acquires mu on wake
|
||||
}
|
||||
r.owner = gid
|
||||
r.depth = 1
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// unlock releases the mutex (or decrements re-entrancy depth).
|
||||
func (r *reentrantMu) unlock() {
|
||||
r.mu.Lock()
|
||||
r.depth--
|
||||
if r.depth == 0 {
|
||||
r.owner = 0
|
||||
r.cond.Signal()
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// goroutineID extracts the current goroutine's ID from runtime.Stack output.
|
||||
// This is a well-known technique used by Go testing infrastructure.
|
||||
func goroutineID() int64 {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
// Stack output starts with "goroutine NNN ["
|
||||
s := buf[:n]
|
||||
s = s[len("goroutine "):]
|
||||
s = s[:bytes.IndexByte(s, ' ')]
|
||||
id, _ := strconv.ParseInt(string(s), 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
// Runner manages loaded extensions and dispatches events to their handlers
|
||||
// sequentially. Handlers execute in extension
|
||||
// load order; for cancellable events the first blocking result wins.
|
||||
//
|
||||
// Each extension has a dedicated reentrant mutex so that handlers for the
|
||||
// same extension are serialized (preventing data races on shared package-level
|
||||
// state), while handlers for different extensions may execute concurrently.
|
||||
type Runner struct {
|
||||
extensions []LoadedExtension
|
||||
extMu []reentrantMu // per-extension reentrant mutex, indexed by extension position
|
||||
ctx Context
|
||||
widgets map[string]WidgetConfig // keyed by widget ID
|
||||
statusEntries map[string]StatusBarEntry // keyed by status key
|
||||
@@ -26,9 +98,20 @@ type Runner struct {
|
||||
disabledTools map[string]bool // nil = all tools enabled
|
||||
customEventSubs map[string][]func(string) // inter-extension event bus
|
||||
optionOverrides map[string]string // runtime option overrides
|
||||
configStore *viper.Viper // per-instance config store (nil = global)
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetConfigStore sets the per-instance configuration store used by GetOption
|
||||
// to resolve "options.<name>" config values. When unset (nil), GetOption falls
|
||||
// back to the process-global viper store. Threading a per-Kit store keeps
|
||||
// extension option resolution isolated between Kit instances.
|
||||
func (r *Runner) SetConfigStore(v *viper.Viper) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.configStore = v
|
||||
}
|
||||
|
||||
// ShortcutEntry pairs a shortcut definition with its handler.
|
||||
type ShortcutEntry struct {
|
||||
Def ShortcutDef
|
||||
@@ -52,7 +135,11 @@ type LoadedExtension struct {
|
||||
|
||||
// NewRunner creates a Runner from a set of loaded extensions.
|
||||
func NewRunner(exts []LoadedExtension) *Runner {
|
||||
return &Runner{extensions: exts}
|
||||
mus := make([]reentrantMu, len(exts))
|
||||
for i := range mus {
|
||||
mus[i].init()
|
||||
}
|
||||
return &Runner{extensions: exts, extMu: mus}
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
@@ -367,6 +454,11 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
for i := range r.extensions {
|
||||
ext := &r.extensions[i]
|
||||
handlers := ext.Handlers[event.Type()]
|
||||
if len(handlers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r.extMu[i].lock()
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
@@ -379,6 +471,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
|
||||
// Check for blocking/short-circuit results.
|
||||
if isBlocking(result) {
|
||||
r.extMu[i].unlock()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -386,6 +479,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
// the caller is responsible for applying the modifications.
|
||||
accumulated = result
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
return accumulated, nil
|
||||
}
|
||||
@@ -712,11 +806,17 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
|
||||
// Extension-registered handlers first (in load order).
|
||||
for i := range r.extensions {
|
||||
for _, h := range r.extensions[i].CustomEventHandlers[name] {
|
||||
extHandlers := r.extensions[i].CustomEventHandlers[name]
|
||||
if len(extHandlers) == 0 {
|
||||
continue
|
||||
}
|
||||
r.extMu[i].lock()
|
||||
for _, h := range extHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
// Then dynamic subscriptions.
|
||||
// Then dynamic subscriptions (not extension-scoped, no per-ext lock).
|
||||
for _, h := range dynamicHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
@@ -783,7 +883,13 @@ func (r *Runner) GetOption(name string) string {
|
||||
|
||||
// 3. Viper config: options.<name>
|
||||
configKey := "options." + name
|
||||
if v := viper.GetString(configKey); v != "" {
|
||||
r.mu.RLock()
|
||||
store := r.configStore
|
||||
r.mu.RUnlock()
|
||||
if store == nil {
|
||||
store = viper.GetViper()
|
||||
}
|
||||
if v := store.GetString(configKey); v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -571,3 +572,142 @@ func TestRunner_ContextPrintNilSafe(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitSameExtension(t *testing.T) {
|
||||
// Verify that concurrent Emit calls for the same extension are serialized
|
||||
// and don't cause data races on shared handler state.
|
||||
var counter int
|
||||
ext := makeHandlerExt("shared-state.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// Read-modify-write: racy without serialization.
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
SubagentChunk: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
_, _ = r.Emit(SubagentChunkEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if counter != goroutines*iterations*2 {
|
||||
t.Errorf("expected counter=%d, got %d (race detected)", goroutines*iterations*2, counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitDifferentExtensions(t *testing.T) {
|
||||
// Two extensions with independent state should not block each other
|
||||
// and should both run correctly under concurrent Emit calls.
|
||||
var counter1, counter2 int
|
||||
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter1
|
||||
counter1 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter2
|
||||
counter2 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext1, ext2)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
expected := goroutines * iterations
|
||||
if counter1 != expected {
|
||||
t.Errorf("ext1 counter: expected %d, got %d", expected, counter1)
|
||||
}
|
||||
if counter2 != expected {
|
||||
t.Errorf("ext2 counter: expected %d, got %d", expected, counter2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ReentrantEmitCustomEvent(t *testing.T) {
|
||||
// Verify that a handler can call EmitCustomEvent (which dispatches to
|
||||
// the same extension's custom event handlers) without deadlocking.
|
||||
var order []string
|
||||
ext := LoadedExtension{
|
||||
Path: "reentrant.go",
|
||||
Handlers: map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
// This triggers EmitCustomEvent for the same extension
|
||||
// via a direct runner call (simulating ctx.EmitCustomEvent).
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
CustomEventHandlers: map[string][]func(string){
|
||||
"test-event": {
|
||||
func(data string) {
|
||||
order = append(order, "custom:"+data)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := makeRunner(ext)
|
||||
|
||||
// Wire up the handler to call EmitCustomEvent re-entrantly.
|
||||
ext.Handlers[SessionStart] = []HandlerFunc{
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
r.EmitCustomEvent("test-event", "hello")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
r.extensions[0] = ext
|
||||
// Rebuild mutexes after modifying extensions slice.
|
||||
r.extMu = make([]reentrantMu, len(r.extensions))
|
||||
for i := range r.extMu {
|
||||
r.extMu[i].init()
|
||||
}
|
||||
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 2 || order[0] != "session_start" || order[1] != "custom:hello" {
|
||||
t.Errorf("expected [session_start, custom:hello], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,22 +2,15 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
|
||||
@@ -152,6 +152,9 @@ func Symbols() interp.Exports {
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
"ToolCallInputStartEvent": reflect.ValueOf((*ToolCallInputStartEvent)(nil)),
|
||||
"ToolCallInputDeltaEvent": reflect.ValueOf((*ToolCallInputDeltaEvent)(nil)),
|
||||
"ToolCallInputEndEvent": reflect.ValueOf((*ToolCallInputEndEvent)(nil)),
|
||||
"ToolExecutionStartEvent": reflect.ValueOf((*ToolExecutionStartEvent)(nil)),
|
||||
"ToolExecutionEndEvent": reflect.ValueOf((*ToolExecutionEndEvent)(nil)),
|
||||
"ToolOutputEvent": reflect.ValueOf((*ToolOutputEvent)(nil)),
|
||||
@@ -169,6 +172,17 @@ func Symbols() interp.Exports {
|
||||
"SessionStartEvent": reflect.ValueOf((*SessionStartEvent)(nil)),
|
||||
"SessionShutdownEvent": reflect.ValueOf((*SessionShutdownEvent)(nil)),
|
||||
"ModelChangeEvent": reflect.ValueOf((*ModelChangeEvent)(nil)),
|
||||
|
||||
// Step lifecycle events
|
||||
"StepStartEvent": reflect.ValueOf((*StepStartEvent)(nil)),
|
||||
"StepFinishEvent": reflect.ValueOf((*StepFinishEvent)(nil)),
|
||||
"ReasoningStartEvent": reflect.ValueOf((*ReasoningStartEvent)(nil)),
|
||||
"WarningsEvent": reflect.ValueOf((*WarningsEvent)(nil)),
|
||||
"SourceEvent": reflect.ValueOf((*SourceEvent)(nil)),
|
||||
"ErrorEvent": reflect.ValueOf((*ErrorEvent)(nil)),
|
||||
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
|
||||
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
|
||||
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,11 +28,11 @@ func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantas
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
|
||||
// into fantasy.AgentTool implementations so the LLM can invoke them.
|
||||
// ExtensionToolsAsLLMTools converts ToolDef values registered by extensions
|
||||
// into LLM agent tool implementations so the LLM can invoke them.
|
||||
// The runner is optional; if provided, ToolContext.OnProgress routes
|
||||
// progress messages through the runner's Print function.
|
||||
func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
tools := make([]fantasy.AgentTool, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
tools = append(tools, &extensionTool{def: def, runner: runner})
|
||||
@@ -90,8 +90,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
// 0. Check if tool is disabled via SetActiveTools.
|
||||
if w.runner.IsToolDisabled(toolName) {
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)),
|
||||
fmt.Errorf("tool %q disabled by extension", toolName)
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
@@ -111,8 +110,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
if reason == "" {
|
||||
reason = "blocked by extension"
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)),
|
||||
fmt.Errorf("tool blocked by extension: %s", reason)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("Error: %s", reason)), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,7 +152,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
|
||||
// extensionTool — wraps a ToolDef into an LLM agent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type extensionTool struct {
|
||||
@@ -182,7 +180,7 @@ func (t *extensionTool) Info() fantasy.ToolInfo {
|
||||
info.Parameters = props
|
||||
} else {
|
||||
// Schema doesn't have "properties" — use as-is (may be
|
||||
// a flat property map already matching fantasy's format).
|
||||
// a flat property map already matching the expected format).
|
||||
info.Parameters = schema
|
||||
}
|
||||
// Extract required fields if present.
|
||||
@@ -238,7 +236,7 @@ func (t *extensionTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), err
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(result), nil
|
||||
}
|
||||
|
||||
@@ -142,8 +142,8 @@ func TestWrappedTool_BlockExecution(t *testing.T) {
|
||||
if toolRan {
|
||||
t.Error("tool should not have run after block")
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("expected error from blocked tool")
|
||||
if err != nil {
|
||||
t.Error("expected nil error for blocked tool (error is conveyed via IsError response)")
|
||||
}
|
||||
if resp.IsError != true {
|
||||
t.Error("expected IsError=true from blocked response")
|
||||
@@ -192,7 +192,7 @@ func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
func TestExtensionToolsAsLLMTools(t *testing.T) {
|
||||
defs := []ToolDef{
|
||||
{
|
||||
Name: "greet",
|
||||
@@ -202,7 +202,7 @@ func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
@@ -232,10 +232,10 @@ func TestExtensionTool_Error(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
if err != nil {
|
||||
t.Error("expected nil error (error is conveyed via IsError response)")
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected IsError=true")
|
||||
@@ -259,7 +259,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
}
|
||||
|
||||
// Without runner, OnProgress is a no-op.
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -285,7 +285,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools2 := ExtensionToolsAsFantasy(defs2, runner)
|
||||
tools2 := ExtensionToolsAsLLMTools(defs2, runner)
|
||||
_, err = tools2[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -306,7 +306,7 @@ func TestExtensionTool_ExecuteWithContextPriority(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -330,7 +330,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
_, _ = tools[0].Run(ctx, fantasy.ToolCall{Input: ""})
|
||||
if !sawCancelled {
|
||||
t.Error("expected IsCancelled=true for cancelled context")
|
||||
@@ -339,7 +339,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
|
||||
func TestExtensionTool_ProviderOptions(t *testing.T) {
|
||||
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
|
||||
// Initially nil.
|
||||
opts := tools[0].ProviderOptions()
|
||||
|
||||
+70
-43
@@ -46,9 +46,9 @@ type AgentSetupOptions struct {
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
// BuildProviderConfig(). Callers (e.g. Kit.New) pre-build this from their
|
||||
// per-instance config store and pass it here, so the slow agent/MCP
|
||||
// initialisation can run without further config reads.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
@@ -72,6 +72,14 @@ type AgentSetupOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
// Viper is the per-instance configuration store. When set, it is used for
|
||||
// any fallback config reads (debug, no-extensions, max-steps, stream,
|
||||
// extension paths) and is attached to the extension runner. When nil, the
|
||||
// process-global viper store is used.
|
||||
Viper *viper.Viper
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -84,57 +92,62 @@ type AgentSetupResult struct {
|
||||
ExtRunner *extensions.Runner
|
||||
}
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the supplied viper
|
||||
// store (or the process-global store when v is nil). All entry points (root,
|
||||
// script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
func BuildProviderConfig(v *viper.Viper) (*models.ProviderConfig, string, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
systemPrompt, err := config.LoadSystemPrompt(v.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
numGPU := int32(v.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(v.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
ModelString: v.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
ProviderAPIKey: v.GetString("provider-api-key"),
|
||||
ProviderURL: v.GetString("provider-url"),
|
||||
MaxTokens: v.GetInt("max-tokens"),
|
||||
StopSequences: v.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
TLSSkipVerify: v.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(v.GetString("thinking-level")),
|
||||
ConfigStore: v,
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
if v.IsSet("temperature") {
|
||||
val := float32(v.GetFloat64("temperature"))
|
||||
cfg.Temperature = &val
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
if v.IsSet("top-p") {
|
||||
val := float32(v.GetFloat64("top-p"))
|
||||
cfg.TopP = &val
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
if v.IsSet("top-k") {
|
||||
val := int32(v.GetInt("top-k"))
|
||||
cfg.TopK = &val
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
if v.IsSet("frequency-penalty") {
|
||||
val := float32(v.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &val
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
if v.IsSet("presence-penalty") {
|
||||
val := float32(v.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &val
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
@@ -146,14 +159,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
// Resolve the config store: prefer the per-instance store, falling back to
|
||||
// the process-global store.
|
||||
v := opts.Viper
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after building the
|
||||
// per-instance store). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -161,13 +181,13 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
debugEnabled := opts.Debug || v.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || v.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
maxSteps = v.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
streamingEnabled := opts.StreamingEnabled || v.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
@@ -186,7 +206,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
extRunner, extCreationOpts, extErr = loadExtensions(v)
|
||||
if extErr != nil {
|
||||
fmt.Printf("Warning: Failed to load extensions: %v\n", extErr)
|
||||
}
|
||||
@@ -229,6 +249,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
@@ -249,9 +270,14 @@ type extensionCreationOpts struct {
|
||||
}
|
||||
|
||||
// loadExtensions discovers and loads Yaegi extensions, builds the runner,
|
||||
// and returns the tool wrapper/extra tools.
|
||||
func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
// and returns the tool wrapper/extra tools. The supplied store is used to
|
||||
// resolve the "extension" config key and is attached to the runner so
|
||||
// extension option lookups stay isolated to this Kit instance.
|
||||
func loadExtensions(v *viper.Viper) (*extensions.Runner, extensionCreationOpts, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
extraPaths := v.GetStringSlice("extension")
|
||||
loaded, err := extensions.LoadExtensions(extraPaths)
|
||||
if err != nil {
|
||||
return nil, extensionCreationOpts{}, err
|
||||
@@ -262,12 +288,13 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
}
|
||||
|
||||
runner := extensions.NewRunner(loaded)
|
||||
runner.SetConfigStore(v)
|
||||
|
||||
wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
}
|
||||
|
||||
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools(), runner)
|
||||
extTools := extensions.ExtensionToolsAsLLMTools(runner.RegisteredTools(), runner)
|
||||
|
||||
return runner, extensionCreationOpts{
|
||||
toolWrapper: wrapper,
|
||||
|
||||
@@ -325,12 +325,6 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
|
||||
// mixed TextPart and ToolCallPart content. Tool-role messages produce
|
||||
// ToolResultPart entries.
|
||||
func (m *Message) ToLLMMessages() []fantasy.Message {
|
||||
return m.ToFantasyMessages()
|
||||
}
|
||||
|
||||
// Deprecated: Use ToLLMMessages instead.
|
||||
// ToFantasyMessages converts a Message to one or more LLM message values.
|
||||
func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
switch m.Role {
|
||||
case RoleAssistant:
|
||||
var parts []fantasy.MessagePart
|
||||
@@ -431,13 +425,6 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
// FromLLMMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromLLMMessage(msg fantasy.Message) Message {
|
||||
return FromFantasyMessage(msg)
|
||||
}
|
||||
|
||||
// Deprecated: Use FromLLMMessage instead.
|
||||
// FromFantasyMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
m := Message{
|
||||
Role: MessageRole(msg.Role),
|
||||
Parts: make([]ContentPart, 0),
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNpmToWireProtocol documents the wire protocols that the auto-router
|
||||
// understands. Provider-specific bundles (azure, bedrock, vercel, openrouter,
|
||||
// google-vertex*) are intentionally absent — they have native top-level cases
|
||||
// in CreateProvider and never reach the auto-router.
|
||||
func TestNpmToWireProtocol(t *testing.T) {
|
||||
want := map[string]wireProtocol{
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
}
|
||||
for npm, wire := range want {
|
||||
if got := npmToWireProtocol[npm]; got != wire {
|
||||
t.Errorf("npmToWireProtocol[%q] = %d, want %d", npm, got, wire)
|
||||
}
|
||||
}
|
||||
|
||||
// Bundle packages must NOT be in the table (regression guard against the
|
||||
// old npmToLLMProvider map that listed 10 entries but only handled 3).
|
||||
for _, npm := range []string{
|
||||
"@ai-sdk/google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic",
|
||||
"@ai-sdk/amazon-bedrock",
|
||||
"@ai-sdk/azure",
|
||||
"@openrouter/ai-sdk-provider",
|
||||
"@ai-sdk/vercel",
|
||||
} {
|
||||
if _, ok := npmToWireProtocol[npm]; ok {
|
||||
t.Errorf("npmToWireProtocol unexpectedly contains bundle package %q", npm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestRegistry builds a registry containing a single proxy-style provider
|
||||
// ("testproxy") with the given default npm, plus one model that carries the
|
||||
// given per-model npm override.
|
||||
func newTestRegistry(api, defaultNPM, modelID, modelNPMOverride string) *ModelsRegistry {
|
||||
return &ModelsRegistry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"testproxy": {
|
||||
ID: "testproxy",
|
||||
Name: "Test Proxy",
|
||||
Env: []string{"TESTPROXY_API_KEY"},
|
||||
NPM: defaultNPM,
|
||||
API: api,
|
||||
Models: map[string]ModelInfo{
|
||||
modelID: {
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
ProviderNPM: modelNPMOverride,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_WireRouting verifies that autoRouteProvider routes each
|
||||
// npm package to the correct fantasy provider implementation. This is the core
|
||||
// regression test for issue #41: previously any npm that resolved to a
|
||||
// non-openai/anthropic/openaicompat LLM provider (notably @ai-sdk/google) hit a
|
||||
// dead `default` branch and failed with "has no LLM provider mapping".
|
||||
func TestAutoRouteProvider_WireRouting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
defaultNPM string
|
||||
overrideNPM string
|
||||
// wantType is the concrete fantasy LanguageModel type the model should
|
||||
// be routed to, identified by reflect type string.
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "openai-compatible default",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
{
|
||||
name: "anthropic override",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/anthropic",
|
||||
wantType: "anthropic.languageModel",
|
||||
},
|
||||
{
|
||||
name: "openai (responses) override",
|
||||
modelID: "gpt-4o",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/openai",
|
||||
wantType: "openai.responsesLanguageModel",
|
||||
},
|
||||
{
|
||||
// The bug: opencode's gemini-* models override the default
|
||||
// openai-compatible npm with @ai-sdk/google.
|
||||
name: "google override (issue #41)",
|
||||
modelID: "gemini-3.5-flash",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/google",
|
||||
wantType: "*google.languageModel",
|
||||
},
|
||||
{
|
||||
// Unknown npm but provider has an API URL → openai-compatible fallback.
|
||||
name: "unknown npm with API URL falls back to openai-compat",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/some-future-thing",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", tt.defaultNPM, tt.modelID, tt.overrideNPM)
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
result, err := autoRouteProvider(context.Background(), config, "testproxy", tt.modelID, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("autoRouteProvider returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.Model == nil {
|
||||
t.Fatalf("autoRouteProvider returned nil model")
|
||||
}
|
||||
|
||||
gotType := reflect.TypeOf(result.Model).String()
|
||||
if gotType != tt.wantType {
|
||||
t.Errorf("routed to %s, want %s", gotType, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownNpmNoAPI verifies the improved error message for
|
||||
// a provider whose npm has no known wire protocol and that has no API URL to
|
||||
// fall back on.
|
||||
func TestAutoRouteProvider_UnknownNpmNoAPI(t *testing.T) {
|
||||
reg := newTestRegistry("", "@ai-sdk/unmapped", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "testproxy", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown npm with no API URL, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot auto-route provider testproxy") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--provider-url") {
|
||||
t.Errorf("error should suggest --provider-url, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownProvider verifies the not-in-database error.
|
||||
func TestAutoRouteProvider_UnknownProvider(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", "@ai-sdk/openai-compatible", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "does-not-exist", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found in model database") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsProviderLLMSupported_Google verifies that a provider whose npm is
|
||||
// @ai-sdk/google is reported as supported (it now maps to a wire protocol).
|
||||
func TestIsProviderLLMSupported_Google(t *testing.T) {
|
||||
info := &ProviderInfo{ID: "testproxy", NPM: "@ai-sdk/google"}
|
||||
if !isProviderLLMSupported("testproxy", info) {
|
||||
t.Error("expected @ai-sdk/google provider to be LLM-supported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionedBasePath verifies detection of proxy base URLs that already
|
||||
// carry an API version segment (which collides with the genai SDK's injected
|
||||
// version).
|
||||
func TestVersionedBasePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
rawURL string
|
||||
want string
|
||||
}{
|
||||
{"https://opencode.ai/zen/v1", "/zen/v1"},
|
||||
{"https://opencode.ai/zen/v1/", "/zen/v1"},
|
||||
{"https://example.com/api/v1beta", "/api/v1beta"},
|
||||
{"https://example.com/api/v2alpha", "/api/v2alpha"},
|
||||
{"https://generativelanguage.googleapis.com", ""},
|
||||
{"https://proxy.example/openai", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := versionedBasePath(tt.rawURL); got != tt.want {
|
||||
t.Errorf("versionedBasePath(%q) = %q, want %q", tt.rawURL, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordingRoundTripper captures the path of the request it receives.
|
||||
type recordingRoundTripper struct{ gotPath string }
|
||||
|
||||
func (r *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
r.gotPath = req.URL.Path
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGeminiProxyTransport_StripsInjectedVersion verifies that the transport
|
||||
// collapses the genai-injected "/v1beta" segment that follows a proxy base
|
||||
// URL which already carries its own version segment. This is the second-order
|
||||
// fix that makes opencode/gemini-* actually reach the proxy (issue #41).
|
||||
func TestGeminiProxyTransport_StripsInjectedVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
basePath string
|
||||
reqPath string
|
||||
wantPath string
|
||||
}{
|
||||
{
|
||||
name: "strips doubled v1beta after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "strips doubled v1beta1 after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta1/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "leaves non-matching path untouched",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/other/v1beta/models/x:generateContent",
|
||||
wantPath: "/other/v1beta/models/x:generateContent",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := &recordingRoundTripper{}
|
||||
tr := &geminiProxyTransport{base: rec, basePath: tt.basePath}
|
||||
req, err := http.NewRequest(http.MethodPost, "https://host"+tt.reqPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
if _, err := tr.RoundTrip(req); err != nil {
|
||||
t.Fatalf("RoundTrip: %v", err)
|
||||
}
|
||||
if rec.gotPath != tt.wantPath {
|
||||
t.Errorf("forwarded path = %q, want %q", rec.gotPath, tt.wantPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package models
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
+51
-18
@@ -10,14 +10,24 @@ import (
|
||||
|
||||
// loadCustomModelsFromConfig loads custom model definitions from the config file
|
||||
// and returns them as a map of model ID -> ModelInfo. Returns nil if no custom
|
||||
// models are configured.
|
||||
// models are configured. Reads from the process-global viper store (the model
|
||||
// registry is a process-global singleton).
|
||||
func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
if !viper.IsSet("customModels") {
|
||||
return loadCustomModelsFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// loadCustomModelsFrom loads custom model definitions from the supplied store.
|
||||
// When v is nil the process-global store is used.
|
||||
func loadCustomModelsFrom(v *viper.Viper) map[string]ModelInfo {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("customModels") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var customModels map[string]CustomModelConfig
|
||||
if err := viper.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
if err := v.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
log.Printf("Warning: Failed to parse customModels: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -60,15 +70,26 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
// from the process-global viper store. Keys are "provider/model" strings.
|
||||
// Returns nil if no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return LoadModelSettingsFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
|
||||
// supplied per-instance store. When v is nil the process-global store is used.
|
||||
// Keys are "provider/model" strings. Returns nil if no model settings are
|
||||
// configured.
|
||||
func LoadModelSettingsFrom(v *viper.Viper) map[string]*GenerationParams {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
if err := v.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -148,12 +169,17 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the config store: prefer the per-instance store carried on the
|
||||
// ProviderConfig (set by BuildProviderConfig / Kit.New), falling back to
|
||||
// the process-global store for callers that don't thread one through.
|
||||
store := config.ConfigStore
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
if settings := LoadModelSettingsFrom(store); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
@@ -173,28 +199,28 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
if params.MaxTokens != nil && !isExplicitlySet(store, "max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
if params.Temperature != nil && !isExplicitlySet(store, "temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
if params.TopP != nil && !isExplicitlySet(store, "top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
if params.TopK != nil && !isExplicitlySet(store, "top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet(store, "frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
if params.PresencePenalty != nil && !isExplicitlySet(store, "presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet(store, "stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet(store, "thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
@@ -228,7 +254,14 @@ func LoadSystemPromptValue(input string) string {
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
//
|
||||
// The check runs against the supplied per-instance store when non-nil,
|
||||
// otherwise the process-global store. This keeps the "explicit vs unset"
|
||||
// precedence contract per-Kit-instance once a store is threaded through.
|
||||
func isExplicitlySet(v *viper.Viper, key string) bool {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
@@ -239,7 +272,7 @@ func isExplicitlySet(key string) bool {
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
return v.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
|
||||
File diff suppressed because one or more lines are too long
+24
-14
@@ -48,18 +48,28 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
"@ai-sdk/google-vertex": "google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic": "google-vertex-anthropic",
|
||||
"@ai-sdk/amazon-bedrock": "bedrock",
|
||||
"@ai-sdk/azure": "azure",
|
||||
"@openrouter/ai-sdk-provider": "openrouter",
|
||||
"@ai-sdk/vercel": "vercel",
|
||||
"@ai-sdk/openai-compatible": "openaicompat",
|
||||
// wireProtocol identifies which LLM API protocol an npm package speaks.
|
||||
// Fantasy implements three native protocols (openai, anthropic, google);
|
||||
// everything else in its providers/ tree is a thin wrapper around one of
|
||||
// them with a pre-baked default URL or auth scheme.
|
||||
type wireProtocol int
|
||||
|
||||
const (
|
||||
wireUnknown wireProtocol = iota
|
||||
wireOpenAI
|
||||
wireAnthropic
|
||||
wireGoogle
|
||||
)
|
||||
|
||||
// npmToWireProtocol maps npm package names from models.dev to the wire
|
||||
// protocol they speak. Provider-specific bundles (azure, bedrock, vercel,
|
||||
// openrouter, google-vertex, google-vertex-anthropic) are intentionally
|
||||
// absent — they have native top-level cases in CreateProvider and never
|
||||
// reach the auto-router. Providers not in this map but with an api URL
|
||||
// are auto-routed through the OpenAI-compatible wire.
|
||||
var npmToWireProtocol = map[string]wireProtocol{
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
+271
-29
@@ -9,7 +9,9 @@ import (
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +27,7 @@ import (
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -85,6 +88,7 @@ type ThinkingLevel string
|
||||
|
||||
const (
|
||||
ThinkingOff ThinkingLevel = "off"
|
||||
ThinkingNone ThinkingLevel = "none"
|
||||
ThinkingMinimal ThinkingLevel = "minimal"
|
||||
ThinkingLow ThinkingLevel = "low"
|
||||
ThinkingMedium ThinkingLevel = "medium"
|
||||
@@ -93,12 +97,14 @@ const (
|
||||
|
||||
// ThinkingLevels returns the ordered list of available thinking levels for cycling.
|
||||
func ThinkingLevels() []ThinkingLevel {
|
||||
return []ThinkingLevel{ThinkingOff, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh}
|
||||
return []ThinkingLevel{ThinkingOff, ThinkingNone, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh}
|
||||
}
|
||||
|
||||
// thinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off".
|
||||
// thinkingBudgetTokens returns the token budget for a thinking level, or 0 for "off" or "none".
|
||||
func thinkingBudgetTokens(level ThinkingLevel) int64 {
|
||||
switch level {
|
||||
case ThinkingNone:
|
||||
return 1024
|
||||
case ThinkingMinimal:
|
||||
return 1024
|
||||
case ThinkingLow:
|
||||
@@ -117,6 +123,8 @@ func ThinkingLevelDescription(level ThinkingLevel) string {
|
||||
switch level {
|
||||
case ThinkingOff:
|
||||
return "No reasoning"
|
||||
case ThinkingNone:
|
||||
return "Minimal reasoning (OpenAI 'none')"
|
||||
case ThinkingMinimal:
|
||||
return "Very brief reasoning (~1k tokens)"
|
||||
case ThinkingLow:
|
||||
@@ -133,7 +141,7 @@ func ThinkingLevelDescription(level ThinkingLevel) string {
|
||||
// ParseThinkingLevel converts a string to a ThinkingLevel, defaulting to ThinkingOff.
|
||||
func ParseThinkingLevel(s string) ThinkingLevel {
|
||||
switch ThinkingLevel(s) {
|
||||
case ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh:
|
||||
case ThinkingNone, ThinkingMinimal, ThinkingLow, ThinkingMedium, ThinkingHigh:
|
||||
return ThinkingLevel(s)
|
||||
default:
|
||||
return ThinkingOff
|
||||
@@ -159,6 +167,13 @@ type ProviderConfig struct {
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
|
||||
// ConfigStore is the per-instance configuration store used to resolve
|
||||
// "explicitly set" precedence checks (isExplicitlySet), per-model
|
||||
// settings, and right-sizing. When nil, the process-global viper store is
|
||||
// used. Threading a per-Kit store here keeps generation-parameter
|
||||
// precedence isolated between Kit instances in the same process.
|
||||
ConfigStore *viper.Viper
|
||||
|
||||
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
|
||||
// for long operations like Ollama model pulls. The returned io.ReadCloser
|
||||
// must be closed when done. When nil, the raw reader is consumed directly
|
||||
@@ -207,8 +222,10 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
//
|
||||
// Native providers: anthropic, openai, google, ollama, azure, google-vertex-anthropic,
|
||||
// openrouter, bedrock, vercel.
|
||||
// Any provider in models.dev with an api URL or openai-compatible npm package
|
||||
// is auto-routed through fantasy's openaicompat provider.
|
||||
// Any other provider in models.dev is auto-routed by wire protocol: its npm
|
||||
// package (or per-model override) selects the OpenAI, Anthropic, or Google
|
||||
// transport, using the provider's api URL as the base. Providers with an api
|
||||
// URL but an unrecognized npm package fall back to the OpenAI-compatible wire.
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
@@ -251,6 +268,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Auto-raise MaxTokens toward the model's known output ceiling when the
|
||||
// user hasn't explicitly set --max-tokens and no per-model override
|
||||
// applied. Runs after ApplyModelSettings so explicit modelSettings win.
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
@@ -295,9 +317,18 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
// Only add cache options for providers that don't already have
|
||||
// options set, to avoid type conflicts (e.g., Anthropic has
|
||||
// different types for regular options vs cache control options).
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
//
|
||||
// For OpenAI Responses API models, we skip merging entirely because
|
||||
// ResponsesProviderOptions and ProviderOptions are incompatible types.
|
||||
skipMerge := false
|
||||
if provider == "openai" && openai.IsResponsesModel(modelName) {
|
||||
skipMerge = true
|
||||
}
|
||||
if !skipMerge {
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -308,43 +339,62 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
// It routes on wire protocol (openai, anthropic, google) rather than per-npm
|
||||
// provider name: fantasy implements three native wire protocols, and every other
|
||||
// entry in its providers/ tree is a thin wrapper around one of them. Using the
|
||||
// provider's api URL from models.dev as the base URL, any proxy that re-flavors
|
||||
// one of these protocols (e.g. opencode's Gemini routes) Just Works.
|
||||
//
|
||||
// Models may carry a provider override that specifies a different npm package
|
||||
// than the provider's default (e.g. opencode's claude-* uses @ai-sdk/anthropic
|
||||
// and its gemini-* uses @ai-sdk/google), which is resolved first.
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
// Resolve npm: per-model override > provider default.
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
llmProvider = "openaicompat"
|
||||
wire, known := npmToWireProtocol[npmPackage]
|
||||
if !known {
|
||||
// Unknown npm but the provider has an API URL → assume OpenAI-compatible.
|
||||
// (Preserves the long-standing "any provider in models.dev with an api URL
|
||||
// is auto-routed through openaicompat" behaviour.)
|
||||
if providerInfo.API == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"cannot auto-route provider %s: npm package %q has no known wire protocol "+
|
||||
"and the registry has no API URL (use --provider-url to override)",
|
||||
provider, npmPackage,
|
||||
)
|
||||
}
|
||||
wire = wireOpenAI
|
||||
}
|
||||
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
// All three wires use the provider's API URL from models.dev as the base.
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
}
|
||||
|
||||
switch wire {
|
||||
case wireOpenAI:
|
||||
// The native OpenAI SDK package (@ai-sdk/openai) speaks the Responses
|
||||
// API; openai-compatible proxies (and unknown-npm fallbacks) use the
|
||||
// chat-completions wire via fantasy's openaicompat provider.
|
||||
if npmPackage == "@ai-sdk/openai" {
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
}
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
}
|
||||
case wireAnthropic:
|
||||
return createAutoRoutedAnthropicProvider(ctx, config, modelName, providerInfo)
|
||||
case "openai":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
case wireGoogle:
|
||||
return createAutoRoutedGoogleProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
return nil, fmt.Errorf("internal error: unknown wire protocol for provider %s (npm: %s)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,6 +511,115 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createAutoRoutedGoogleProvider creates a Google (Gemini) provider for
|
||||
// third-party providers that expose a Gemini-compatible API (e.g. opencode's
|
||||
// Gemini routes, which carry an @ai-sdk/google per-model override).
|
||||
//
|
||||
// The underlying genai SDK always injects its own API version segment
|
||||
// ("v1beta") between the base URL and the resource path. When the proxy's
|
||||
// base URL from models.dev already carries a version segment (e.g. opencode's
|
||||
// https://opencode.ai/zen/v1), that produces a doubled ".../v1/v1beta/..."
|
||||
// path that the proxy rejects. In that case we install a transport that
|
||||
// strips the injected segment so the proxy's own version is used.
|
||||
func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
}
|
||||
|
||||
opts := []google.Option{
|
||||
google.WithGeminiAPIKey(apiKey),
|
||||
google.WithName(info.ID),
|
||||
}
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, google.WithBaseURL(config.ProviderURL))
|
||||
}
|
||||
|
||||
// Decide whether the genai-injected version segment needs stripping.
|
||||
var httpClient *http.Client
|
||||
if basePath := versionedBasePath(config.ProviderURL); basePath != "" {
|
||||
httpClient = newGeminiProxyHTTPClient(basePath, config.TLSSkipVerify)
|
||||
} else if config.TLSSkipVerify {
|
||||
httpClient = createHTTPClientWithTLSConfig(true)
|
||||
}
|
||||
if httpClient != nil {
|
||||
opts = append(opts, google.WithHTTPClient(httpClient))
|
||||
}
|
||||
|
||||
p, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// versionSegmentRe matches a trailing API version segment in a URL path,
|
||||
// e.g. "/v1", "/v1beta", "/v1beta1", "/v2alpha".
|
||||
var versionSegmentRe = regexp.MustCompile(`/v\d+(?:beta\d*|alpha\d*)?$`)
|
||||
|
||||
// versionedBasePath returns the path component of rawURL when that path ends
|
||||
// with an API version segment (e.g. opencode's ".../zen/v1" → "/zen/v1").
|
||||
// It returns "" when rawURL is empty, unparseable, or has no version suffix
|
||||
// — in which case the genai SDK's default version injection is correct and
|
||||
// no rewriting is needed.
|
||||
func versionedBasePath(rawURL string) string {
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSuffix(u.Path, "/")
|
||||
if versionSegmentRe.MatchString(path) {
|
||||
return path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// newGeminiProxyHTTPClient builds an HTTP client whose transport strips the
|
||||
// genai-injected version segment ("v1beta"/"v1beta1") that directly follows
|
||||
// basePath, collapsing "{basePath}/v1beta/..." back to "{basePath}/...".
|
||||
func newGeminiProxyHTTPClient(basePath string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &geminiProxyTransport{base: base, basePath: basePath},
|
||||
}
|
||||
}
|
||||
|
||||
// geminiProxyTransport removes the redundant API version segment that the
|
||||
// genai SDK injects after a proxy base URL that already carries its own
|
||||
// version segment.
|
||||
type geminiProxyTransport struct {
|
||||
base http.RoundTripper
|
||||
basePath string
|
||||
}
|
||||
|
||||
func (t *geminiProxyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for _, injected := range []string{"/v1beta1", "/v1beta"} {
|
||||
prefix := t.basePath + injected + "/"
|
||||
if strings.HasPrefix(req.URL.Path, prefix) {
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Path = t.basePath + strings.TrimPrefix(req.URL.Path, t.basePath+injected)
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// resolveAPIKey returns the first non-empty API key from the explicit key
|
||||
// or the environment variables.
|
||||
func resolveAPIKey(explicitKey string, envVars []string) string {
|
||||
@@ -489,6 +648,37 @@ func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
// defaultRightSizeCap bounds auto-raised MaxTokens so that we don't silently
|
||||
// allocate enormous output budgets for models with very high ceilings (e.g.
|
||||
// Devstral at 262144, Mistral at 128000). Users who genuinely want more can
|
||||
// pass --max-tokens explicitly or set modelSettings[...].maxTokens in config.
|
||||
const defaultRightSizeCap = 32768
|
||||
|
||||
// rightSizeMaxTokens raises config.MaxTokens toward the model's known output
|
||||
// ceiling when:
|
||||
// - the user has not explicitly set --max-tokens (or the KIT_MAX_TOKENS env
|
||||
// var, or the top-level max-tokens key in config.yaml), AND
|
||||
// - no per-model override already bumped MaxTokens (ApplyModelSettings runs
|
||||
// before this function), AND
|
||||
// - modelInfo.Limit.Output is known and larger than the current MaxTokens.
|
||||
//
|
||||
// The raised value is capped at defaultRightSizeCap to keep accidental
|
||||
// allocations reasonable on very-large-output models. This prevents the
|
||||
// common "ghost" where the agent's reply is silently truncated at the 8192
|
||||
// default even though the selected model supports 64k or 262k output tokens.
|
||||
func rightSizeMaxTokens(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
if modelInfo == nil || modelInfo.Limit.Output <= 0 {
|
||||
return
|
||||
}
|
||||
if isExplicitlySet(config.ConfigStore, "max-tokens") {
|
||||
return
|
||||
}
|
||||
target := min(modelInfo.Limit.Output, defaultRightSizeCap)
|
||||
if config.MaxTokens < target {
|
||||
config.MaxTokens = target
|
||||
}
|
||||
}
|
||||
|
||||
// clearConflictingAnthropicSamplingParams ensures that temperature and top_p are
|
||||
// not both sent to the Anthropic API, which rejects requests containing both.
|
||||
// When both are set (typically from defaults), top_p is cleared so that
|
||||
@@ -535,6 +725,8 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
// Returns nil for ThinkingOff (use the model's default).
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingNone:
|
||||
return new(openai.ReasoningEffortNone)
|
||||
case ThinkingMinimal:
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
@@ -548,6 +740,56 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidThinkingLevelForModel checks if a thinking level is valid for the given
|
||||
// model. Some OpenAI models like gpt-5.4 don't support "minimal" and require
|
||||
// "none" instead.
|
||||
func IsValidThinkingLevelForModel(level ThinkingLevel, modelName string) bool {
|
||||
if level == ThinkingOff {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if this is an OpenAI model that doesn't support "minimal"
|
||||
// gpt-5.4 and newer gpt-5.x models use "none" instead of "minimal"
|
||||
if level == ThinkingMinimal {
|
||||
if strings.Contains(modelName, "gpt-5.4") ||
|
||||
strings.Contains(modelName, "gpt-5-pro") ||
|
||||
strings.Contains(modelName, "gpt-5-chat") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is an OpenAI model that doesn't support "none"
|
||||
// Older gpt-5 models only support "minimal", not "none"
|
||||
if level == ThinkingNone {
|
||||
if strings.Contains(modelName, "gpt-5") &&
|
||||
!strings.Contains(modelName, "gpt-5.4") &&
|
||||
!strings.Contains(modelName, "gpt-5-pro") &&
|
||||
!strings.Contains(modelName, "gpt-5-chat") {
|
||||
// Older gpt-5 models might not support "none"
|
||||
// They only added "none" support in newer versions
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// All other levels are generally valid for reasoning models
|
||||
return true
|
||||
}
|
||||
|
||||
// SuggestThinkingLevelFallback returns a recommended fallback level when the
|
||||
// requested level is not valid for the model. Returns ThinkingOff if no
|
||||
// suitable fallback exists.
|
||||
func SuggestThinkingLevelFallback(level ThinkingLevel, modelName string) ThinkingLevel {
|
||||
if level == ThinkingMinimal && !IsValidThinkingLevelForModel(level, modelName) {
|
||||
// For models that don't support "minimal", suggest "none" (~same token budget)
|
||||
return ThinkingNone
|
||||
}
|
||||
if level == ThinkingNone && !IsValidThinkingLevelForModel(level, modelName) {
|
||||
// For models that don't support "none", suggest "minimal" (~same token budget)
|
||||
return ThinkingMinimal
|
||||
}
|
||||
return ThinkingOff
|
||||
}
|
||||
|
||||
// buildAnthropicProviderOptions returns fantasy.ProviderOptions configured for
|
||||
// Anthropic models with extended thinking. When thinking is enabled, it sets
|
||||
// SendReasoning to true and configures the thinking budget. For thinking-off
|
||||
|
||||
+26
-13
@@ -4,6 +4,7 @@ import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -111,13 +112,30 @@ func NewModelsRegistry() *ModelsRegistry {
|
||||
}
|
||||
|
||||
// buildFromModelsDB converts models.dev provider data into our internal format.
|
||||
// It tries the on-disk cache first and falls back to the embedded database.
|
||||
// It starts from the compile-time embedded database and merges on-disk cached
|
||||
// data from `kit update-models` on top. Cached provider metadata replaces
|
||||
// embedded metadata, and model entries are merged with cached models taking
|
||||
// precedence. This means newly synced models are available while embedded
|
||||
// models that haven't been synced yet are still reachable.
|
||||
func buildFromModelsDB() map[string]ProviderInfo {
|
||||
// Try cached data first (from `kit update-models`)
|
||||
dbProviders, _ := LoadCachedProviders()
|
||||
if len(dbProviders) == 0 {
|
||||
// Fall back to compile-time embedded data
|
||||
dbProviders = loadEmbeddedProviders()
|
||||
// Start with compile-time embedded data as the base.
|
||||
dbProviders := loadEmbeddedProviders()
|
||||
if dbProviders == nil {
|
||||
dbProviders = make(ModelsDBProviders)
|
||||
}
|
||||
|
||||
// Merge on-disk cached data on top (cached takes precedence).
|
||||
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
|
||||
for providerID, cp := range cached {
|
||||
if existing, ok := dbProviders[providerID]; ok {
|
||||
// Merge models: embedded base + cached overrides.
|
||||
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
|
||||
maps.Copy(mergedModels, existing.Models)
|
||||
maps.Copy(mergedModels, cp.Models)
|
||||
cp.Models = mergedModels
|
||||
}
|
||||
dbProviders[providerID] = cp
|
||||
}
|
||||
}
|
||||
|
||||
providers := make(map[string]ProviderInfo, len(dbProviders))
|
||||
@@ -379,11 +397,6 @@ func (r *ModelsRegistry) GetLLMProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMProviders instead.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
return r.GetLLMProviders()
|
||||
}
|
||||
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama and custom are always supported (model names are user-defined).
|
||||
@@ -391,8 +404,8 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
// Check if npm maps to a known wire protocol
|
||||
if _, ok := npmToWireProtocol[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// bindMaxTokensFlag wires a fresh pflag-backed "max-tokens" key into viper so
|
||||
// isExplicitlySet behaves the same way it does in production. Returns a
|
||||
// cleanup function that removes the binding so sibling tests see a clean
|
||||
// state.
|
||||
func bindMaxTokensFlag(t *testing.T, args []string) func() {
|
||||
t.Helper()
|
||||
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
fs.Int("max-tokens", 8192, "")
|
||||
if err := viper.BindPFlag("max-tokens", fs.Lookup("max-tokens")); err != nil {
|
||||
t.Fatalf("BindPFlag: %v", err)
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
t.Fatalf("fs.Parse: %v", err)
|
||||
}
|
||||
return func() {
|
||||
viper.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_RaisesWhenBelowCeiling(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil) // no args → flag.Changed = false
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "claude-sonnet-4-5",
|
||||
Limit: Limit{Context: 200000, Output: 64000},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 32768 {
|
||||
t.Errorf("expected MaxTokens raised to defaultRightSizeCap (32768), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_CapsAtDefaultRightSizeCap(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Mistral Devstral has 262144 output — we should still cap at 32768.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "devstral-medium-latest",
|
||||
Limit: Limit{Context: 262144, Output: 262144},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != defaultRightSizeCap {
|
||||
t.Errorf("expected MaxTokens capped at %d, got %d", defaultRightSizeCap, config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_UsesExactOutputWhenBelowCap(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 4096}
|
||||
// Model with output limit smaller than the cap.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "gpt-4",
|
||||
Limit: Limit{Context: 8192, Output: 8192},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens raised to model output ceiling (8192), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_DoesNotLowerCurrentValue(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
// User (via per-model settings, applied earlier) already bumped MaxTokens
|
||||
// above the cap — we must not clobber their choice.
|
||||
config := &ProviderConfig{MaxTokens: 100000}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "devstral-medium-latest",
|
||||
Limit: Limit{Context: 262144, Output: 262144},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 100000 {
|
||||
t.Errorf("expected MaxTokens preserved at 100000, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_RespectsExplicitFlag(t *testing.T) {
|
||||
// Simulate `--max-tokens 4096` on the command line.
|
||||
cleanup := bindMaxTokensFlag(t, []string{"--max-tokens", "4096"})
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 4096}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "claude-sonnet-4-5",
|
||||
Limit: Limit{Context: 200000, Output: 64000},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected explicit --max-tokens to be preserved (4096), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_NilModelInfo(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Custom model / Ollama / unknown provider → no model info.
|
||||
rightSizeMaxTokens(config, nil)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens unchanged with nil modelInfo, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_ZeroOutputLimit(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Model present in catalog but with no known output limit.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "unknown-model",
|
||||
Limit: Limit{Context: 0, Output: 0},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens unchanged with zero output limit, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
+39
-35
@@ -36,15 +36,17 @@ type Diagnostic struct {
|
||||
}
|
||||
|
||||
// LoadAll discovers and loads all prompt templates from standard locations
|
||||
// and any extra paths. Templates are loaded in order of precedence (lowest
|
||||
// to highest), with later templates overriding earlier ones of the same name.
|
||||
// and any extra paths. Templates are loaded in order of precedence (highest
|
||||
// to lowest); the first source to define a given name wins, later definitions
|
||||
// of the same name are dropped with a diagnostic.
|
||||
//
|
||||
// Discovery paths searched in order:
|
||||
// 1. Default templates (if IncludeDefaults)
|
||||
// 2. ~/.kit/prompts/ (global user templates)
|
||||
// 3. .kit/prompts/ (project-local templates)
|
||||
// 4. ConfigPaths (from configuration)
|
||||
// 5. ExtraPaths (explicit paths, highest precedence)
|
||||
// 2. ~/.kit/prompts/ (legacy global)
|
||||
// 3. $XDG_CONFIG_HOME/kit/prompts/ (XDG global, default ~/.config/kit/prompts/)
|
||||
// 4. <cwd>/.kit/prompts/ (project-local templates)
|
||||
// 5. ConfigPaths (from configuration)
|
||||
// 6. ExtraPaths (explicit paths, lowest precedence)
|
||||
func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
if opts.Cwd == "" {
|
||||
opts.Cwd, _ = os.Getwd()
|
||||
@@ -88,13 +90,21 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
addTemplates(defaults, "default")
|
||||
}
|
||||
|
||||
// 2. Global user templates: ~/.kit/prompts/
|
||||
globalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(globalDir); err == nil {
|
||||
// 2. Legacy global user templates: ~/.kit/prompts/
|
||||
legacyGlobalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(legacyGlobalDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
|
||||
// 3. Project-local templates: .kit/prompts/
|
||||
// 3. XDG global user templates: $XDG_CONFIG_HOME/kit/prompts/
|
||||
// Default: ~/.config/kit/prompts/. Aligns with extensions and skills.
|
||||
if xdgDir := GlobalDir(); xdgDir != "" && xdgDir != legacyGlobalDir {
|
||||
if templates, err := LoadFromDir(xdgDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Project-local templates: .kit/prompts/
|
||||
localDir := filepath.Join(opts.Cwd, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(localDir); err == nil {
|
||||
addTemplates(templates, "local")
|
||||
@@ -179,31 +189,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
|
||||
// It returns the deduplicated list and diagnostics for any collisions.
|
||||
// This is a standalone function for when you need to deduplicate an existing list.
|
||||
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
|
||||
seen := make(map[string]*PromptTemplate)
|
||||
var result []*PromptTemplate
|
||||
var diagnostics []Diagnostic
|
||||
|
||||
for _, tpl := range templates {
|
||||
if existing, ok := seen[tpl.Name]; ok {
|
||||
diagnostics = append(diagnostics, Diagnostic{
|
||||
Name: tpl.Name,
|
||||
KeptPath: existing.FilePath,
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: "duplicate template name (first-match-wins)",
|
||||
})
|
||||
} else {
|
||||
seen[tpl.Name] = tpl
|
||||
result = append(result, tpl)
|
||||
}
|
||||
}
|
||||
|
||||
return result, diagnostics
|
||||
}
|
||||
|
||||
// loadDefaultTemplates returns the built-in default templates.
|
||||
// These are embedded templates that ship with Kit.
|
||||
func loadDefaultTemplates() []*PromptTemplate {
|
||||
@@ -211,3 +196,22 @@ func loadDefaultTemplates() []*PromptTemplate {
|
||||
// For now, return an empty slice - users can define their own templates
|
||||
return nil
|
||||
}
|
||||
|
||||
// GlobalDir returns the XDG-aligned global prompts directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/prompts/. Returns an empty
|
||||
// string if the user's home directory cannot be resolved.
|
||||
//
|
||||
// This is the canonical location for user-wide prompt templates and aligns
|
||||
// with the discovery paths used for extensions ($XDG_CONFIG_HOME/kit/extensions/)
|
||||
// and skills ($XDG_CONFIG_HOME/kit/skills/).
|
||||
func GlobalDir() string {
|
||||
base := os.Getenv("XDG_CONFIG_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(base, "kit", "prompts")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestCompactionParentCycleRegression tests that after multiple compactions,
|
||||
// newly appended messages always have a valid parent chain and BuildContext
|
||||
// returns the correct messages.
|
||||
func TestCompactionParentCycleRegression(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Simulate a long conversation with multiple compactions.
|
||||
msg1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
|
||||
msg2, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
|
||||
|
||||
// First compaction
|
||||
comp1, _ := tm.AppendCompaction("Summary 1", msg1, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
msg3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
|
||||
msg4, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg4"}}})
|
||||
|
||||
// Second compaction
|
||||
comp2, _ := tm.AppendCompaction("Summary 2", msg3, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
msg5, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg5"}}})
|
||||
msg6, _ := tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg6"}}})
|
||||
|
||||
// Verify parent chain integrity
|
||||
for _, id := range []string{msg1, msg2, comp1, msg3, msg4, comp2, msg5, msg6} {
|
||||
entry := tm.GetEntry(id)
|
||||
if entry == nil {
|
||||
t.Fatalf("entry %s not found in index", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Walk parent chain from msg6 — must reach root without cycles
|
||||
visited := make(map[string]bool)
|
||||
current := msg6
|
||||
for current != "" {
|
||||
if visited[current] {
|
||||
t.Fatalf("cycle detected at entry %s", current)
|
||||
}
|
||||
visited[current] = true
|
||||
entry := tm.GetEntry(current)
|
||||
if entry == nil {
|
||||
t.Fatalf("entry %s missing from index during parent walk", current)
|
||||
}
|
||||
parent := ""
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
parent = e.ParentID
|
||||
case *CompactionEntry:
|
||||
parent = e.ParentID
|
||||
}
|
||||
current = parent
|
||||
}
|
||||
|
||||
// BuildContext should return: Summary2 + msg6 + msg5 + msg3 + msg4 = 5 messages
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
if len(msgs) != 5 {
|
||||
t.Fatalf("expected 5 messages, got %d: %+v", len(msgs), msgs)
|
||||
}
|
||||
}
|
||||
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
|
||||
// Kept messages must appear BEFORE post-compaction messages so the LLM
|
||||
// sees the conversation in chronological order. Otherwise the latest
|
||||
// post-compaction user message would be followed by an older kept user
|
||||
// message, breaking user/assistant alternation and causing the model to
|
||||
// respond as if the post-compaction turn never happened.
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
// Verify order: summary, M3 (kept), M4 (new)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeCwdForDir verifies the working-directory → session-directory
|
||||
// name encoding strips characters that are illegal on Windows (notably the
|
||||
// drive-letter colon, see issue #18) while preserving the previous output
|
||||
// for the typical Unix paths.
|
||||
func TestEncodeCwdForDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "unix absolute path",
|
||||
cwd: "/home/user/proj",
|
||||
want: "home--user--proj",
|
||||
},
|
||||
{
|
||||
name: "unix relative path",
|
||||
cwd: "proj/sub",
|
||||
want: "proj--sub",
|
||||
},
|
||||
{
|
||||
name: "windows drive root",
|
||||
cwd: `C:\test`,
|
||||
want: "C--test",
|
||||
},
|
||||
{
|
||||
name: "windows nested path",
|
||||
cwd: `C:\Users\User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows secondary drive",
|
||||
cwd: `S:\work\repo`,
|
||||
want: "S--work--repo",
|
||||
},
|
||||
{
|
||||
name: "windows mixed separators",
|
||||
cwd: `C:\Users/User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows other illegal chars stripped",
|
||||
cwd: `C:\a<b>c|d?e*f"g`,
|
||||
want: "C--abcdefg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := encodeCwdForDir(tc.cwd)
|
||||
if got != tc.want {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q, want %q", tc.cwd, got, tc.want)
|
||||
}
|
||||
// Encoded directory must never contain characters that are
|
||||
// illegal in Windows directory names.
|
||||
for _, bad := range []string{":", "<", ">", "\"", "|", "?", "*", "\\", "/"} {
|
||||
if strings.Contains(got, bad) {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q contains illegal char %q", tc.cwd, got, bad)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -97,6 +99,11 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
//
|
||||
// Per-file extraction is parallelized across a small worker pool because each
|
||||
// file requires a full JSONL scan to compute MessageCount and FirstMessage —
|
||||
// for users with many sessions this is the dominant cost of opening the
|
||||
// session picker.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -107,20 +114,47 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
return nil, fmt.Errorf("failed to read directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
var sessions []SessionInfo
|
||||
// Collect candidate paths first so we can parallelize the heavy work.
|
||||
paths := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
|
||||
continue
|
||||
}
|
||||
paths = append(paths, filepath.Join(dir, entry.Name()))
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
info, err := extractSessionInfo(path)
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
results := make([]*SessionInfo, len(paths))
|
||||
|
||||
// Worker pool sized to GOMAXPROCS, capped to avoid thrashing for tiny lists.
|
||||
workers := max(min(runtime.GOMAXPROCS(0), len(paths)), 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan int, len(paths))
|
||||
for range workers {
|
||||
wg.Go(func() {
|
||||
for i := range jobs {
|
||||
info, err := extractSessionInfo(paths[i])
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
results[i] = info
|
||||
}
|
||||
})
|
||||
}
|
||||
for i := range paths {
|
||||
jobs <- i
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
|
||||
sessions := make([]SessionInfo, 0, len(results))
|
||||
for i, info := range results {
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
// Clean up and skip empty sessions (no messages).
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
_ = os.Remove(paths[i])
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestDetectCycleWithCorruptedParentChain tests that cycle detection works
|
||||
// when a corrupted session has circular parent references.
|
||||
func TestDetectCycleWithCorruptedParentChain(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Create normal chain: msg1 -> msg2 -> msg3
|
||||
id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
|
||||
_, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
|
||||
id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
|
||||
|
||||
// Simulate corruption: manually set msg1's parent to msg3, creating cycle
|
||||
// This simulates the condition seen in the user's session
|
||||
for _, entry := range tm.entries {
|
||||
if e, ok := entry.(*MessageEntry); ok && e.ID == id1 {
|
||||
e.ParentID = id3 // Create cycle: msg1 -> msg3 -> ... -> msg1
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// DetectCycle should find the cycle
|
||||
// The cycle is: id1 -> id3 -> id2 -> id1
|
||||
// So detecting from id3 should find id1 as the repeat
|
||||
cycle, entry := tm.DetectCycle(id3)
|
||||
if !cycle {
|
||||
t.Fatal("expected to detect cycle, but none found")
|
||||
}
|
||||
// The cycle entry could be id1 or id3 depending on where we start
|
||||
if entry != id1 && entry != id3 {
|
||||
t.Fatalf("expected cycle at %s or %s, got %s", id1, id3, entry)
|
||||
}
|
||||
|
||||
// BuildContext should still work (it has its own cycle detection)
|
||||
// but will truncate at the cycle point
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
if len(msgs) == 0 {
|
||||
t.Fatal("BuildContext returned no messages")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAppendMessageRejectsInvalidParent tests that AppendMessage rejects
|
||||
// appending when the current leaf has a broken parent chain.
|
||||
func TestAppendMessageRejectsInvalidParent(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Create normal message
|
||||
id1, err := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append msg1: %v", err)
|
||||
}
|
||||
|
||||
// Simulate corruption: set leafID to a non-existent ID
|
||||
tm.leafID = "non-existent-id"
|
||||
|
||||
// Next append should fail validation
|
||||
_, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error when appending with invalid leafID, got nil")
|
||||
}
|
||||
|
||||
// Restore valid leafID
|
||||
tm.leafID = id1
|
||||
|
||||
// Append should succeed now
|
||||
_, err = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append msg3 after restoring leafID: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildContextHandlesCycleGracefully tests that BuildContext handles
|
||||
// cycles gracefully by truncating the branch.
|
||||
func TestBuildContextHandlesCycleGracefully(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Create messages
|
||||
id1, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg1"}}})
|
||||
_, _ = tm.AppendMessage(message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "msg2"}}})
|
||||
id3, _ := tm.AppendMessage(message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "msg3"}}})
|
||||
|
||||
// Verify normal case works
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(msgs))
|
||||
}
|
||||
|
||||
// Simulate cycle: set msg1's parent to msg3
|
||||
for _, entry := range tm.entries {
|
||||
if e, ok := entry.(*MessageEntry); ok && e.ID == id1 {
|
||||
e.ParentID = id3
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// BuildContext should handle cycle gracefully (getBranchLocked has cycle detection)
|
||||
msgs, _, _ = tm.BuildContext()
|
||||
// Should only include messages from the cycle: msg3, msg2, msg1
|
||||
// (msg3 is leaf, walks to msg2 -> msg1 -> msg3 (cycle detected, stops))
|
||||
if len(msgs) != 3 {
|
||||
t.Fatalf("expected 3 messages in cycle case, got %d: %+v", len(msgs), msgs)
|
||||
}
|
||||
}
|
||||
@@ -63,6 +63,11 @@ type TreeManager struct {
|
||||
|
||||
// file is the open file handle for appending entries. Nil for in-memory.
|
||||
file *os.File
|
||||
|
||||
// writer is a buffered writer wrapping file. Writes go through this
|
||||
// buffer and are flushed to disk at explicit sync points (after each
|
||||
// public Append* call, in Close, etc.) to reduce syscall overhead.
|
||||
writer *bufio.Writer
|
||||
}
|
||||
|
||||
// --- Constructors ---
|
||||
@@ -105,11 +110,16 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
|
||||
return nil, fmt.Errorf("failed to create session file: %w", err)
|
||||
}
|
||||
tm.file = f
|
||||
tm.writer = bufio.NewWriter(f)
|
||||
|
||||
if err := tm.writeEntry(&header); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to write session header: %w", err)
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to flush session header: %w", err)
|
||||
}
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
@@ -150,6 +160,7 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
|
||||
return nil, fmt.Errorf("failed to recreate session file: %w", err)
|
||||
}
|
||||
newTm.file = f
|
||||
newTm.writer = bufio.NewWriter(f)
|
||||
|
||||
if err := newTm.writeEntry(&newTm.header); err != nil {
|
||||
_ = f.Close()
|
||||
@@ -289,6 +300,12 @@ func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManag
|
||||
}
|
||||
}
|
||||
|
||||
// Flush all buffered writes from the fork in a single syscall.
|
||||
if err := newTm.flushLocked(); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to flush forked session: %w", err)
|
||||
}
|
||||
|
||||
// Set the leaf to the last entry in the new session.
|
||||
newTm.leafID = prevNewID
|
||||
|
||||
@@ -365,12 +382,16 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
|
||||
}
|
||||
|
||||
// Validate tree integrity and log diagnostics
|
||||
tm.LogTreeDiagnostics()
|
||||
|
||||
// Open file for appending.
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open session file for append: %w", err)
|
||||
}
|
||||
tm.file = f
|
||||
tm.writer = bufio.NewWriter(f)
|
||||
|
||||
return tm, nil
|
||||
}
|
||||
@@ -410,6 +431,12 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
// Validate parent chain before appending to detect/prevent cycles
|
||||
// that could be caused by external file corruption or race conditions.
|
||||
if err := tm.validateParentChainLocked(tm.leafID, ""); err != nil {
|
||||
return "", fmt.Errorf("parent chain validation failed: %w", err)
|
||||
}
|
||||
|
||||
entry, err := NewMessageEntry(tm.leafID, msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -418,6 +445,9 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush message: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -442,6 +472,9 @@ func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, erro
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush model change: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -456,6 +489,9 @@ func (tm *TreeManager) AppendBranchSummary(fromID, summary string) (string, erro
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush branch summary: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -470,6 +506,9 @@ func (tm *TreeManager) AppendLabel(targetID, label string) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush label: %w", err)
|
||||
}
|
||||
|
||||
tm.labels[targetID] = label
|
||||
tm.leafID = entry.ID
|
||||
@@ -485,6 +524,9 @@ func (tm *TreeManager) AppendSessionInfo(name string) (string, error) {
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush session info: %w", err)
|
||||
}
|
||||
|
||||
tm.sessionName = name
|
||||
tm.leafID = entry.ID
|
||||
@@ -501,6 +543,9 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush extension data: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -518,6 +563,13 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
// Validate that firstKeptEntryID exists if provided
|
||||
if firstKeptEntryID != "" {
|
||||
if _, ok := tm.index[firstKeptEntryID]; !ok {
|
||||
return "", fmt.Errorf("first kept entry %q does not exist", firstKeptEntryID)
|
||||
}
|
||||
}
|
||||
|
||||
// The compaction entry has no parent, making it a new "root" for the
|
||||
// post-compaction branch. This ensures old compacted messages are not
|
||||
// traversed when walking from the current leaf.
|
||||
@@ -525,6 +577,9 @@ func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokens
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := tm.flushLocked(); err != nil {
|
||||
return "", fmt.Errorf("failed to flush compaction: %w", err)
|
||||
}
|
||||
|
||||
tm.leafID = entry.ID
|
||||
return entry.ID, nil
|
||||
@@ -700,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
// If there is a compaction, inject the summary first, then the
|
||||
// preserved "kept" messages (chronologically before the compaction),
|
||||
// then the post-compaction messages (chronologically after).
|
||||
//
|
||||
// Order matters: the kept messages must come BEFORE the post-compaction
|
||||
// branch so the LLM sees the conversation in chronological order. If the
|
||||
// kept messages were appended last, the latest user message in the
|
||||
// current branch would be followed by an older kept user message,
|
||||
// breaking the strict user/assistant alternation that providers expect
|
||||
// and causing the model to respond as if the previous turn never
|
||||
// happened.
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -713,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
})
|
||||
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
// Step 1: collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not on the current branch (the compaction entry is a
|
||||
// new root with no parent), so we iterate tm.entries in append order
|
||||
// and stop when we reach the compaction entry itself.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
@@ -770,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
// Stop when we reach the compaction entry itself; messages
|
||||
// after it are collected from the branch walk below.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -805,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: collect entries on the current branch after the compaction
|
||||
// entry (these are post-compaction messages). The compaction entry
|
||||
// itself is skipped — its summary was already injected above.
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Summary already injected above.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
@@ -910,11 +969,31 @@ func (tm *TreeManager) IsEmpty() bool {
|
||||
return tm.MessageCount() == 0
|
||||
}
|
||||
|
||||
// Close closes the underlying file handle.
|
||||
// Flush writes any buffered data to the underlying file.
|
||||
func (tm *TreeManager) Flush() error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// flushLocked writes buffered data to disk. Caller must hold the lock.
|
||||
func (tm *TreeManager) flushLocked() error {
|
||||
if tm.writer != nil {
|
||||
return tm.writer.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close flushes any buffered writes and closes the underlying file handle.
|
||||
func (tm *TreeManager) Close() error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
if tm.file != nil {
|
||||
// Flush buffered data before closing.
|
||||
if tm.writer != nil {
|
||||
_ = tm.writer.Flush()
|
||||
tm.writer = nil
|
||||
}
|
||||
err := tm.file.Close()
|
||||
tm.file = nil
|
||||
return err
|
||||
@@ -955,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
// If there's a compaction, we collect IDs in the same order as
|
||||
// BuildContext: [summary placeholder, kept messages, post-compaction
|
||||
// messages]. This ordering must stay in sync with BuildContext so a
|
||||
// cut-point index can be mapped back to the correct entry ID.
|
||||
if lastCompaction != nil {
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
|
||||
// Iterate tm.entries in append order and stop at the compaction
|
||||
// entry to avoid double-counting post-compaction messages.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
@@ -1001,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
@@ -1025,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: IDs of entries after the compaction entry on the current
|
||||
// branch (post-compaction messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
@@ -1074,13 +1152,22 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
// All entries are buffered and flushed to disk in a single batch.
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
entry, err := NewMessageEntry(tm.leafID, message.FromLLMMessage(msg))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return err
|
||||
}
|
||||
tm.leafID = entry.ID
|
||||
}
|
||||
return nil
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
@@ -1132,12 +1219,20 @@ func (tm *TreeManager) appendAndPersist(entry any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeEntry serializes an entry and appends it as a line to the file.
|
||||
// writeEntry serializes an entry and appends it to the buffered writer.
|
||||
// The data is not flushed to disk until flushLocked is called.
|
||||
func (tm *TreeManager) writeEntry(entry any) error {
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal entry: %w", err)
|
||||
}
|
||||
if tm.writer != nil {
|
||||
if _, err := tm.writer.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return tm.writer.WriteByte('\n')
|
||||
}
|
||||
// Fallback for direct file writes (shouldn't happen in normal flow).
|
||||
data = append(data, '\n')
|
||||
_, err = tm.file.Write(data)
|
||||
return err
|
||||
@@ -1213,12 +1308,32 @@ func (tm *TreeManager) getBranchLocked(fromID string) []any {
|
||||
}
|
||||
|
||||
// buildTreeNode recursively builds a TreeNode from an entry ID.
|
||||
// It includes a depth limit to prevent infinite recursion in case of
|
||||
// corrupted parent-child relationships.
|
||||
func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
|
||||
return tm.buildTreeNodeDepth(id, 0, make(map[string]bool))
|
||||
}
|
||||
|
||||
// buildTreeNodeDepth is the internal implementation with depth tracking.
|
||||
func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[string]bool) *TreeNode {
|
||||
const maxDepth = 1000
|
||||
if depth > maxDepth {
|
||||
// Cycle or extremely deep tree detected, stop recursing
|
||||
return nil
|
||||
}
|
||||
if visited[id] {
|
||||
// Cycle detected, stop recursing
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, ok := tm.index[id]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
visited[id] = true
|
||||
defer delete(visited, id)
|
||||
|
||||
node := &TreeNode{
|
||||
Entry: entry,
|
||||
ID: id,
|
||||
@@ -1226,7 +1341,7 @@ func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
|
||||
}
|
||||
|
||||
for _, childID := range tm.childIndex[id] {
|
||||
child := tm.buildTreeNode(childID)
|
||||
child := tm.buildTreeNodeDepth(childID, depth+1, visited)
|
||||
if child != nil {
|
||||
node.Children = append(node.Children, child)
|
||||
}
|
||||
@@ -1238,15 +1353,44 @@ func (tm *TreeManager) buildTreeNode(id string) *TreeNode {
|
||||
// --- Path conventions ---
|
||||
|
||||
// DefaultSessionDir returns the default session storage directory for a cwd.
|
||||
// Convention: ~/.kit/sessions/--<cwd-path>--/
|
||||
// Convention: ~/.kit/sessions/<encoded-cwd>, where path separators are
|
||||
// encoded as "--" with no leading or trailing dashes — e.g.
|
||||
// /home/user/proj becomes home--user--proj. See encodeCwdForDir for the
|
||||
// full encoding rules (including Windows path handling).
|
||||
func DefaultSessionDir(cwd string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
// Convert path separators to double dashes.
|
||||
safeCwd := strings.ReplaceAll(cwd, string(filepath.Separator), "--")
|
||||
return filepath.Join(home, ".kit", "sessions", encodeCwdForDir(cwd))
|
||||
}
|
||||
|
||||
// encodeCwdForDir converts a working-directory path into a single, filesystem-
|
||||
// safe directory name. Path separators are replaced with double dashes and
|
||||
// characters that are illegal in Windows directory names — most importantly
|
||||
// the colon that follows the drive letter (e.g. `C:\foo` → `C--foo`) — are
|
||||
// stripped. The result is identical to the previous Unix-only encoding for
|
||||
// paths that do not contain such characters, so existing session directories
|
||||
// are preserved.
|
||||
func encodeCwdForDir(cwd string) string {
|
||||
// Convert both `/` and `\` to double dashes so encoding is stable across
|
||||
// platforms and remains correct on Windows where `filepath.Separator`
|
||||
// would otherwise miss forward-slash style paths.
|
||||
safeCwd := strings.ReplaceAll(cwd, "\\", "--")
|
||||
safeCwd = strings.ReplaceAll(safeCwd, "/", "--")
|
||||
// Remove leading separator replacement.
|
||||
safeCwd = strings.TrimPrefix(safeCwd, "--")
|
||||
return filepath.Join(home, ".kit", "sessions", safeCwd)
|
||||
// Strip characters that are illegal in directory names on Windows
|
||||
// (`< > : " | ? *`). On Unix these characters are legal but rare in
|
||||
// practice; stripping them keeps the encoding portable.
|
||||
replacer := strings.NewReplacer(
|
||||
":", "",
|
||||
"<", "",
|
||||
">", "",
|
||||
"\"", "",
|
||||
"|", "",
|
||||
"?", "",
|
||||
"*", "",
|
||||
)
|
||||
return replacer.Replace(safeCwd)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// ValidateParentChain checks that the parent ID points to an existing entry
|
||||
// and that appending this entry would not create a cycle. This should be called
|
||||
// before appending any entry to the tree.
|
||||
// Returns an error if the parent is invalid or would create a cycle.
|
||||
func (tm *TreeManager) ValidateParentChain(parentID string, newEntryID string) error {
|
||||
if parentID == "" {
|
||||
// Empty parent is valid (root entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check that parent exists
|
||||
if _, ok := tm.index[parentID]; !ok {
|
||||
return fmt.Errorf("parent entry %q does not exist in index", parentID)
|
||||
}
|
||||
|
||||
// Check that we're not creating a cycle by walking up the parent chain
|
||||
// from parentID and ensuring we don't hit newEntryID (or any node that
|
||||
// has newEntryID as an ancestor, but since newEntryID is new, just check
|
||||
// that parentID isn't newEntryID, which it can't be since we check existence)
|
||||
visited := make(map[string]bool)
|
||||
current := parentID
|
||||
for current != "" {
|
||||
if visited[current] {
|
||||
return fmt.Errorf("existing cycle detected at entry %q", current)
|
||||
}
|
||||
visited[current] = true
|
||||
|
||||
// Safety check: if somehow we reach the new entry ID, that's a cycle
|
||||
if current == newEntryID {
|
||||
return fmt.Errorf("would create cycle: entry %q cannot be its own ancestor", newEntryID)
|
||||
}
|
||||
|
||||
entry, ok := tm.index[current]
|
||||
if !ok {
|
||||
return fmt.Errorf("broken parent chain: entry %q not found", current)
|
||||
}
|
||||
current = tm.entryParentID(entry)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DetectCycle walks the parent chain from the given entry ID and returns true
|
||||
// if a cycle is detected. This is used for diagnostics.
|
||||
func (tm *TreeManager) DetectCycle(fromID string) (cycleDetected bool, cycleEntry string) {
|
||||
visited := make(map[string]bool)
|
||||
current := fromID
|
||||
for current != "" {
|
||||
if visited[current] {
|
||||
return true, current
|
||||
}
|
||||
visited[current] = true
|
||||
entry, ok := tm.index[current]
|
||||
if !ok {
|
||||
return false, ""
|
||||
}
|
||||
current = tm.entryParentID(entry)
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// LogTreeDiagnostics logs information about the tree structure for debugging.
|
||||
// Call this after OpenTreeSession or when anomalies are detected.
|
||||
func (tm *TreeManager) LogTreeDiagnostics() {
|
||||
tm.mu.RLock()
|
||||
defer tm.mu.RUnlock()
|
||||
|
||||
log.Printf("[TreeManager] Entry count: %d, Leaf ID: %s", len(tm.entries), tm.leafID)
|
||||
|
||||
// Check for cycles from leaf
|
||||
if tm.leafID != "" {
|
||||
if cycle, entry := tm.detectCycleLocked(tm.leafID); cycle {
|
||||
log.Printf("[TreeManager] WARNING: Cycle detected in tree at entry %s", entry)
|
||||
}
|
||||
}
|
||||
|
||||
// Count entries by type
|
||||
counts := make(map[EntryType]int)
|
||||
for _, entry := range tm.entries {
|
||||
var et EntryType
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
et = e.Type
|
||||
case *ModelChangeEntry:
|
||||
et = e.Type
|
||||
case *BranchSummaryEntry:
|
||||
et = e.Type
|
||||
case *LabelEntry:
|
||||
et = e.Type
|
||||
case *SessionInfoEntry:
|
||||
et = e.Type
|
||||
case *ExtensionDataEntry:
|
||||
et = e.Type
|
||||
case *CompactionEntry:
|
||||
et = e.Type
|
||||
default:
|
||||
et = "unknown"
|
||||
}
|
||||
counts[et]++
|
||||
}
|
||||
log.Printf("[TreeManager] Entry types: %+v", counts)
|
||||
}
|
||||
|
||||
// detectCycleLocked is the internal version of DetectCycle (must hold read lock)
|
||||
func (tm *TreeManager) detectCycleLocked(fromID string) (bool, string) {
|
||||
visited := make(map[string]bool)
|
||||
current := fromID
|
||||
for current != "" {
|
||||
if visited[current] {
|
||||
return true, current
|
||||
}
|
||||
visited[current] = true
|
||||
entry, ok := tm.index[current]
|
||||
if !ok {
|
||||
return false, ""
|
||||
}
|
||||
current = tm.entryParentID(entry)
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// validateParentChainLocked is the internal version used by append methods.
|
||||
// Must be called with the write lock held.
|
||||
func (tm *TreeManager) validateParentChainLocked(parentID string, newEntryID string) error {
|
||||
if parentID == "" {
|
||||
return nil
|
||||
}
|
||||
if _, ok := tm.index[parentID]; !ok {
|
||||
return fmt.Errorf("parent entry %q does not exist", parentID)
|
||||
}
|
||||
// Check for existing cycles in the parent chain
|
||||
if cycle, entry := tm.detectCycleLocked(parentID); cycle {
|
||||
return fmt.Errorf("existing cycle detected at entry %q in parent chain", entry)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
|
||||
@@ -47,6 +47,7 @@ type MCPConnection struct {
|
||||
client client.MCPClient
|
||||
serverName string
|
||||
serverConfig config.MCPServerConfig
|
||||
initResult *mcp.InitializeResult // captured at handshake; nil before initialize
|
||||
lastUsed time.Time
|
||||
isHealthy bool
|
||||
errorCount int
|
||||
@@ -63,7 +64,6 @@ type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
@@ -75,9 +75,8 @@ type MCPConnectionPool struct {
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
// If config is nil, default configuration values will be used. The pool starts a background
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
@@ -86,7 +85,6 @@ func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageMo
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
@@ -246,10 +244,12 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
oauthEnabled := p.oauthFlow != nil && !serverConfig.NoOAuth
|
||||
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if oauthEnabled && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
@@ -263,15 +263,17 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
conn := &MCPConnection{}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if oauthEnabled && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
@@ -281,15 +283,11 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
isHealthy: true,
|
||||
errorCount: 0,
|
||||
lastError: nil,
|
||||
}
|
||||
conn.client = mcpClient
|
||||
conn.serverName = serverName
|
||||
conn.serverConfig = serverConfig
|
||||
conn.lastUsed = time.Now()
|
||||
conn.isHealthy = true
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Created connection for %s", serverName))
|
||||
@@ -308,6 +306,8 @@ func (p *MCPConnectionPool) createMCPClient(ctx context.Context, serverName stri
|
||||
return p.createSSEClient(ctx, serverConfig)
|
||||
case "streamable":
|
||||
return p.createStreamableClient(ctx, serverConfig)
|
||||
case "inprocess":
|
||||
return p.createInProcessClient(serverConfig)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
|
||||
}
|
||||
@@ -345,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig
|
||||
return stdioClient, nil
|
||||
}
|
||||
|
||||
// createSSEClient creates an SSE client
|
||||
// parseHeaders parses "Key: Value" header strings into a map.
|
||||
func parseHeaders(raw []string) map[string]string {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
headers := make(map[string]string)
|
||||
for _, header := range raw {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// buildOAuthConfig constructs a transport.OAuthConfig from the server config
|
||||
// and the pool's OAuth flow. Returns nil if OAuth is not applicable.
|
||||
func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) {
|
||||
if p.oauthFlow == nil || serverConfig.NoOAuth {
|
||||
return nil, nil
|
||||
}
|
||||
tokenStore, err := p.createTokenStore(serverConfig.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", err)
|
||||
}
|
||||
cfg := &transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
cfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
cfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
cfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.ClientOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
|
||||
// PubMed) set NoOAuth to skip dynamic client registration and token
|
||||
// exchange, which would otherwise fail with a 404.
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
@@ -406,45 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.StreamableHTTPCOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth.
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithHTTPOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
@@ -459,6 +453,22 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createInProcessClient creates an in-process MCP client that communicates
|
||||
// directly with an *server.MCPServer in the same process. No subprocess is
|
||||
// spawned and no network I/O occurs — calls go through JSON marshal →
|
||||
// MCPServer.HandleMessage → JSON unmarshal, all in-memory.
|
||||
func (p *MCPConnectionPool) createInProcessClient(serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
srv, ok := serverConfig.InProcessServer.(*server.MCPServer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("InProcessServer must be *server.MCPServer, got %T", serverConfig.InProcessServer)
|
||||
}
|
||||
inProcessClient, err := client.NewInProcessClient(srv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create in-process client: %w", err)
|
||||
}
|
||||
return inProcessClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
@@ -469,8 +479,10 @@ func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenS
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
// initializeClient initializes the client and captures the server's
|
||||
// initialize result on the supplied connection so callers can later
|
||||
// inspect advertised capabilities (e.g. task support).
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, c client.MCPClient, conn *MCPConnection) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -480,12 +492,21 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
Name: "kit",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
|
||||
// Advertise task support so servers may return CreateTaskResult for
|
||||
// long-running tools/call requests instead of blocking the connection
|
||||
// until completion. The client is responsible for polling tasks/get and
|
||||
// tasks/result until the task reaches a terminal state.
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{
|
||||
Tasks: mcp.NewTasksCapability(),
|
||||
}
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
initResult, err := c.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
if conn != nil {
|
||||
conn.initResult = initResult
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug("[POOL] Initialized MCP client")
|
||||
@@ -600,6 +621,54 @@ func (c *MCPConnection) ServerName() string {
|
||||
return c.serverName
|
||||
}
|
||||
|
||||
// InitializeResult returns the result captured from the server's initialize
|
||||
// response, or nil if the connection was created before initialize completed.
|
||||
// Callers can inspect ServerCapabilities.Tasks to discover task-related
|
||||
// capability advertisements.
|
||||
func (c *MCPConnection) InitializeResult() *mcp.InitializeResult {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.initResult
|
||||
}
|
||||
|
||||
// SupportsToolTasks reports whether the server advertised support for
|
||||
// task-augmented tools/call requests. Returns false when the connection has
|
||||
// not yet completed initialization or when the server omitted task
|
||||
// capabilities.
|
||||
func (c *MCPConnection) SupportsToolTasks() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return supportsToolTasksFromInit(c.initResult)
|
||||
}
|
||||
|
||||
// supportsToolTasksFromInit reports whether the supplied InitializeResult
|
||||
// advertises task-augmented tools/call support. Extracted to a free function
|
||||
// for unit testing without standing up a connection.
|
||||
func supportsToolTasksFromInit(init *mcp.InitializeResult) bool {
|
||||
if init == nil || init.Capabilities.Tasks == nil {
|
||||
return false
|
||||
}
|
||||
req := init.Capabilities.Tasks.Requests
|
||||
if req == nil || req.Tools == nil {
|
||||
return false
|
||||
}
|
||||
return req.Tools.Call != nil
|
||||
}
|
||||
|
||||
// ServerSupportsToolTasks reports whether the named server's connection
|
||||
// advertises task-augmented tools/call support. Returns false when no
|
||||
// connection exists for the server or when the server didn't advertise the
|
||||
// capability.
|
||||
func (p *MCPConnectionPool) ServerSupportsToolTasks(serverName string) bool {
|
||||
p.mu.RLock()
|
||||
conn, ok := p.connections[serverName]
|
||||
p.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return conn.SupportsToolTasks()
|
||||
}
|
||||
|
||||
// GetClients returns a map of all MCP clients currently in the pool.
|
||||
// The map keys are server names and values are the corresponding MCP client instances.
|
||||
// The returned map is a copy and modifications won't affect the pool.
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// mcpFantasyTool adapts an MCP tool to the fantasy.AgentTool interface.
|
||||
// It bridges the MCP tool protocol with fantasy's agent tool system, handling
|
||||
// name prefixing, schema conversion, connection pooling, and result marshaling.
|
||||
type mcpFantasyTool struct {
|
||||
toolInfo fantasy.ToolInfo
|
||||
mapping *toolMapping
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpFantasyTool) Info() fantasy.ToolInfo {
|
||||
return t.toolInfo
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by routing through the connection pool.
|
||||
// It maps the prefixed tool name back to the original name, retrieves a healthy
|
||||
// connection, invokes the tool, and converts the MCP result to a fantasy ToolResponse.
|
||||
func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
// Parse and validate JSON arguments
|
||||
var arguments any
|
||||
input := call.Input
|
||||
if input == "" || input == "{}" {
|
||||
arguments = nil
|
||||
} else {
|
||||
var temp any
|
||||
if err := json.Unmarshal([]byte(input), &temp); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid JSON arguments: %v", err)), nil
|
||||
}
|
||||
arguments = json.RawMessage(input)
|
||||
}
|
||||
|
||||
// Get connection from pool with health check
|
||||
conn, err := t.mapping.manager.connectionPool.GetConnectionWithHealthCheck(
|
||||
ctx, t.mapping.serverName, t.mapping.serverConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
||||
}
|
||||
|
||||
// Call the MCP tool using the original (unprefixed) name
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if t.mapping.manager.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := t.mapping.manager.connectionPool.oauthFlow.RunAuthFlow(ctx, t.mapping.serverName, err); flowErr != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", t.mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
marshaledResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
|
||||
// Return as text response, preserving error status from MCP
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(string(marshaledResult)), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(marshaledResult)), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTestInProcessServer creates a simple MCP server with one tool for testing.
|
||||
func newTestInProcessServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("test-server", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description("Name to greet")),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("Hello, " + name + "!"), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestInProcessTransportType(t *testing.T) {
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessTransportTypeInferred(t *testing.T) {
|
||||
// When Type is empty but InProcessServer is set, infer "inprocess".
|
||||
cfg := config.MCPServerConfig{
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessValidation(t *testing.T) {
|
||||
// Valid: InProcessServer is set.
|
||||
validCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := validCfg.Validate(); err != nil {
|
||||
t.Errorf("expected valid config, got error: %v", err)
|
||||
}
|
||||
|
||||
// Invalid: type is inprocess but InProcessServer is nil.
|
||||
invalidCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := invalidCfg.Validate(); err == nil {
|
||||
t.Error("expected validation error for nil InProcessServer, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessClient(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the connection is healthy and functional.
|
||||
if !conn.isHealthy {
|
||||
t.Error("expected connection to be healthy")
|
||||
}
|
||||
|
||||
// List tools to verify the connection works end-to-end.
|
||||
toolsResp, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools failed: %v", err)
|
||||
}
|
||||
if len(toolsResp.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(toolsResp.Tools))
|
||||
}
|
||||
if toolsResp.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", toolsResp.Tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessToolExecution(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the tool.
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: "greet",
|
||||
Arguments: map[string]any{"name": "World"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected at least one content block")
|
||||
}
|
||||
text, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("expected TextContent, got %T", result.Content[0])
|
||||
}
|
||||
if text.Text != "Hello, World!" {
|
||||
t.Errorf("expected 'Hello, World!', got %q", text.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolManagerInProcess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
mgr := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
count, err := mgr.AddServer(ctx, "myserver", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", count)
|
||||
}
|
||||
|
||||
tools := mgr.GetTools()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Name != "myserver__greet" {
|
||||
t.Errorf("expected tool name 'myserver__greet', got %q", tools[0].Name)
|
||||
}
|
||||
|
||||
// Execute the tool.
|
||||
input, _ := json.Marshal(map[string]any{"name": "SDK"})
|
||||
result, err := mgr.ExecuteTool(ctx, "myserver__greet", string(input))
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if result.Content == "" {
|
||||
t.Error("expected non-empty result content")
|
||||
}
|
||||
|
||||
// Verify result contains our greeting.
|
||||
if !strings.Contains(result.Content, "Hello, SDK!") {
|
||||
t.Errorf("expected 'Hello, SDK!' in result, got %q", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessInvalidServer(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pass a non-*server.MCPServer value.
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: "not a server",
|
||||
}
|
||||
|
||||
_, err := pool.GetConnection(ctx, "bad", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid InProcessServer type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessReuse(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
// Get connection twice — should reuse.
|
||||
conn1, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("first GetConnection failed: %v", err)
|
||||
}
|
||||
conn2, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("second GetConnection failed: %v", err)
|
||||
}
|
||||
if conn1 != conn2 {
|
||||
t.Error("expected same connection object on reuse")
|
||||
}
|
||||
}
|
||||
+882
-46
File diff suppressed because it is too large
Load Diff
@@ -101,7 +101,7 @@ func TestMCPToolManager_AddServer_Integration(t *testing.T) {
|
||||
// Verify tool names are prefixed.
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
toolNames[tool.Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo'")
|
||||
@@ -234,8 +234,8 @@ func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
|
||||
|
||||
// Remaining tools should all be from server-b.
|
||||
for _, tool := range tools {
|
||||
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
|
||||
if !strings.HasPrefix(tool.Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ func TestMCPToolManager_Close_NilPool(t *testing.T) {
|
||||
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
|
||||
// non-existent connection returns an error.
|
||||
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
err := pool.RemoveConnection("nonexistent")
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mcpclient "github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTestPromptServer creates an in-process MCP server with prompt capabilities
|
||||
// and the specified prompts + handlers. Returns an initialized MCPClient.
|
||||
func newTestPromptServer(t *testing.T, prompts ...server.ServerPrompt) mcpclient.MCPClient {
|
||||
t.Helper()
|
||||
|
||||
mcpServer := server.NewMCPServer(
|
||||
"test-prompt-server", "1.0.0",
|
||||
server.WithPromptCapabilities(true),
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
if len(prompts) > 0 {
|
||||
mcpServer.AddPrompts(prompts...)
|
||||
}
|
||||
|
||||
// Add a dummy tool so loadServerTools has something to list.
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool("noop", mcp.WithDescription("no-op tool")),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
)
|
||||
|
||||
client, err := mcpclient.NewInProcessClient(mcpServer)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInProcessClient: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := client.Start(ctx); err != nil {
|
||||
t.Fatalf("client.Start: %v", err)
|
||||
}
|
||||
|
||||
initReq := mcp.InitializeRequest{}
|
||||
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
|
||||
if _, err := client.Initialize(ctx, initReq); err != nil {
|
||||
t.Fatalf("client.Initialize: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
// injectClientIntoManager sets up an MCPToolManager with a pre-connected
|
||||
// in-process client, bypassing the normal connection pool flow.
|
||||
func injectClientIntoManager(t *testing.T, serverName string, client mcpclient.MCPClient) *MCPToolManager {
|
||||
t.Helper()
|
||||
|
||||
m := NewMCPToolManager()
|
||||
|
||||
// Create a minimal connection pool and inject our client.
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
pool.mu.Lock()
|
||||
pool.connections[serverName] = &MCPConnection{
|
||||
client: client,
|
||||
serverName: serverName,
|
||||
isHealthy: true,
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
m.connectionPool = pool
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func TestLoadServerPrompts_Basic(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("review-pr",
|
||||
mcp.WithPromptDescription("Review a pull request"),
|
||||
mcp.WithArgument("pr_number",
|
||||
mcp.ArgumentDescription("The PR number to review"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
mcp.WithArgument("focus",
|
||||
mcp.ArgumentDescription("Area to focus on"),
|
||||
),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
prNum := req.Params.Arguments["pr_number"]
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "PR review prompt",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Please review PR #%s", prNum),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("explain-code",
|
||||
mcp.WithPromptDescription("Explain a piece of code"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: "Please explain the following code.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "github", client)
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
serverName: "github",
|
||||
isHealthy: true,
|
||||
}
|
||||
m.loadServerPrompts(ctx, "github", conn)
|
||||
|
||||
prompts := m.GetPrompts()
|
||||
if len(prompts) != 2 {
|
||||
t.Fatalf("expected 2 prompts, got %d", len(prompts))
|
||||
}
|
||||
|
||||
// Find review-pr prompt.
|
||||
var reviewPR *MCPPrompt
|
||||
for i := range prompts {
|
||||
if prompts[i].Name == "review-pr" {
|
||||
reviewPR = &prompts[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if reviewPR == nil {
|
||||
t.Fatal("review-pr prompt not found")
|
||||
}
|
||||
if reviewPR.Description != "Review a pull request" {
|
||||
t.Errorf("unexpected description: %q", reviewPR.Description)
|
||||
}
|
||||
if reviewPR.ServerName != "github" {
|
||||
t.Errorf("unexpected server name: %q", reviewPR.ServerName)
|
||||
}
|
||||
if len(reviewPR.Arguments) != 2 {
|
||||
t.Fatalf("expected 2 arguments, got %d", len(reviewPR.Arguments))
|
||||
}
|
||||
|
||||
// Verify argument metadata.
|
||||
arg0 := reviewPR.Arguments[0]
|
||||
if arg0.Name != "pr_number" {
|
||||
t.Errorf("expected first arg name 'pr_number', got %q", arg0.Name)
|
||||
}
|
||||
if !arg0.Required {
|
||||
t.Error("expected first arg to be required")
|
||||
}
|
||||
arg1 := reviewPR.Arguments[1]
|
||||
if arg1.Name != "focus" {
|
||||
t.Errorf("expected second arg name 'focus', got %q", arg1.Name)
|
||||
}
|
||||
if arg1.Required {
|
||||
t.Error("expected second arg to be optional")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_ExpandsWithArgs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("greet",
|
||||
mcp.WithPromptDescription("Greet someone"),
|
||||
mcp.WithArgument("name", mcp.RequiredArgument()),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
name := req.Params.Arguments["name"]
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Greeting",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Hello, %s!", name),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "myserver", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "myserver", "greet", map[string]string{"name": "World"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if result.Description != "Greeting" {
|
||||
t.Errorf("unexpected description: %q", result.Description)
|
||||
}
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Role != "user" {
|
||||
t.Errorf("unexpected role: %q", result.Messages[0].Role)
|
||||
}
|
||||
if result.Messages[0].Content != "Hello, World!" {
|
||||
t.Errorf("unexpected content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_MultipleMessages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("chat-starter"),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "What is Go?"},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Go is a programming language."},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Tell me more."},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "server", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "server", "chat-starter", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Role != "user" {
|
||||
t.Errorf("msg[0] role: got %q, want 'user'", result.Messages[0].Role)
|
||||
}
|
||||
if result.Messages[1].Role != "assistant" {
|
||||
t.Errorf("msg[1] role: got %q, want 'assistant'", result.Messages[1].Role)
|
||||
}
|
||||
if result.Messages[2].Content != "Tell me more." {
|
||||
t.Errorf("msg[2] content: got %q, want 'Tell me more.'", result.Messages[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_ServerNotFound(t *testing.T) {
|
||||
m := NewMCPToolManager()
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
m.connectionPool = pool
|
||||
|
||||
_, err := m.GetPrompt(context.Background(), "nonexistent", "foo", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_NoPool(t *testing.T) {
|
||||
m := NewMCPToolManager()
|
||||
|
||||
_, err := m.GetPrompt(context.Background(), "any", "foo", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error with no pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveServer_RemovesPrompts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("my-prompt",
|
||||
mcp.WithPromptDescription("A test prompt"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "hi"}},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "testsvr", client)
|
||||
|
||||
// Manually populate tools and prompts as loadServerTools would.
|
||||
conn := m.connectionPool.connections["testsvr"]
|
||||
m.loadServerPrompts(ctx, "testsvr", conn)
|
||||
|
||||
// Also add a fake tool mapping so RemoveServer finds the server.
|
||||
m.toolMap["testsvr__noop"] = &toolMapping{
|
||||
serverName: "testsvr",
|
||||
originalName: "noop",
|
||||
}
|
||||
m.tools = append(m.tools, MCPTool{
|
||||
Name: "testsvr__noop",
|
||||
ServerName: "testsvr",
|
||||
})
|
||||
|
||||
// Verify prompts exist before removal.
|
||||
if got := len(m.GetPrompts()); got != 1 {
|
||||
t.Fatalf("expected 1 prompt before removal, got %d", got)
|
||||
}
|
||||
|
||||
// Remove the server.
|
||||
err := m.RemoveServer("testsvr")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer error: %v", err)
|
||||
}
|
||||
|
||||
// Verify prompts are gone.
|
||||
if got := len(m.GetPrompts()); got != 0 {
|
||||
t.Fatalf("expected 0 prompts after removal, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadServerPrompts_NoPromptCapability(t *testing.T) {
|
||||
// Server without prompt capabilities — ListPrompts should fail gracefully.
|
||||
mcpServer := server.NewMCPServer("no-prompts", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
// No WithPromptCapabilities
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool("noop"),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
)
|
||||
|
||||
client, err := mcpclient.NewInProcessClient(mcpServer)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInProcessClient: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
_ = client.Start(ctx)
|
||||
initReq := mcp.InitializeRequest{}
|
||||
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
|
||||
_, _ = client.Initialize(ctx, initReq)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
m := NewMCPToolManager()
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
serverName: "no-prompts",
|
||||
isHealthy: true,
|
||||
}
|
||||
|
||||
// Should not panic or error — just silently skip.
|
||||
m.loadServerPrompts(ctx, "no-prompts", conn)
|
||||
|
||||
if got := len(m.GetPrompts()); got != 0 {
|
||||
t.Fatalf("expected 0 prompts from server without prompt capability, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPromptContent(t *testing.T) {
|
||||
t.Run("TextContent", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.TextContent{Type: "text", Text: "hello world"})
|
||||
if text != "hello world" {
|
||||
t.Errorf("text = %q, want %q", text, "hello world")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImageContent", func(t *testing.T) {
|
||||
// base64 of "fake image"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("fake image"))
|
||||
text, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: encoded,
|
||||
MIMEType: "image/png",
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "image/png" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "image/png")
|
||||
}
|
||||
if parts[0].Filename != "image.png" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "image.png")
|
||||
}
|
||||
if string(parts[0].Data) != "fake image" {
|
||||
t.Errorf("data = %q, want %q", string(parts[0].Data), "fake image")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImageContent_DefaultMIME", func(t *testing.T) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("img"))
|
||||
_, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: encoded,
|
||||
// no MIMEType → should default to image/png
|
||||
})
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "image/png" {
|
||||
t.Errorf("default MIME = %q, want %q", parts[0].MediaType, "image/png")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AudioContent", func(t *testing.T) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("fake audio"))
|
||||
text, parts := extractPromptContent(mcp.AudioContent{
|
||||
Type: "audio",
|
||||
Data: encoded,
|
||||
MIMEType: "audio/mp3",
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "audio/mp3" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "audio/mp3")
|
||||
}
|
||||
if parts[0].Filename != "audio.wav" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "audio.wav")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmbeddedResource_Text", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.TextResourceContents{
|
||||
URI: "file:///project/main.go",
|
||||
MIMEType: "text/x-go",
|
||||
Text: "package main",
|
||||
},
|
||||
})
|
||||
if text == "" {
|
||||
t.Fatal("expected non-empty text for text resource")
|
||||
}
|
||||
if !strings.Contains(text, "package main") {
|
||||
t.Errorf("text should contain resource content, got %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "file:///project/main.go") {
|
||||
t.Errorf("text should contain URI, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for text resource, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmbeddedResource_Blob", func(t *testing.T) {
|
||||
blobData := []byte("binary content")
|
||||
encoded := base64.StdEncoding.EncodeToString(blobData)
|
||||
text, parts := extractPromptContent(mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.BlobResourceContents{
|
||||
URI: "file:///project/data.bin",
|
||||
MIMEType: "application/octet-stream",
|
||||
Blob: encoded,
|
||||
},
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text for blob resource, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part for blob resource, got %d", len(parts))
|
||||
}
|
||||
if parts[0].Filename != "data.bin" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "data.bin")
|
||||
}
|
||||
if parts[0].MediaType != "application/octet-stream" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "application/octet-stream")
|
||||
}
|
||||
if string(parts[0].Data) != "binary content" {
|
||||
t.Errorf("data = %q, want %q", string(parts[0].Data), "binary content")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ResourceLink", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.ResourceLink{
|
||||
Type: "resource_link",
|
||||
URI: "file:///docs/readme.md",
|
||||
Name: "readme.md",
|
||||
})
|
||||
if text == "" {
|
||||
t.Fatal("expected non-empty text for resource link")
|
||||
}
|
||||
if !strings.Contains(text, "file:///docs/readme.md") {
|
||||
t.Errorf("text should contain URI, got %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "readme.md") {
|
||||
t.Errorf("text should contain name, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for resource link, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidBase64", func(t *testing.T) {
|
||||
_, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: "not-valid-base64!!!",
|
||||
MIMEType: "image/png",
|
||||
})
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for invalid base64, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NilContent", func(t *testing.T) {
|
||||
text, parts := extractPromptContent((*mcp.TextContent)(nil))
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text for nil, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 parts for nil, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilenameFromURI(t *testing.T) {
|
||||
tests := []struct {
|
||||
uri string
|
||||
want string
|
||||
}{
|
||||
{"file:///path/to/image.png", "image.png"},
|
||||
{"file:///single.txt", "single.txt"},
|
||||
{"resource://server/data.json", "data.json"},
|
||||
{"nopath", "nopath"},
|
||||
{"", "resource"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.uri, func(t *testing.T) {
|
||||
got := filenameFromURI(tt.uri)
|
||||
if got != tt.want {
|
||||
t.Errorf("filenameFromURI(%q) = %q, want %q", tt.uri, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_EmbeddedResources(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
imgData := base64.StdEncoding.EncodeToString([]byte("fake-png"))
|
||||
blobData := base64.StdEncoding.EncodeToString([]byte("binary-blob"))
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("review-with-files",
|
||||
mcp.WithPromptDescription("Review with embedded resources"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Review prompt with embedded files",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Please review these files:"},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.TextResourceContents{
|
||||
URI: "file:///src/main.go",
|
||||
MIMEType: "text/x-go",
|
||||
Text: "package main\n\nfunc main() {}",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: imgData,
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.BlobResourceContents{
|
||||
URI: "file:///data/model.bin",
|
||||
MIMEType: "application/octet-stream",
|
||||
Blob: blobData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "test", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "test", "review-with-files", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if result.Description != "Review prompt with embedded files" {
|
||||
t.Errorf("unexpected description: %q", result.Description)
|
||||
}
|
||||
|
||||
// Should have 4 messages: text, embedded text resource, image, embedded blob
|
||||
if len(result.Messages) != 4 {
|
||||
t.Fatalf("expected 4 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
// Message 0: plain text
|
||||
msg0 := result.Messages[0]
|
||||
if msg0.Content != "Please review these files:" {
|
||||
t.Errorf("msg[0] content = %q", msg0.Content)
|
||||
}
|
||||
if len(msg0.FileParts) != 0 {
|
||||
t.Errorf("msg[0] expected 0 file parts, got %d", len(msg0.FileParts))
|
||||
}
|
||||
|
||||
// Message 1: embedded text resource → inlined as text
|
||||
msg1 := result.Messages[1]
|
||||
if !strings.Contains(msg1.Content, "package main") {
|
||||
t.Errorf("msg[1] should contain resource text, got %q", msg1.Content)
|
||||
}
|
||||
if len(msg1.FileParts) != 0 {
|
||||
t.Errorf("msg[1] expected 0 file parts (text resource), got %d", len(msg1.FileParts))
|
||||
}
|
||||
|
||||
// Message 2: image → file part
|
||||
msg2 := result.Messages[2]
|
||||
if msg2.Content != "" {
|
||||
t.Errorf("msg[2] expected empty text for image, got %q", msg2.Content)
|
||||
}
|
||||
if len(msg2.FileParts) != 1 {
|
||||
t.Fatalf("msg[2] expected 1 file part, got %d", len(msg2.FileParts))
|
||||
}
|
||||
if msg2.FileParts[0].MediaType != "image/png" {
|
||||
t.Errorf("msg[2] file part MIME = %q", msg2.FileParts[0].MediaType)
|
||||
}
|
||||
if string(msg2.FileParts[0].Data) != "fake-png" {
|
||||
t.Errorf("msg[2] file part data = %q", string(msg2.FileParts[0].Data))
|
||||
}
|
||||
|
||||
// Message 3: embedded blob resource → file part
|
||||
msg3 := result.Messages[3]
|
||||
if msg3.Content != "" {
|
||||
t.Errorf("msg[3] expected empty text for blob resource, got %q", msg3.Content)
|
||||
}
|
||||
if len(msg3.FileParts) != 1 {
|
||||
t.Fatalf("msg[3] expected 1 file part, got %d", len(msg3.FileParts))
|
||||
}
|
||||
if msg3.FileParts[0].Filename != "model.bin" {
|
||||
t.Errorf("msg[3] filename = %q, want %q", msg3.FileParts[0].Filename, "model.bin")
|
||||
}
|
||||
if string(msg3.FileParts[0].Data) != "binary-blob" {
|
||||
t.Errorf("msg[3] file part data = %q", string(msg3.FileParts[0].Data))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,404 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// MCPTaskMode controls when the connection pool augments tools/call requests
|
||||
// with MCP task metadata. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks.
|
||||
type MCPTaskMode string
|
||||
|
||||
const (
|
||||
// MCPTaskModeAuto augments tools/call with task metadata only when the
|
||||
// server advertises tasks/toolCalls capability during initialize.
|
||||
MCPTaskModeAuto MCPTaskMode = "auto"
|
||||
// MCPTaskModeNever forces every tools/call to be issued synchronously
|
||||
// (no Task field in the request), regardless of server capability.
|
||||
MCPTaskModeNever MCPTaskMode = "never"
|
||||
// MCPTaskModeAlways always sets a Task field on the tools/call request,
|
||||
// even when the server didn't advertise task support. The server may
|
||||
// still respond synchronously; this just opts in unconditionally on
|
||||
// the client side.
|
||||
MCPTaskModeAlways MCPTaskMode = "always"
|
||||
)
|
||||
|
||||
// ParseTaskMode normalises a per-server tasks-mode string from
|
||||
// configuration. Empty input maps to MCPTaskModeAuto. Unknown values are
|
||||
// also treated as MCPTaskModeAuto so a stray config typo never breaks
|
||||
// existing flows.
|
||||
func ParseTaskMode(s string) MCPTaskMode {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "auto":
|
||||
return MCPTaskModeAuto
|
||||
case "never", "off", "disabled":
|
||||
return MCPTaskModeNever
|
||||
case "always", "force":
|
||||
return MCPTaskModeAlways
|
||||
default:
|
||||
return MCPTaskModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// MCPTaskInfo is the connection-layer view of an MCP Task. It mirrors the
|
||||
// upstream mcp.Task but exposes Go-native types and includes the originating
|
||||
// server name. SDK-level wrappers re-export this under public-facing names.
|
||||
type MCPTaskInfo struct {
|
||||
// Server is the configured MCP server name this task lives on.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the task.
|
||||
TaskID string
|
||||
// Status is the current task lifecycle state.
|
||||
Status mcp.TaskStatus
|
||||
// StatusMessage is an optional human-readable description.
|
||||
StatusMessage string
|
||||
// CreatedAt is the wall-clock time the task was created (best-effort
|
||||
// parsed from the server's ISO-8601 timestamp; zero on parse failure).
|
||||
CreatedAt time.Time
|
||||
// UpdatedAt is the wall-clock time the task was last updated (best-
|
||||
// effort parsed; zero on parse failure).
|
||||
UpdatedAt time.Time
|
||||
// TTL is the time-to-live the server intends to retain the task after
|
||||
// creation. Zero means the server did not advertise a TTL.
|
||||
TTL time.Duration
|
||||
// PollInterval is the suggested polling interval. Zero means use the
|
||||
// client's default.
|
||||
PollInterval time.Duration
|
||||
}
|
||||
|
||||
// MCPTaskProgress is emitted while the connection pool is waiting on a
|
||||
// task-augmented tool call. It provides minimal feedback for SDK consumers
|
||||
// that want to render progress widgets without subscribing to the full
|
||||
// notifications/tasks/status channel (Phase 2).
|
||||
type MCPTaskProgress struct {
|
||||
Server string
|
||||
TaskID string
|
||||
Status mcp.TaskStatus
|
||||
Message string
|
||||
}
|
||||
|
||||
// MCPTaskProgressHandler is invoked once after a task is accepted and on
|
||||
// every status transition observed by the polling loop. The final
|
||||
// invocation always carries a terminal status. Implementations must not
|
||||
// block; long work should be queued on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// MCPTaskConfig configures task-aware tool execution on the manager.
|
||||
// All fields are optional; the zero value disables progress callbacks and
|
||||
// applies sensible defaults.
|
||||
type MCPTaskConfig struct {
|
||||
// PerServerMode overrides the per-server TasksMode resolved from
|
||||
// MCPServerConfig. Keys are server names. Missing entries fall back
|
||||
// to the value from config. Used by SDK consumers that want to set
|
||||
// modes programmatically.
|
||||
PerServerMode map[string]MCPTaskMode
|
||||
|
||||
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
|
||||
// tools/call. Zero means omit the TTL — let the server pick its own.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// PollInterval is the fallback interval between tasks/get requests
|
||||
// when the server does not suggest one. Zero defaults to 1 second.
|
||||
PollInterval time.Duration
|
||||
|
||||
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
|
||||
MaxPollInterval time.Duration
|
||||
|
||||
// Timeout is the maximum wall-clock duration to wait for a task to
|
||||
// reach a terminal state. Zero defaults to 15 minutes. Independent
|
||||
// of the per-call context deadline; whichever fires first wins.
|
||||
Timeout time.Duration
|
||||
|
||||
// Progress, if non-nil, receives every status transition observed by
|
||||
// the polling loop.
|
||||
Progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
func (c MCPTaskConfig) resolved() MCPTaskConfig {
|
||||
if c.PollInterval <= 0 {
|
||||
c.PollInterval = 1 * time.Second
|
||||
}
|
||||
if c.MaxPollInterval <= 0 {
|
||||
c.MaxPollInterval = 5 * time.Second
|
||||
}
|
||||
if c.Timeout <= 0 {
|
||||
c.Timeout = 15 * time.Minute
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// requestIDCounter generates monotonically increasing JSON-RPC request IDs
|
||||
// for low-level tools/call invocations that bypass the upstream client's
|
||||
// ParseCallToolResult helper (necessary because that helper rejects task
|
||||
// responses for lacking a "content" field).
|
||||
//
|
||||
// The counter is process-wide rather than per-manager so multiple managers
|
||||
// or repeated calls within the same connection produce unique IDs.
|
||||
var requestIDCounter atomic.Int64
|
||||
|
||||
func nextRequestID() mcp.RequestId {
|
||||
return mcp.NewRequestId(requestIDCounter.Add(1))
|
||||
}
|
||||
|
||||
// callToolWithTask issues tools/call directly on the transport so we can
|
||||
// observe both response shapes:
|
||||
//
|
||||
// - {"content": [...], ...} — synchronous CallToolResult.
|
||||
// - {"task": {...}, ...} — asynchronous CreateTaskResult.
|
||||
//
|
||||
// On success exactly one of (callResult, taskResult) is non-nil. The
|
||||
// upstream client.CallTool helper parses the response with
|
||||
// mcp.ParseCallToolResult which requires a "content" field, so it cannot
|
||||
// be used for task-augmented calls.
|
||||
func callToolWithTask(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
params mcp.CallToolParams,
|
||||
) (callResult *mcp.CallToolResult, taskResult *mcp.CreateTaskResult, err error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
Params: params,
|
||||
}
|
||||
|
||||
resp, sendErr := tr.SendRequest(ctx, req)
|
||||
if sendErr != nil {
|
||||
return nil, nil, sendErr
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Peek at the raw result to decide which shape we got.
|
||||
var probe struct {
|
||||
Task json.RawMessage `json:"task"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
raw := resp.Result
|
||||
if len(raw) == 0 {
|
||||
return nil, nil, errors.New("empty tools/call result")
|
||||
}
|
||||
if uErr := json.Unmarshal(raw, &probe); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode tools/call result: %w", uErr)
|
||||
}
|
||||
|
||||
if len(probe.Task) > 0 && string(probe.Task) != "null" {
|
||||
// Task-augmented response.
|
||||
var ct mcp.CreateTaskResult
|
||||
if uErr := json.Unmarshal(raw, &ct); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode CreateTaskResult: %w", uErr)
|
||||
}
|
||||
return nil, &ct, nil
|
||||
}
|
||||
|
||||
// Synchronous response — defer to the upstream parser so content blocks
|
||||
// are typed correctly (TextContent, ImageContent, ResourceLink, etc.).
|
||||
cr, pErr := mcp.ParseCallToolResult(&raw)
|
||||
if pErr != nil {
|
||||
return nil, nil, fmt.Errorf("parse CallToolResult: %w", pErr)
|
||||
}
|
||||
return cr, nil, nil
|
||||
}
|
||||
|
||||
// pollTaskUntilTerminal blocks until the task reaches a terminal status,
|
||||
// the context is cancelled, or the configured timeout elapses. On
|
||||
// cancellation it best-effort issues tasks/cancel before returning.
|
||||
func pollTaskUntilTerminal(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
serverName string,
|
||||
task mcp.Task,
|
||||
cfg MCPTaskConfig,
|
||||
progress MCPTaskProgressHandler,
|
||||
) (*mcp.TaskResultResult, error) {
|
||||
cfg = cfg.resolved()
|
||||
deadline := time.Now().Add(cfg.Timeout)
|
||||
|
||||
emit := func(status mcp.TaskStatus, msg string) {
|
||||
if progress != nil {
|
||||
progress(MCPTaskProgress{Server: serverName, TaskID: task.TaskId, Status: status, Message: msg})
|
||||
}
|
||||
}
|
||||
|
||||
emit(task.Status, task.StatusMessage)
|
||||
|
||||
current := task
|
||||
interval := cfg.PollInterval
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = time.Duration(*current.PollInterval) * time.Millisecond
|
||||
}
|
||||
if interval > cfg.MaxPollInterval {
|
||||
interval = cfg.MaxPollInterval
|
||||
}
|
||||
|
||||
for !current.Status.IsTerminal() {
|
||||
if time.Now().After(deadline) {
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, fmt.Errorf("task %s timed out after %s", current.TaskId, cfg.Timeout)
|
||||
}
|
||||
|
||||
// Wait between polls or abort early on context cancellation.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
}
|
||||
|
||||
got, err := c.GetTask(ctx, mcp.GetTaskRequest{
|
||||
Params: mcp.GetTaskParams{TaskId: current.TaskId},
|
||||
})
|
||||
if err != nil {
|
||||
// Transient transport hiccup — propagate immediately. The
|
||||
// upstream agent layer treats this like any other tool error.
|
||||
return nil, fmt.Errorf("tasks/get failed: %w", err)
|
||||
}
|
||||
current = got.Task
|
||||
if current.Status != task.Status || current.StatusMessage != task.StatusMessage {
|
||||
emit(current.Status, current.StatusMessage)
|
||||
task = current
|
||||
}
|
||||
|
||||
// Honour any updated suggested poll interval, capped at the limit.
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = min(time.Duration(*current.PollInterval)*time.Millisecond, cfg.MaxPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// Terminal state reached. Emit one last progress event and fetch the
|
||||
// definitive tool result.
|
||||
emit(current.Status, current.StatusMessage)
|
||||
|
||||
if current.Status == mcp.TaskStatusCancelled {
|
||||
return nil, fmt.Errorf("task %s was cancelled", current.TaskId)
|
||||
}
|
||||
|
||||
res, err := fetchTaskResult(ctx, c, current.TaskId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tasks/result failed: %w", err)
|
||||
}
|
||||
if current.Status == mcp.TaskStatusFailed && res != nil && !res.IsError {
|
||||
// The server flagged the task as failed but didn't decorate the
|
||||
// result. Surface the status message so the caller still sees a
|
||||
// useful tool-error.
|
||||
return nil, fmt.Errorf("task %s failed: %s", current.TaskId, current.StatusMessage)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// fetchTaskResult issues tasks/result on the transport and parses the raw
|
||||
// response. The upstream client.TaskResult helper delegates to
|
||||
// mcp.ParseTaskResultResult which (as of mcp-go v0.51.0) looks for the
|
||||
// content array under a nested "result" key that never exists in the
|
||||
// wire format — leading to systematically empty Content. Doing the
|
||||
// parse here keeps the polling path working until that is fixed upstream.
|
||||
func fetchTaskResult(ctx context.Context, c *client.Client, taskID string) (*mcp.TaskResultResult, error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodTasksResult),
|
||||
Params: mcp.TaskResultParams{TaskId: taskID},
|
||||
}
|
||||
resp, err := tr.SendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Manually decode the wire shape: {"_meta": {...}, "content": [...],
|
||||
// "structuredContent": ..., "isError": bool}.
|
||||
var shape struct {
|
||||
Meta json.RawMessage `json:"_meta"`
|
||||
Content []json.RawMessage `json:"content"`
|
||||
StructuredContent any `json:"structuredContent"`
|
||||
IsError bool `json:"isError"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Result, &shape); err != nil {
|
||||
return nil, fmt.Errorf("decode tasks/result: %w", err)
|
||||
}
|
||||
|
||||
out := &mcp.TaskResultResult{
|
||||
StructuredContent: shape.StructuredContent,
|
||||
IsError: shape.IsError,
|
||||
}
|
||||
if len(shape.Meta) > 0 && string(shape.Meta) != "null" {
|
||||
var metaMap map[string]any
|
||||
if err := json.Unmarshal(shape.Meta, &metaMap); err == nil {
|
||||
out.Meta = mcp.NewMetaFromMap(metaMap)
|
||||
}
|
||||
}
|
||||
for _, raw := range shape.Content {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(raw, &contentMap); err != nil {
|
||||
return nil, fmt.Errorf("decode content block: %w", err)
|
||||
}
|
||||
parsed, err := mcp.ParseContent(contentMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse content block: %w", err)
|
||||
}
|
||||
out.Content = append(out.Content, parsed)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cancelTaskBestEffort issues tasks/cancel and ignores any error. Used on
|
||||
// context cancellation paths where the connection is already going away.
|
||||
func cancelTaskBestEffort(c *client.Client, taskID string) {
|
||||
if c == nil || taskID == "" {
|
||||
return
|
||||
}
|
||||
cancelCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _ = c.CancelTask(cancelCtx, mcp.CancelTaskRequest{
|
||||
Params: mcp.CancelTaskParams{TaskId: taskID},
|
||||
})
|
||||
}
|
||||
|
||||
// taskFromMCP converts a wire-format mcp.Task to our richer connection-
|
||||
// layer view. Unparseable timestamps surface as the zero time.
|
||||
func taskFromMCP(serverName string, t mcp.Task) MCPTaskInfo {
|
||||
out := MCPTaskInfo{
|
||||
Server: serverName,
|
||||
TaskID: t.TaskId,
|
||||
Status: t.Status,
|
||||
StatusMessage: t.StatusMessage,
|
||||
}
|
||||
if t.CreatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.CreatedAt); err == nil {
|
||||
out.CreatedAt = v
|
||||
}
|
||||
}
|
||||
if t.LastUpdatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.LastUpdatedAt); err == nil {
|
||||
out.UpdatedAt = v
|
||||
}
|
||||
}
|
||||
if t.TTL != nil {
|
||||
out.TTL = time.Duration(*t.TTL) * time.Millisecond
|
||||
}
|
||||
if t.PollInterval != nil {
|
||||
out.PollInterval = time.Duration(*t.PollInterval) * time.Millisecond
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTaskTestInProcessServer builds an in-process MCP server with a
|
||||
// task-augmented tool. The handler simulates work by sleeping briefly
|
||||
// before completing.
|
||||
//
|
||||
// Important: the upstream mcp-go server cancels the request context as
|
||||
// soon as the synchronous part of the tools/call returns (see
|
||||
// request_handler.go:85, `defer cancel()`). Task goroutines spawned by
|
||||
// AddTaskTool inherit that context and therefore see context.Canceled
|
||||
// the instant they start. Real-world transports (stdio, SSE, streamable
|
||||
// HTTP) don't trip this because they keep the connection — and a
|
||||
// background context — alive across the async work, but the in-process
|
||||
// transport runs entirely on the request goroutine. To test the polling
|
||||
// path realistically we detach from the request context here.
|
||||
func newTaskTestInProcessServer(t *testing.T, workDuration time.Duration) *server.MCPServer {
|
||||
t.Helper()
|
||||
srv := server.NewMCPServer("task-test", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
// list=true, cancel=true, toolCallTasks=true so capability detection,
|
||||
// cancellation, and tool augmentation all flow through.
|
||||
server.WithTaskCapabilities(true, true, true),
|
||||
)
|
||||
srv.AddTaskTool(
|
||||
mcp.Tool{
|
||||
Name: "long_running",
|
||||
Description: "Sleep, then echo the input string.",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"msg": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
Execution: &mcp.ToolExecution{
|
||||
TaskSupport: mcp.TaskSupportRequired,
|
||||
},
|
||||
},
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CreateTaskResult, error) {
|
||||
msg, _ := req.GetArguments()["msg"].(string)
|
||||
// Detach from the request context so the task handler can
|
||||
// outlive the synchronous request — see comment above.
|
||||
time.Sleep(workDuration)
|
||||
_ = ctx
|
||||
return &mcp.CreateTaskResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "echo:" + msg},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
// newSyncOnlyServer is a server that does NOT advertise task capability.
|
||||
// Used to verify the auto-detect path keeps the sync semantics.
|
||||
func newSyncOnlyServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("sync-only", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required()),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("hi " + name), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestConnectionPoolAdvertisesTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
srv := newTaskTestInProcessServer(t, 0)
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: srv}
|
||||
|
||||
conn, err := pool.GetConnection(context.Background(), "tasks", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
|
||||
init := conn.InitializeResult()
|
||||
if init == nil {
|
||||
t.Fatal("InitializeResult is nil after GetConnection")
|
||||
}
|
||||
if init.Capabilities.Tasks == nil {
|
||||
t.Fatal("server did not advertise Tasks capability — initialize handshake regressed")
|
||||
}
|
||||
if !conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be true for a server with toolCallTasks=true")
|
||||
}
|
||||
if !pool.ServerSupportsToolTasks("tasks") {
|
||||
t.Error("ServerSupportsToolTasks should mirror the connection's value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolDetectsAbsentTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: newSyncOnlyServer()}
|
||||
conn, err := pool.GetConnection(context.Background(), "sync", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
if conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be false for a server that didn't advertise the capability")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsToolTasksFromInit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *mcp.InitializeResult
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"no tasks", &mcp.InitializeResult{}, false},
|
||||
{"tasks no requests", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: &mcp.TasksCapability{}},
|
||||
}, false},
|
||||
{"tasks with toolCalls", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: mcp.NewTasksCapability()},
|
||||
}, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := supportsToolTasksFromInit(tc.in); got != tc.want {
|
||||
t.Errorf("supportsToolTasksFromInit() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTaskMode(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want MCPTaskMode
|
||||
}{
|
||||
{"", MCPTaskModeAuto},
|
||||
{"auto", MCPTaskModeAuto},
|
||||
{"AUTO", MCPTaskModeAuto},
|
||||
{"never", MCPTaskModeNever},
|
||||
{"off", MCPTaskModeNever},
|
||||
{"always", MCPTaskModeAlways},
|
||||
{"force", MCPTaskModeAlways},
|
||||
{"bogus", MCPTaskModeAuto},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := ParseTaskMode(tc.in); got != tc.want {
|
||||
t.Errorf("ParseTaskMode(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolPollsTaskToCompletion(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 20 * time.Millisecond,
|
||||
MaxPollInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hello"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if res.IsError {
|
||||
t.Fatalf("expected non-error result, got %s", res.Content)
|
||||
}
|
||||
if !strings.Contains(res.Content, "echo:hello") {
|
||||
t.Errorf("expected result to contain 'echo:hello', got %s", res.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolHonorsNeverMode(t *testing.T) {
|
||||
// Even though the server advertises tasks/toolCalls, "never" should
|
||||
// keep the call synchronous. Since the tool is TaskSupportRequired,
|
||||
// the server returns an error rather than running it sync — we just
|
||||
// verify the error surfaces (not a poll-loop hang).
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PerServerMode: map[string]MCPTaskMode{"tasks": MCPTaskModeNever},
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// We don't care which way the server fails the sync call; we just want
|
||||
// to confirm we didn't hang in the polling loop and didn't panic.
|
||||
_, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"x"}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error when forcing sync execution of a task-required tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolEmitsProgress(t *testing.T) {
|
||||
var statuses []mcp.TaskStatus
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 10 * time.Millisecond,
|
||||
MaxPollInterval: 25 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
Progress: func(p MCPTaskProgress) {
|
||||
statuses = append(statuses, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 30*time.Millisecond),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hi"}`); err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if len(statuses) == 0 {
|
||||
t.Fatal("expected at least one progress event")
|
||||
}
|
||||
last := statuses[len(statuses)-1]
|
||||
if !last.IsTerminal() {
|
||||
t.Errorf("last progress event should be terminal, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGetCancelMCPTasksOnLoadedServer(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// tasks/list — no in-flight tasks yet, so we just verify the call
|
||||
// succeeds and returns an empty slice (or any slice; the exact length
|
||||
// depends on server retention policy).
|
||||
if _, err := mgr.ListServerTasks(ctx, "tasks"); err != nil {
|
||||
t.Errorf("ListServerTasks: %v", err)
|
||||
}
|
||||
|
||||
// Unknown server should error cleanly without panicking.
|
||||
if _, err := mgr.GetServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("GetServerTask on unknown server should error")
|
||||
}
|
||||
if _, err := mgr.CancelServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("CancelServerTask on unknown server should error")
|
||||
}
|
||||
}
|
||||
@@ -103,14 +103,12 @@ func TestMCPToolManager_EmptyConfig(t *testing.T) {
|
||||
|
||||
// Test that we can get tool info for each tool
|
||||
for _, tool := range tools {
|
||||
info := tool.Info()
|
||||
|
||||
// Check that the tool has a valid name
|
||||
if info.Name == "" {
|
||||
if tool.Name == "" {
|
||||
t.Error("Tool has empty name")
|
||||
}
|
||||
|
||||
t.Logf("Tool: %s, Description: %s", info.Name, info.Description)
|
||||
t.Logf("Tool: %s, Description: %s", tool.Name, tool.Description)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -28,15 +28,6 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoBorder returns a renderingOption that disables all borders on the
|
||||
// block, rendering content with only padding.
|
||||
func WithNoBorder() renderingOption {
|
||||
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackground returns a renderingOption that sets the background color
|
||||
// for the entire block. The color parameter accepts any color.Color value,
|
||||
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
|
||||
func WithBackground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.background = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
}
|
||||
}
|
||||
|
||||
// renderContentBlock renders content with configurable styling options
|
||||
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
|
||||
renderer := &blockRenderer{
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
|
||||
// newTestInput creates an InputComponent with the given AppController (may be nil).
|
||||
func newTestInput(ctrl AppController) *InputComponent {
|
||||
return NewInputComponent(80, "test input", ctrl)
|
||||
return NewInputComponent(80, ctrl)
|
||||
}
|
||||
|
||||
// sendInputMsg calls component.Update with the given message, returns the
|
||||
@@ -69,30 +69,6 @@ func TestInputComponent_SubmitEmitsSubmitMsg(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputComponent_CtrlD_SubmitEmitsSubmitMsg verifies that ctrl+d also
|
||||
// submits the text.
|
||||
func TestInputComponent_CtrlD_SubmitEmitsSubmitMsg(t *testing.T) {
|
||||
ctrl := &stubAppController{}
|
||||
c := newTestInput(ctrl)
|
||||
|
||||
c.textarea.SetValue("ctrl+d submit")
|
||||
c.lastValue = "ctrl+d submit"
|
||||
|
||||
_, cmd := sendInputMsg(c, tea.KeyPressMsg{Code: 'd', Mod: tea.ModCtrl})
|
||||
|
||||
msg := runCmd(cmd)
|
||||
if msg == nil {
|
||||
t.Fatal("expected a cmd from ctrl+d on non-empty input")
|
||||
}
|
||||
sm, ok := msg.(core.SubmitMsg)
|
||||
if !ok {
|
||||
t.Fatalf("expected submitMsg from ctrl+d, got %T", msg)
|
||||
}
|
||||
if sm.Text != "ctrl+d submit" {
|
||||
t.Fatalf("expected Text='ctrl+d submit', got %q", sm.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputComponent_EmptySubmit_NoCmd verifies that submitting an empty or
|
||||
// whitespace-only string produces no cmd.
|
||||
func TestInputComponent_EmptySubmit_NoCmd(t *testing.T) {
|
||||
|
||||
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
|
||||
@@ -84,7 +84,7 @@ var SlashCommands = []SlashCommand{
|
||||
},
|
||||
{
|
||||
Name: "/thinking",
|
||||
Description: "Set thinking/reasoning level (off, minimal, low, medium, high)",
|
||||
Description: "Set thinking/reasoning level (off, none, minimal, low, medium, high)",
|
||||
Category: "System",
|
||||
Aliases: []string{"/think"},
|
||||
Complete: func(prefix string) []string {
|
||||
@@ -161,6 +161,12 @@ var SlashCommands = []SlashCommand{
|
||||
Category: "Navigation",
|
||||
Aliases: []string{"/r"},
|
||||
},
|
||||
{
|
||||
Name: "/copy",
|
||||
Description: "Copy the last message to the system clipboard",
|
||||
Category: "System",
|
||||
Aliases: []string{"/cp"},
|
||||
},
|
||||
{
|
||||
Name: "/export",
|
||||
Description: "Export session (JSONL by default, or /export path.jsonl)",
|
||||
@@ -199,18 +205,6 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
names = append(names, cmd.Name)
|
||||
names = append(names, cmd.Aliases...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ExtensionCommand is a slash command registered by an extension. Unlike
|
||||
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
|
||||
// extension commands carry their own Execute callback.
|
||||
|
||||
@@ -25,6 +25,11 @@ type SubmitMsg struct {
|
||||
// presses ESC a second time, the canceling state is reset to false.
|
||||
type CancelTimerExpiredMsg struct{}
|
||||
|
||||
// CtrlCResetMsg is sent after a short delay when the user presses Ctrl+C to
|
||||
// clear input. If the user doesn't press Ctrl+C again within the timeout,
|
||||
// the ctrlCPressedOnce flag is reset so the next Ctrl+C will clear again.
|
||||
type CtrlCResetMsg struct{}
|
||||
|
||||
// --- Tree session events ---
|
||||
|
||||
// TreeNodeSelectedMsg is sent when the user selects a node in the tree selector.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
}
|
||||
|
||||
// Format the message to include all the debug info in a structured way
|
||||
var formattedMessage string
|
||||
|
||||
// Check if this is a multi-line debug output (like connection info)
|
||||
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
|
||||
// Extract the tag and content
|
||||
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
|
||||
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Add appropriate emoji based on the message content
|
||||
if strings.Contains(content, "Creating new connection") {
|
||||
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
|
||||
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
|
||||
} else if strings.Contains(content, "Reusing") {
|
||||
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
|
||||
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
|
||||
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
|
||||
} else if strings.Contains(content, "closed") {
|
||||
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Failed to close") {
|
||||
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
|
||||
} else {
|
||||
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
|
||||
// Use the CLI's debug message rendering
|
||||
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
@@ -29,9 +29,16 @@ type (
|
||||
ExtensionCommand = commands.ExtensionCommand
|
||||
)
|
||||
|
||||
// Re-export functions from fileutil package
|
||||
// Re-export functions and types from fileutil package
|
||||
var ProcessFileAttachments = fileutil.ProcessFileAttachments
|
||||
|
||||
// Re-export types from fileutil
|
||||
type (
|
||||
FileAttachmentResult = fileutil.FileAttachmentResult
|
||||
FilePart = fileutil.FilePart
|
||||
MCPResourceReader = fileutil.MCPResourceReader
|
||||
)
|
||||
|
||||
// Re-export from prefs package
|
||||
var (
|
||||
LoadThemePreference = prefs.LoadThemePreference
|
||||
|
||||
@@ -6,22 +6,78 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileSuggestion represents a single file or directory suggestion for the @
|
||||
// autocomplete popup.
|
||||
// FileSuggestion represents a single file, directory, or MCP resource
|
||||
// suggestion for the @ autocomplete popup.
|
||||
type FileSuggestion struct {
|
||||
// RelPath is the path relative to the search base (e.g. "cmd/kit/main.go").
|
||||
// RelPath is the path relative to the search base (e.g. "cmd/kit/main.go")
|
||||
// or a display name for MCP resources (e.g. "mcp:server/resource-name").
|
||||
RelPath string
|
||||
// IsDir is true when the entry is a directory.
|
||||
IsDir bool
|
||||
// Score is the fuzzy match score (higher is better).
|
||||
Score int
|
||||
// IsMCPResource is true for MCP resource entries.
|
||||
IsMCPResource bool
|
||||
// MCPServerName is the MCP server name (set when IsMCPResource is true).
|
||||
MCPServerName string
|
||||
// MCPResourceURI is the MCP resource URI (set when IsMCPResource is true).
|
||||
MCPResourceURI string
|
||||
// MCPMIMEType is the MIME type hint from the MCP server.
|
||||
MCPMIMEType string
|
||||
}
|
||||
|
||||
// maxFileSuggestions is the maximum number of file suggestions returned.
|
||||
const maxFileSuggestions = 20
|
||||
|
||||
// fileListCache caches the result of listFiles() keyed by directory to avoid
|
||||
// re-running git subprocesses on every keystroke during @file completion.
|
||||
var fileListCache struct {
|
||||
mu sync.Mutex
|
||||
dir string // searchDir that produced the cached entries
|
||||
cwd string // cwd used for the git query
|
||||
entries []FileSuggestion // cached file list
|
||||
expireAt time.Time // when the cache entry expires
|
||||
}
|
||||
|
||||
// fileListCacheTTL controls how long a cached file list stays valid.
|
||||
// During rapid typing the list is reused; after the TTL a fresh git
|
||||
// ls-files is executed so newly created files become visible.
|
||||
const fileListCacheTTL = 3 * time.Second
|
||||
|
||||
// getCachedFileList returns the file list for searchDir, using a short-lived
|
||||
// cache to avoid repeated subprocess calls during @file autocompletion.
|
||||
func getCachedFileList(searchDir, cwd string) []FileSuggestion {
|
||||
fileListCache.mu.Lock()
|
||||
defer fileListCache.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
if fileListCache.dir == searchDir &&
|
||||
fileListCache.cwd == cwd &&
|
||||
now.Before(fileListCache.expireAt) {
|
||||
// Return a copy so callers can mutate (e.g. prepend baseDir).
|
||||
cp := make([]FileSuggestion, len(fileListCache.entries))
|
||||
copy(cp, fileListCache.entries)
|
||||
return cp
|
||||
}
|
||||
|
||||
// Cache miss or expired — run the real (potentially expensive) lookup.
|
||||
files := listFiles(searchDir, cwd)
|
||||
|
||||
fileListCache.dir = searchDir
|
||||
fileListCache.cwd = cwd
|
||||
fileListCache.entries = files
|
||||
fileListCache.expireAt = now.Add(fileListCacheTTL)
|
||||
|
||||
// Return a copy.
|
||||
cp := make([]FileSuggestion, len(files))
|
||||
copy(cp, files)
|
||||
return cp
|
||||
}
|
||||
|
||||
// ExtractAtPrefix checks the current line for an @-file trigger at cursorCol.
|
||||
// It returns:
|
||||
// - hasAt: true if a valid @ trigger was found
|
||||
@@ -90,7 +146,7 @@ func GetFileSuggestions(prefix string, cwd string) []FileSuggestion {
|
||||
}
|
||||
}
|
||||
|
||||
files := listFiles(searchDir, cwd)
|
||||
files := getCachedFileList(searchDir, cwd)
|
||||
if len(files) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
@@ -10,31 +12,75 @@ import (
|
||||
"github.com/mark3labs/kit/internal/fences"
|
||||
)
|
||||
|
||||
// FilePart represents a binary file attachment (image, audio, etc.) extracted
|
||||
// from an @file reference. Callers convert this to kit.LLMFilePart before
|
||||
// sending to the LLM. Defined here to avoid a circular dependency on pkg/kit.
|
||||
type FilePart struct {
|
||||
// Filename is the basename of the file (e.g. "photo.png").
|
||||
Filename string
|
||||
// Data is the raw file bytes.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/png", "audio/wav").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// MCPResourceReader is a callback function that reads an MCP resource by
|
||||
// server name and URI. Returns text content, binary data, MIME type, and error.
|
||||
// Used by ProcessFileAttachments to resolve @mcp:server:uri tokens.
|
||||
type MCPResourceReader func(serverName, uri string) (text string, blobData []byte, mimeType string, isBlob bool, err error)
|
||||
|
||||
// FileAttachmentResult is the result of processing @file references in user
|
||||
// input. Text files are inlined as XML in ProcessedText; binary files (images,
|
||||
// audio, video, PDFs) are returned as FileParts for multimodal submission.
|
||||
type FileAttachmentResult struct {
|
||||
// ProcessedText is the user's text with @file tokens replaced:
|
||||
// text files become XML-wrapped content, binary file tokens are removed.
|
||||
ProcessedText string
|
||||
// FileParts contains binary file attachments extracted from @file
|
||||
// references. Empty when all referenced files are text.
|
||||
FileParts []FilePart
|
||||
}
|
||||
|
||||
// fileTokenPattern matches @file references in user text. Supports:
|
||||
// - @"path with spaces.txt" (quoted)
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// ProcessFileAttachments scans the user's input text for @file references,
|
||||
// reads each referenced file, and returns the text with @tokens replaced by
|
||||
// XML-wrapped file content. Non-file @ tokens (like email addresses) are left
|
||||
// unchanged.
|
||||
// reads each referenced file, and returns a result containing the processed
|
||||
// text and any binary file attachments. Text files are XML-wrapped inline;
|
||||
// binary files (images, audio, etc.) are extracted as FileParts for multimodal
|
||||
// submission. Non-file @ tokens (like email addresses) are left unchanged.
|
||||
//
|
||||
// Returns the original text unchanged if no valid @file references are found.
|
||||
func ProcessFileAttachments(text string, cwd string) string {
|
||||
return fences.ReplaceOutside(text, func(segment string) string {
|
||||
return processFileTokens(segment, cwd)
|
||||
// MCP resources are supported via @mcp:server:uri tokens. The optional
|
||||
// mcpReader callback is used to resolve them; pass nil to skip MCP resources.
|
||||
func ProcessFileAttachments(text string, cwd string, mcpReader ...MCPResourceReader) FileAttachmentResult {
|
||||
var reader MCPResourceReader
|
||||
if len(mcpReader) > 0 {
|
||||
reader = mcpReader[0]
|
||||
}
|
||||
var allParts []FilePart
|
||||
processed := fences.ReplaceOutside(text, func(segment string) string {
|
||||
result, parts := processFileTokens(segment, cwd, reader)
|
||||
allParts = append(allParts, parts...)
|
||||
return result
|
||||
})
|
||||
return FileAttachmentResult{
|
||||
ProcessedText: processed,
|
||||
FileParts: allParts,
|
||||
}
|
||||
}
|
||||
|
||||
// processFileTokens handles @file replacement in a single text segment
|
||||
// that is known to be outside fenced code blocks.
|
||||
func processFileTokens(text string, cwd string) string {
|
||||
// that is known to be outside fenced code blocks. Returns the processed
|
||||
// text and any binary file parts extracted.
|
||||
func processFileTokens(text string, cwd string, mcpReader MCPResourceReader) (string, []FilePart) {
|
||||
tokens := fileTokenPattern.FindAllString(text, -1)
|
||||
if len(tokens) == 0 {
|
||||
return text
|
||||
return text, nil
|
||||
}
|
||||
|
||||
var parts []FilePart
|
||||
result := text
|
||||
for _, token := range tokens {
|
||||
path := tokenToPath(token)
|
||||
@@ -42,6 +88,43 @@ func processFileTokens(text string, cwd string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for MCP resource reference: @mcp:server:uri
|
||||
if strings.HasPrefix(path, "mcp:") {
|
||||
if mcpReader == nil {
|
||||
continue
|
||||
}
|
||||
mcpRef := path[4:] // strip "mcp:"
|
||||
// Split into server:uri (first colon separates server from URI)
|
||||
serverName, uri, ok := strings.Cut(mcpRef, ":")
|
||||
if !ok || serverName == "" || uri == "" {
|
||||
continue // invalid format
|
||||
}
|
||||
|
||||
textContent, blobData, mimeType, isBlob, err := mcpReader(serverName, uri)
|
||||
if err != nil {
|
||||
continue // skip on error, leave token as-is
|
||||
}
|
||||
|
||||
if isBlob {
|
||||
// Binary MCP resource → extract as FilePart.
|
||||
filename := filepath.Base(uri)
|
||||
if filename == "." || filename == "/" {
|
||||
filename = serverName + "_resource"
|
||||
}
|
||||
parts = append(parts, FilePart{
|
||||
Filename: filename,
|
||||
Data: blobData,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
result = strings.Replace(result, token, "", 1)
|
||||
} else {
|
||||
// Text MCP resource → inline as XML.
|
||||
wrapped := fmt.Sprintf("<resource uri=\"%s\" server=\"%s\">\n%s\n</resource>", uri, serverName, textContent)
|
||||
result = strings.Replace(result, token, wrapped, 1)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
absPath, err := resolvePath(path, cwd)
|
||||
if err != nil {
|
||||
// Not a valid file reference — leave the token as-is.
|
||||
@@ -69,12 +152,28 @@ func processFileTokens(text string, cwd string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build the XML-wrapped replacement.
|
||||
wrapped := wrapFileContent(absPath, content)
|
||||
result = strings.Replace(result, token, wrapped, 1)
|
||||
mediaType := detectMediaType(absPath, content)
|
||||
|
||||
if isBinaryMediaType(mediaType) {
|
||||
// Binary file → extract as a FilePart for multimodal submission.
|
||||
// Remove the @token from the text.
|
||||
parts = append(parts, FilePart{
|
||||
Filename: filepath.Base(absPath),
|
||||
Data: content,
|
||||
MediaType: mediaType,
|
||||
})
|
||||
result = strings.Replace(result, token, "", 1)
|
||||
} else {
|
||||
// Text file → inline as XML-wrapped content.
|
||||
wrapped := wrapFileContent(absPath, content)
|
||||
result = strings.Replace(result, token, wrapped, 1)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
// Clean up any extra whitespace left by removed binary tokens.
|
||||
result = strings.TrimSpace(result)
|
||||
|
||||
return result, parts
|
||||
}
|
||||
|
||||
// tokenToPath strips the @ prefix and optional quotes from a token,
|
||||
@@ -137,3 +236,86 @@ func resolvePath(path string, cwd string) (string, error) {
|
||||
func wrapFileContent(absPath string, content []byte) string {
|
||||
return fmt.Sprintf("<file path=\"%s\">\n%s\n</file>", absPath, string(content))
|
||||
}
|
||||
|
||||
// detectMediaType determines the MIME type of a file using extension-based
|
||||
// lookup first (more reliable for known types), then falls back to content
|
||||
// sniffing via net/http.DetectContentType.
|
||||
func detectMediaType(path string, content []byte) string {
|
||||
// Extension-based detection is more reliable for well-known types.
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
if mt := mime.TypeByExtension(ext); mt != "" {
|
||||
// mime.TypeByExtension returns types like "image/png; charset=utf-8"
|
||||
// — strip parameters.
|
||||
if base, _, ok := strings.Cut(mt, ";"); ok {
|
||||
return strings.TrimSpace(base)
|
||||
}
|
||||
return mt
|
||||
}
|
||||
|
||||
// Known extensions that mime package may miss.
|
||||
switch ext {
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
case ".avif":
|
||||
return "image/avif"
|
||||
case ".heic", ".heif":
|
||||
return "image/heif"
|
||||
case ".opus":
|
||||
return "audio/opus"
|
||||
case ".flac":
|
||||
return "audio/flac"
|
||||
case ".m4a":
|
||||
return "audio/mp4"
|
||||
case ".wasm":
|
||||
return "application/wasm"
|
||||
}
|
||||
|
||||
// Content sniffing fallback.
|
||||
if len(content) > 0 {
|
||||
detected := http.DetectContentType(content)
|
||||
if detected != "" && detected != "application/octet-stream" {
|
||||
if base, _, ok := strings.Cut(detected, ";"); ok {
|
||||
return strings.TrimSpace(base)
|
||||
}
|
||||
return detected
|
||||
}
|
||||
}
|
||||
|
||||
// Default: treat as plain text so it gets XML-wrapped.
|
||||
return "text/plain"
|
||||
}
|
||||
|
||||
// isBinaryMediaType returns true if the MIME type represents a binary file
|
||||
// that should be sent as a multimodal FilePart rather than XML-wrapped text.
|
||||
func isBinaryMediaType(mediaType string) bool {
|
||||
// Image types — always binary.
|
||||
if strings.HasPrefix(mediaType, "image/") {
|
||||
return true
|
||||
}
|
||||
// Audio types — always binary.
|
||||
if strings.HasPrefix(mediaType, "audio/") {
|
||||
return true
|
||||
}
|
||||
// Video types — always binary.
|
||||
if strings.HasPrefix(mediaType, "video/") {
|
||||
return true
|
||||
}
|
||||
// Specific application types that are binary.
|
||||
switch mediaType {
|
||||
case "application/pdf",
|
||||
"application/zip",
|
||||
"application/gzip",
|
||||
"application/x-tar",
|
||||
"application/octet-stream",
|
||||
"application/wasm",
|
||||
"application/x-executable",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessFileAttachments_TextFile(t *testing.T) {
|
||||
// Create a temp text file
|
||||
dir := t.TempDir()
|
||||
textFile := filepath.Join(dir, "hello.txt")
|
||||
if err := os.WriteFile(textFile, []byte("hello world"), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
text := "@" + textFile + " check this out"
|
||||
result := ProcessFileAttachments(text, dir)
|
||||
|
||||
if len(result.FileParts) != 0 {
|
||||
t.Errorf("expected 0 FileParts for text file, got %d", len(result.FileParts))
|
||||
}
|
||||
if result.ProcessedText == text {
|
||||
t.Error("expected text file to be XML-wrapped, but got original text unchanged")
|
||||
}
|
||||
// Should contain XML wrapping
|
||||
if !contains(result.ProcessedText, "<file path=") {
|
||||
t.Error("expected XML <file> wrapping in processed text")
|
||||
}
|
||||
if !contains(result.ProcessedText, "hello world") {
|
||||
t.Error("expected file content in processed text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessFileAttachments_BinaryFile(t *testing.T) {
|
||||
// Create a minimal PNG file (binary)
|
||||
dir := t.TempDir()
|
||||
pngFile := filepath.Join(dir, "image.png")
|
||||
// Minimal valid PNG header
|
||||
pngData := []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, // IHDR chunk
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, // 1x1
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, // 8bit RGB
|
||||
0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, // IDAT chunk
|
||||
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00,
|
||||
0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, 0x33,
|
||||
0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, // IEND chunk
|
||||
0xAE, 0x42, 0x60, 0x82,
|
||||
}
|
||||
if err := os.WriteFile(pngFile, pngData, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
text := "@" + pngFile + " what is this image?"
|
||||
result := ProcessFileAttachments(text, dir)
|
||||
|
||||
if len(result.FileParts) != 1 {
|
||||
t.Fatalf("expected 1 FilePart for binary file, got %d", len(result.FileParts))
|
||||
}
|
||||
if result.FileParts[0].MediaType != "image/png" {
|
||||
t.Errorf("expected media type image/png, got %s", result.FileParts[0].MediaType)
|
||||
}
|
||||
if result.FileParts[0].Filename != "image.png" {
|
||||
t.Errorf("expected filename image.png, got %s", result.FileParts[0].Filename)
|
||||
}
|
||||
// The @token should be removed from the text
|
||||
if contains(result.ProcessedText, "@") && contains(result.ProcessedText, pngFile) {
|
||||
t.Error("expected @token to be removed from processed text for binary file")
|
||||
}
|
||||
if contains(result.ProcessedText, "what is this image?") {
|
||||
// Good, the prompt text should remain
|
||||
} else {
|
||||
t.Error("expected prompt text to remain in processed text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessFileAttachments_MCPResource(t *testing.T) {
|
||||
// Test @mcp:server:uri token processing with a mock reader
|
||||
text := "@mcp:test-server:docs://readme tell me about this"
|
||||
reader := func(serverName, uri string) (string, []byte, string, bool, error) {
|
||||
if serverName != "test-server" || uri != "docs://readme" {
|
||||
t.Errorf("unexpected server/uri: %s/%s", serverName, uri)
|
||||
}
|
||||
return "Hello from MCP resource", nil, "text/plain", false, nil
|
||||
}
|
||||
|
||||
result := ProcessFileAttachments(text, "/tmp", reader)
|
||||
|
||||
if len(result.FileParts) != 0 {
|
||||
t.Errorf("expected 0 FileParts for text MCP resource, got %d", len(result.FileParts))
|
||||
}
|
||||
if !contains(result.ProcessedText, "<resource uri=\"docs://readme\" server=\"test-server\">") {
|
||||
t.Error("expected <resource> XML wrapping in processed text")
|
||||
}
|
||||
if !contains(result.ProcessedText, "Hello from MCP resource") {
|
||||
t.Error("expected MCP resource content in processed text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessFileAttachments_MCPResource_Binary(t *testing.T) {
|
||||
// Test @mcp:server:uri token processing for a binary resource
|
||||
text := "@mcp:test-server:images://logo describe this"
|
||||
reader := func(serverName, uri string) (string, []byte, string, bool, error) {
|
||||
if serverName != "test-server" || uri != "images://logo" {
|
||||
t.Errorf("unexpected server/uri: %s/%s", serverName, uri)
|
||||
}
|
||||
return "", []byte{0x89, 0x50, 0x4E, 0x47}, "image/png", true, nil
|
||||
}
|
||||
|
||||
result := ProcessFileAttachments(text, "/tmp", reader)
|
||||
|
||||
if len(result.FileParts) != 1 {
|
||||
t.Fatalf("expected 1 FilePart for binary MCP resource, got %d", len(result.FileParts))
|
||||
}
|
||||
if result.FileParts[0].MediaType != "image/png" {
|
||||
t.Errorf("expected media type image/png, got %s", result.FileParts[0].MediaType)
|
||||
}
|
||||
if result.FileParts[0].Filename != "logo" {
|
||||
t.Errorf("expected filename 'logo', got %s", result.FileParts[0].Filename)
|
||||
}
|
||||
// The @token should be removed from the text
|
||||
if contains(result.ProcessedText, "@mcp:") {
|
||||
t.Error("expected @mcp: token to be removed from processed text for binary resource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessFileAttachments_NoReader(t *testing.T) {
|
||||
// Without an MCP reader, @mcp: tokens should be left as-is
|
||||
text := "@mcp:server:resource this is a test"
|
||||
result := ProcessFileAttachments(text, "/tmp")
|
||||
|
||||
if len(result.FileParts) != 0 {
|
||||
t.Errorf("expected 0 FileParts, got %d", len(result.FileParts))
|
||||
}
|
||||
// The @mcp: token should remain unchanged since no reader was provided
|
||||
if result.ProcessedText != text {
|
||||
t.Errorf("expected text unchanged without reader, got: %s", result.ProcessedText)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMediaType(t *testing.T) {
|
||||
tests := []struct {
|
||||
ext string
|
||||
content []byte
|
||||
expected string
|
||||
}{
|
||||
// An intentionally-synthetic extension that is not registered
|
||||
// in any system MIME database. Exercises the "unknown ext +
|
||||
// no content" branch, which must return the text/plain default.
|
||||
// Do not use real extensions (e.g. .go) here: CI images often
|
||||
// ship /etc/mime.types with entries like ".go → text/x-go",
|
||||
// which would make the assertion environment-dependent.
|
||||
{".kitsyntheticext", nil, "text/plain"},
|
||||
{".png", []byte{0x89, 0x50, 0x4E, 0x47}, "image/png"},
|
||||
{".jpg", []byte{0xFF, 0xD8, 0xFF}, "image/jpeg"},
|
||||
{".pdf", []byte{0x25, 0x50, 0x44, 0x46}, "application/pdf"},
|
||||
{".txt", []byte("hello"), "text/plain"},
|
||||
{".wav", nil, "audio/wav"},
|
||||
{".webp", nil, "image/webp"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ext, func(t *testing.T) {
|
||||
got := detectMediaType("test"+tt.ext, tt.content)
|
||||
if got != tt.expected {
|
||||
t.Errorf("detectMediaType(%q) = %q, want %q", tt.ext, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBinaryMediaType(t *testing.T) {
|
||||
tests := []struct {
|
||||
mimeType string
|
||||
expected bool
|
||||
}{
|
||||
{"image/png", true},
|
||||
{"image/jpeg", true},
|
||||
{"audio/wav", true},
|
||||
{"video/mp4", true},
|
||||
{"application/pdf", true},
|
||||
{"text/plain", false},
|
||||
{"text/go", false},
|
||||
{"application/json", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.mimeType, func(t *testing.T) {
|
||||
got := isBinaryMediaType(tt.mimeType)
|
||||
if got != tt.expected {
|
||||
t.Errorf("isBinaryMediaType(%q) = %v, want %v", tt.mimeType, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsStr(s, substr))
|
||||
}
|
||||
|
||||
func containsStr(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -17,6 +17,7 @@ type Renderer interface {
|
||||
RenderReasoningBlock(content string, timestamp time.Time) UIMessage
|
||||
RenderToolMessage(toolName, toolArgs, toolResult string, isError bool) UIMessage
|
||||
RenderSystemMessage(content string, timestamp time.Time) UIMessage
|
||||
RenderCustomMessage(content, label string, timestamp time.Time) UIMessage
|
||||
RenderErrorMessage(errorMsg string, timestamp time.Time) UIMessage
|
||||
RenderDebugMessage(message string, timestamp time.Time) UIMessage
|
||||
RenderDebugConfigMessage(config map[string]any, timestamp time.Time) UIMessage
|
||||
|
||||
+120
-34
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
@@ -39,7 +40,6 @@ type InputComponent struct {
|
||||
width int
|
||||
lastValue string
|
||||
popupHeight int
|
||||
title string
|
||||
submitNext bool // defer submit one tick so popup dismisses cleanly
|
||||
|
||||
// Argument completion state. When the user types "/cmd " followed by
|
||||
@@ -61,6 +61,10 @@ type InputComponent struct {
|
||||
// autocomplete suggestions. Set by the parent via SetCwd.
|
||||
cwd string
|
||||
|
||||
// mcpResources is a callback that returns available MCP resources for
|
||||
// the @ autocomplete popup. Set by the parent via SetMCPResourceProvider.
|
||||
mcpResources func() []FileSuggestion
|
||||
|
||||
// appCtrl is used for slash commands that mutate app state.
|
||||
// May be nil in tests; nil-safe.
|
||||
appCtrl AppController
|
||||
@@ -101,17 +105,17 @@ type clipboardImageMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// NewInputComponent creates a new InputComponent with the given width, title,
|
||||
// and optional AppController. If appCtrl is nil the component still works but
|
||||
// NewInputComponent creates a new InputComponent with the given width and
|
||||
// optional AppController. If appCtrl is nil the component still works but
|
||||
// /clear and /clear-queue are no-ops.
|
||||
func NewInputComponent(width int, title string, appCtrl AppController) *InputComponent {
|
||||
func NewInputComponent(width int, appCtrl AppController) *InputComponent {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = "Type your message..."
|
||||
ta.ShowLineNumbers = false
|
||||
ta.Prompt = ""
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(3) // Default to 3 lines like huh
|
||||
ta.SetHeight(4) // 4 lines for comfortable multi-line input
|
||||
ta.Focus()
|
||||
|
||||
// Override InsertNewline so only ctrl+j and shift+enter insert newlines.
|
||||
@@ -136,8 +140,8 @@ func NewInputComponent(width int, title string, appCtrl AppController) *InputCom
|
||||
commands: commands.SlashCommands,
|
||||
width: width,
|
||||
popupHeight: 7,
|
||||
title: title,
|
||||
appCtrl: appCtrl,
|
||||
hideHint: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,6 +151,12 @@ func (s *InputComponent) SetCwd(cwd string) {
|
||||
s.cwd = cwd
|
||||
}
|
||||
|
||||
// SetMCPResourceProvider sets a callback that returns MCP resource suggestions
|
||||
// for the @ autocomplete popup. Called by the parent after construction.
|
||||
func (s *InputComponent) SetMCPResourceProvider(fn func() []FileSuggestion) {
|
||||
s.mcpResources = fn
|
||||
}
|
||||
|
||||
// Init implements tea.Model. Starts the cursor blink animation.
|
||||
func (s *InputComponent) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
@@ -190,7 +200,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.KeyPressMsg:
|
||||
if !s.showPopup {
|
||||
switch msg.String() {
|
||||
case "ctrl+d", "enter":
|
||||
case "enter":
|
||||
value := s.textarea.Value()
|
||||
s.pushHistory(value)
|
||||
s.textarea.SetValue("")
|
||||
@@ -332,9 +342,46 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// Check for @file trigger first.
|
||||
cursorCol := len(line) // approximate: cursor is at end after typing
|
||||
if hasAt, prefix, atIdx := ExtractAtPrefix(line, cursorCol); hasAt && s.cwd != "" {
|
||||
suggestions := GetFileSuggestions(prefix, s.cwd)
|
||||
if hasAt, prefix, atIdx := ExtractAtPrefix(line, cursorCol); hasAt {
|
||||
var suggestions []FileSuggestion
|
||||
|
||||
// Local file suggestions (only if cwd is set).
|
||||
if s.cwd != "" {
|
||||
suggestions = GetFileSuggestions(prefix, s.cwd)
|
||||
}
|
||||
|
||||
// MCP resource suggestions — merge with file suggestions.
|
||||
if s.mcpResources != nil {
|
||||
mcpSuggestions := s.mcpResources()
|
||||
if prefix != "" {
|
||||
// Fuzzy-filter MCP resources against the typed prefix.
|
||||
queryLower := strings.ToLower(prefix)
|
||||
var filtered []FileSuggestion
|
||||
for _, r := range mcpSuggestions {
|
||||
score := scoreFilePath(queryLower, r.RelPath)
|
||||
if score <= 0 {
|
||||
// Also try matching against the resource name without prefix.
|
||||
score = scoreFilePath(queryLower, r.MCPServerName+"/"+r.RelPath)
|
||||
}
|
||||
if score > 0 {
|
||||
r.Score = score
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
}
|
||||
mcpSuggestions = filtered
|
||||
}
|
||||
suggestions = append(suggestions, mcpSuggestions...)
|
||||
}
|
||||
|
||||
if len(suggestions) > 0 {
|
||||
// Sort by score descending, cap at maxFileSuggestions.
|
||||
sort.Slice(suggestions, func(i, j int) bool {
|
||||
return suggestions[i].Score > suggestions[j].Score
|
||||
})
|
||||
if len(suggestions) > maxFileSuggestions {
|
||||
suggestions = suggestions[:maxFileSuggestions]
|
||||
}
|
||||
|
||||
s.showPopup = true
|
||||
s.fileMode = true
|
||||
s.argMode = false
|
||||
@@ -348,6 +395,8 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
desc := ""
|
||||
if fs.IsDir {
|
||||
desc = "directory"
|
||||
} else if fs.IsMCPResource {
|
||||
desc = "mcp:" + fs.MCPServerName
|
||||
}
|
||||
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
|
||||
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
|
||||
@@ -470,19 +519,13 @@ func (s *InputComponent) resetHistoryBrowsing() {
|
||||
s.savedInput = ""
|
||||
}
|
||||
|
||||
// View implements tea.Model. Renders the title, textarea, autocomplete popup
|
||||
// View implements tea.Model. Renders the textarea, autocomplete popup
|
||||
// (if visible), and help text.
|
||||
func (s *InputComponent) View() tea.View {
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := style.GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
@@ -490,12 +533,12 @@ func (s *InputComponent) View() tea.View {
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
MarginTop(1).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(s.width - 1) // full width minus left border
|
||||
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render(s.title))
|
||||
view.WriteString("\n")
|
||||
view.WriteString(inputBoxStyle.Render(s.textarea.View()))
|
||||
|
||||
// Popup is now rendered as a centered overlay in AppModel.View()
|
||||
@@ -658,9 +701,25 @@ func (s *InputComponent) renderPopupWithOptions(centered bool) string {
|
||||
}
|
||||
content = indicator + displayName
|
||||
} else {
|
||||
nameWidth := 15
|
||||
if innerWidth < 25 {
|
||||
nameWidth = max(innerWidth*2/5+1, 8)
|
||||
// Compute nameWidth from the longest command name in the
|
||||
// visible slice so we never truncate unnecessarily.
|
||||
nameWidth := 0
|
||||
for _, fm := range s.filtered {
|
||||
if n := len([]rune(fm.Command.Name)); n > nameWidth {
|
||||
nameWidth = n
|
||||
}
|
||||
}
|
||||
nameWidth += 3 // account for indicator prefix (2) + gap before description (1)
|
||||
// Ensure descriptions still get at least 20 chars when possible.
|
||||
maxForName := innerWidth - 20
|
||||
if maxForName < 8 {
|
||||
maxForName = innerWidth * 2 / 3
|
||||
}
|
||||
if nameWidth > maxForName {
|
||||
nameWidth = maxForName
|
||||
}
|
||||
if nameWidth < 8 {
|
||||
nameWidth = 8
|
||||
}
|
||||
maxNameChars := nameWidth - 2
|
||||
displayName := sc.Name
|
||||
@@ -793,9 +852,25 @@ func (s *InputComponent) PendingImageCount() int {
|
||||
return len(s.pendingImages)
|
||||
}
|
||||
|
||||
// Clear clears the textarea content and resets related state. Returns true if
|
||||
// there was content to clear, false if the input was already empty.
|
||||
func (s *InputComponent) Clear() bool {
|
||||
hadContent := s.textarea.Value() != ""
|
||||
s.textarea.SetValue("")
|
||||
s.textarea.CursorEnd()
|
||||
s.lastValue = ""
|
||||
s.showPopup = false
|
||||
s.argMode = false
|
||||
s.fileMode = false
|
||||
s.browsingHistory = false
|
||||
s.savedInput = ""
|
||||
return hadContent
|
||||
}
|
||||
|
||||
// applyFileCompletion replaces the @prefix in the textarea with the selected
|
||||
// file suggestion. For directories, it keeps the popup open for further
|
||||
// drilling. For files, it closes the popup and adds a trailing space.
|
||||
// file or MCP resource suggestion. For directories, it keeps the popup open
|
||||
// for further drilling. For files and resources, it closes the popup and adds
|
||||
// a trailing space.
|
||||
func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
if idx >= len(s.fileSuggestions) {
|
||||
return
|
||||
@@ -812,19 +887,30 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
|
||||
// Reconstruct: everything before the @ on the last line + @<path>
|
||||
beforeAt := lastLine[:s.fileAtStartIdx]
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
|
||||
var replacement string
|
||||
if needsQuote {
|
||||
replacement = `@"` + suggestion.RelPath + `"`
|
||||
} else {
|
||||
replacement = "@" + suggestion.RelPath
|
||||
}
|
||||
|
||||
// For files, add a trailing space. For directories, don't — allow
|
||||
// continued drilling into the directory.
|
||||
if !suggestion.IsDir {
|
||||
if suggestion.IsMCPResource {
|
||||
// MCP resources use @mcp:server:uri format.
|
||||
// Quote if the URI contains spaces.
|
||||
ref := "mcp:" + suggestion.MCPServerName + ":" + suggestion.MCPResourceURI
|
||||
if strings.Contains(ref, " ") {
|
||||
replacement = `@"` + ref + `"`
|
||||
} else {
|
||||
replacement = "@" + ref
|
||||
}
|
||||
replacement += " "
|
||||
} else {
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
if needsQuote {
|
||||
replacement = `@"` + suggestion.RelPath + `"`
|
||||
} else {
|
||||
replacement = "@" + suggestion.RelPath
|
||||
}
|
||||
// For files, add a trailing space. For directories, don't — allow
|
||||
// continued drilling into the directory.
|
||||
if !suggestion.IsDir {
|
||||
replacement += " "
|
||||
}
|
||||
}
|
||||
|
||||
newLastLine := beforeAt + replacement
|
||||
@@ -836,7 +922,7 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
s.textarea.SetValue(newValue)
|
||||
s.textarea.CursorEnd()
|
||||
|
||||
if suggestion.IsDir {
|
||||
if suggestion.IsDir && !suggestion.IsMCPResource {
|
||||
// Keep popup open — trigger a refresh for the new directory.
|
||||
s.lastValue = "" // force re-evaluation on next update tick
|
||||
} else {
|
||||
|
||||
@@ -25,17 +25,6 @@ type TextMessageItem struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
@@ -109,8 +98,8 @@ func (m *TextMessageItem) renderContent(width int) string {
|
||||
// It accumulates content chunks and re-renders on each update for live display.
|
||||
type StreamingMessageItem struct {
|
||||
id string
|
||||
role string // "assistant" or "reasoning"
|
||||
content string // Accumulated streaming content
|
||||
role string // "assistant" or "reasoning"
|
||||
content strings.Builder // Accumulated streaming content
|
||||
timestamp time.Time
|
||||
startTime time.Time // When streaming started (for live duration counter)
|
||||
modelName string
|
||||
@@ -156,10 +145,10 @@ func (s *StreamingMessageItem) Render(width int) string {
|
||||
durationMs = time.Since(s.startTime).Milliseconds()
|
||||
}
|
||||
ty := createTypography(style.GetTheme())
|
||||
rendered = render.ReasoningBlock(s.content, durationMs, ty, style.GetTheme())
|
||||
rendered = render.ReasoningBlock(s.content.String(), durationMs, width, ty, style.GetTheme())
|
||||
} else {
|
||||
// Render as assistant message
|
||||
rendered = render.AssistantBlock(s.content, width, style.GetTheme())
|
||||
rendered = render.AssistantBlock(s.content.String(), width, style.GetTheme())
|
||||
}
|
||||
|
||||
// Cache and return (but reasoning is never cached due to live duration)
|
||||
@@ -187,7 +176,7 @@ func (s *StreamingMessageItem) Height() int {
|
||||
|
||||
// AppendChunk adds a content chunk and invalidates the render cache.
|
||||
func (s *StreamingMessageItem) AppendChunk(chunk string) {
|
||||
s.content += chunk
|
||||
s.content.WriteString(chunk)
|
||||
s.cachedWidth = 0 // Invalidate cache
|
||||
}
|
||||
|
||||
@@ -243,9 +232,7 @@ func (m *StreamingBashOutputItem) Render(width int) string {
|
||||
|
||||
// Header with command
|
||||
if m.command != "" {
|
||||
headerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
headerStyle := style.GetCachedStyles().BashHeader
|
||||
parts = append(parts, headerStyle.Render(fmt.Sprintf("▸ %s", m.command)))
|
||||
}
|
||||
|
||||
@@ -318,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
+39
-12
@@ -88,13 +88,9 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
"content": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
for key, val := range params {
|
||||
@@ -150,9 +146,26 @@ func (r *MessageRenderer) SetWidth(width int) {
|
||||
r.width = width
|
||||
}
|
||||
|
||||
// RenderUserMessage renders a user's input message using herald Tip alert
|
||||
// RenderUserMessage renders a user's input message with a colored left border.
|
||||
func (r *MessageRenderer) RenderUserMessage(content string, timestamp time.Time) UIMessage {
|
||||
rendered := render.UserBlock(content, r.width, r.ty, style.GetTheme())
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Highlight @file tokens with accent color.
|
||||
content = render.HighlightFileTokens(content, theme)
|
||||
|
||||
rendered := renderContentBlock(
|
||||
content,
|
||||
r.width,
|
||||
WithAlign(lipgloss.Left),
|
||||
WithBorderColor(theme.Success),
|
||||
WithPaddingTop(0),
|
||||
WithPaddingBottom(0),
|
||||
WithMarginBottom(1),
|
||||
)
|
||||
|
||||
return UIMessage{
|
||||
Type: UserMessage,
|
||||
@@ -178,7 +191,7 @@ func (r *MessageRenderer) RenderAssistantMessage(content string, timestamp time.
|
||||
// as live streaming: muted italic text with margin. This is used when resuming
|
||||
// sessions to display saved reasoning content.
|
||||
func (r *MessageRenderer) RenderReasoningBlock(content string, timestamp time.Time) UIMessage {
|
||||
rendered := render.ReasoningBlock(content, 0, r.ty, style.GetTheme())
|
||||
rendered := render.ReasoningBlock(content, 0, r.width, r.ty, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: AssistantMessage,
|
||||
@@ -200,6 +213,19 @@ func (r *MessageRenderer) RenderSystemMessage(content string, timestamp time.Tim
|
||||
}
|
||||
}
|
||||
|
||||
// RenderCustomMessage renders a message with a custom alert label (e.g. "Help").
|
||||
// Content is rendered as markdown.
|
||||
func (r *MessageRenderer) RenderCustomMessage(content, label string, timestamp time.Time) UIMessage {
|
||||
rendered := render.CustomBlock(content, label, r.width, style.GetTheme())
|
||||
|
||||
return UIMessage{
|
||||
Type: SystemMessage,
|
||||
Content: rendered,
|
||||
Height: lipgloss.Height(rendered),
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
// RenderDebugMessage renders diagnostic and debugging information
|
||||
func (r *MessageRenderer) RenderDebugMessage(message string, timestamp time.Time) UIMessage {
|
||||
header := r.ty.H6("🔍 Debug Output")
|
||||
@@ -308,7 +334,7 @@ func (r *MessageRenderer) RenderToolMessage(toolName, toolArgs, toolResult strin
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
headerLine += " " + style.GetCachedStyles().ToolMuted.Render(params)
|
||||
}
|
||||
|
||||
// Get body content
|
||||
@@ -399,7 +425,8 @@ func createTypography(theme style.Theme) *herald.Typography {
|
||||
herald.WithCodeLineNumbers(true),
|
||||
// Customize alert labels
|
||||
herald.WithAlertLabel(herald.AlertNote, "Info"),
|
||||
herald.WithAlertLabel(herald.AlertTip, "You"),
|
||||
herald.WithAlertLabel(herald.AlertTip, ""),
|
||||
herald.WithAlertIcon(herald.AlertTip, ""),
|
||||
herald.WithAlertLabel(herald.AlertWarning, "Working"),
|
||||
herald.WithAlertLabel(herald.AlertCaution, "Error"),
|
||||
)
|
||||
|
||||
+789
-62
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user