mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
55 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 58b8aa043c | |||
| 41e34f30f8 | |||
| 1e78153b50 | |||
| a613361969 | |||
| 67722b0c24 | |||
| ce1d7afe83 | |||
| 1a2f6da40f | |||
| 747f5be099 | |||
| d7c4565999 | |||
| bd24f3315c | |||
| 592f8dc84f | |||
| 66c4a1eb15 | |||
| 5104477631 | |||
| 394a4676a1 | |||
| 30f2bc243d | |||
| 922e246098 | |||
| 32b6376515 | |||
| cf194ff89a | |||
| 03006425fa | |||
| a322dfc59a | |||
| b1387d837e | |||
| f561f4cfd9 | |||
| 64caed57d4 | |||
| 975c30a773 | |||
| 35b9360d64 | |||
| 1b8373e133 | |||
| 1a5e4ce7c5 | |||
| 8823977612 | |||
| 24e2ea111c | |||
| 31ea80ec4f | |||
| 99f2680c2e | |||
| da7e05eb87 | |||
| a95714a22d | |||
| c4a2b0f1a3 | |||
| 2016570e2d | |||
| d557f4b870 | |||
| 65054fe3db | |||
| 97d2246375 | |||
| 1e12505741 | |||
| 6755597c9b | |||
| 45689cb30d | |||
| 78570d4188 | |||
| 7cf38b37ee | |||
| 4ef57eec4e | |||
| cbd828e190 | |||
| d304805106 | |||
| 6e36053856 | |||
| 92eaaf6a59 | |||
| e6084b7bd0 | |||
| 34d5abff9c | |||
| fc0ddd5f4f | |||
| 7aa6160c75 | |||
| e830bf87ca | |||
| 3881d1c28f | |||
| 53f6682bd0 |
@@ -1,268 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
diagnosticsTimeout = 20 * time.Second
|
||||
maxOutputBytes = 12_000
|
||||
)
|
||||
|
||||
type toolPathInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type lintResult struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
absPath, ok := resolveGoFilePath(e.Input, ctx.CWD)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
borderColor = "#f38ba8" // red
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: strings.Join(msgLines, "\n"),
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isEditOrWrite(toolName string) bool {
|
||||
return strings.EqualFold(toolName, "edit") || strings.EqualFold(toolName, "write")
|
||||
}
|
||||
|
||||
func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
var args toolPathInput
|
||||
if err := json.Unmarshal([]byte(inputJSON), &args); err != nil || args.Path == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
absPath := args.Path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
absPath = filepath.Join(cwd, absPath)
|
||||
}
|
||||
|
||||
if strings.ToLower(filepath.Ext(absPath)) != ".go" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "gopls", "check", absPath)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes), Err: fmt.Errorf("failed to run gopls check: %w", err)}
|
||||
}
|
||||
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes)}
|
||||
}
|
||||
|
||||
func runGolangCILint(cwd, target string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []string{
|
||||
"run",
|
||||
target,
|
||||
"--show-stats=false",
|
||||
"--output.text.path", "stdout",
|
||||
"--output.text.colors=false",
|
||||
"--output.text.print-issued-lines=false",
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "golangci-lint", args...)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
trimmed := truncate(string(out), maxOutputBytes)
|
||||
if err == nil {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
exitErr, ok := err.(*exec.ExitError)
|
||||
if ok && exitErr.ExitCode() == 1 {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
return lintResult{Output: trimmed, Err: fmt.Errorf("failed to run golangci-lint: %w", err)}
|
||||
}
|
||||
|
||||
func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
var lines []string
|
||||
if res.Err != nil {
|
||||
lines = append(lines, "ERROR: "+res.Err.Error())
|
||||
}
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return emptyFallback
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "\n... output truncated ..."
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
lintCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return goplsCount, lintCount
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
---
|
||||
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
|
||||
---
|
||||
|
||||
Perform a comprehensive **read-only** audit of this repository and report
|
||||
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
|
||||
hints from the user: $@
|
||||
|
||||
## Scope
|
||||
|
||||
If the user supplied focus hints above (a package path, a subsystem name, a
|
||||
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
|
||||
audit the whole repo, prioritising the highest-traffic packages first
|
||||
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Map the repo first**:
|
||||
- `ls` / `find` the top-level layout and list every Go package
|
||||
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
|
||||
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
|
||||
extension surface)
|
||||
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
|
||||
(e.g. "no dependency name leakage", "UI never imports extensions
|
||||
directly") — these define what counts as a violation
|
||||
|
||||
2. **Hunt for dead code**:
|
||||
- Run `go vet ./...` and capture warnings
|
||||
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
|
||||
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
|
||||
with zero non-test references inside the module are suspects
|
||||
- Check for unreferenced files, `// TODO: remove` markers, commented-out
|
||||
blocks, and `_ = x` discard patterns
|
||||
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
|
||||
them and include their output verbatim
|
||||
- **Do not delete anything** — list candidates with file:line and a
|
||||
confidence level (high / medium / low)
|
||||
|
||||
3. **Find unnecessary duplication**:
|
||||
- Look for near-identical function bodies, struct shapes, or switch
|
||||
statements across packages — `grep` for repeated function signatures
|
||||
and copy-pasted string literals / error messages is a fast first pass
|
||||
- Distinguish *coincidental* duplication (two things that happen to look
|
||||
alike but evolve independently) from *unnecessary* duplication (same
|
||||
intent, drifting in lockstep) — only flag the latter
|
||||
- For each cluster, propose where the extracted helper should live
|
||||
(which package, which file) and whether it crosses a boundary
|
||||
|
||||
4. **Check concerns / boundary violations**:
|
||||
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
|
||||
in exported signatures, and for dependency-name leakage in exported
|
||||
names / godoc (e.g. library jargon appearing in `LLM*` types)
|
||||
- **UI ↔ extensions**: grep `internal/ui/` for any import of
|
||||
`internal/extensions/` — per AGENTS.md the UI must not import
|
||||
extensions directly; converters in `cmd/root.go` should bridge them
|
||||
- **cmd vs internal**: business logic living in `cmd/` that should be
|
||||
in `internal/` (and vice versa)
|
||||
- **Cyclic risk**: packages that import each other transitively or that
|
||||
reach across sibling boundaries unexpectedly
|
||||
- For each violation, cite the offending import / signature with
|
||||
file:line
|
||||
|
||||
5. **Spot refactor opportunities**:
|
||||
- Long functions (>80 lines) doing multiple unrelated things
|
||||
- Deeply nested conditionals that flatten well with early returns
|
||||
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
|
||||
that could become helpers — but only where the wrapping context is
|
||||
genuinely uniform
|
||||
- Structs with too many fields that hint at split responsibilities
|
||||
- Exported APIs that would be cleaner with options structs / functional
|
||||
options
|
||||
- Tests that share setup boilerplate ripe for a helper
|
||||
- Flag each with: location, current shape (1-2 lines), proposed shape
|
||||
(1-2 lines), and estimated risk (low / medium / high)
|
||||
|
||||
6. **Cross-check against project rules**:
|
||||
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
|
||||
findings contradicts the documented gotchas (Yaegi interface ban,
|
||||
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
|
||||
"refactor" would reintroduce a known pitfall, drop it from the report
|
||||
and note why
|
||||
|
||||
7. **Write the report** as your final message (do not write it to disk)
|
||||
structured as:
|
||||
|
||||
```
|
||||
# Code Audit Report
|
||||
|
||||
## Summary
|
||||
- N dead-code candidates
|
||||
- N duplication clusters
|
||||
- N boundary violations
|
||||
- N refactor opportunities
|
||||
|
||||
## Dead Code
|
||||
### High confidence
|
||||
- path/to/file.go:LINE — symbol — reason
|
||||
|
||||
### Medium confidence
|
||||
...
|
||||
|
||||
## Duplication
|
||||
### Cluster: <short name>
|
||||
- Sites: file:line, file:line, …
|
||||
- Suggested home: package/path
|
||||
- Notes: …
|
||||
|
||||
## Boundary Violations
|
||||
- Rule: <which rule from AGENTS.md / project convention>
|
||||
- Offender: file:line
|
||||
- Fix sketch: …
|
||||
|
||||
## Refactor Opportunities
|
||||
- Location: file:line
|
||||
- Current: …
|
||||
- Proposed: …
|
||||
- Risk: low/medium/high
|
||||
- Why it's worth it: …
|
||||
|
||||
## Suggested Next Steps
|
||||
1. …
|
||||
2. …
|
||||
```
|
||||
|
||||
8. **End the report with an explicit reminder** that no files were modified,
|
||||
and recommend the user pick the highest-leverage items to act on
|
||||
manually (or via a follow-up `/fix-issue` style prompt) rather than
|
||||
running a sweeping refactor.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
|
||||
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
|
||||
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
|
||||
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
|
||||
straight to it
|
||||
- **Be honest about confidence**: false positives in a code audit are
|
||||
expensive — prefer "medium confidence, worth a look" over confidently
|
||||
wrong claims
|
||||
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
|
||||
anything that's purely stylistic unless it directly causes one of the
|
||||
four issue categories above
|
||||
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
|
||||
`vendor/`) and obvious third-party copies
|
||||
- **Don't propose architectural rewrites** — stay within the existing
|
||||
shape of the repo and recommend incremental, reviewable changes
|
||||
@@ -0,0 +1,47 @@
|
||||
---
|
||||
description: Open a GitHub PR for the current branch using the repo's PR template
|
||||
---
|
||||
|
||||
Open a GitHub pull request for the current branch, filling out the repository's PR template with a description grounded in the actual commits and diff.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Verify the branch is pushed**:
|
||||
- `git status -sb` and `git log @{u}..HEAD --oneline 2>/dev/null` — if there is no upstream or unpushed commits, run `git push -u origin "$(git branch --show-current)"` first
|
||||
- If the working tree is dirty, stop and tell the user to commit first (suggest `/commit-push`)
|
||||
2. **Gather context**:
|
||||
- `git log origin/main..HEAD --oneline` — list of commits going into the PR
|
||||
- `git diff origin/main...HEAD --stat` then `git diff origin/main...HEAD` — read the actual changes
|
||||
- Identify the linked issue (from commit messages, branch name, or extra user input: $@) — capture as `Fixes #N` if applicable
|
||||
3. **Locate the PR template**:
|
||||
- Check `.github/pull_request_template.md`, `.github/PULL_REQUEST_TEMPLATE.md`, or `docs/pull_request_template.md`
|
||||
- If none exists, use a minimal `## Description` / `## Type of Change` / `## Checklist` structure
|
||||
4. **Draft the PR body** by filling out the template:
|
||||
- **Description**: 1–3 short paragraphs explaining *what* changed and *why*, grounded in the diff. Include a brief before/after example for new APIs when useful.
|
||||
- **Fixes #N**: only if there is a real linked issue
|
||||
- **Type of Change**: tick the single most accurate box with `[x]` (leave others as `[ ]`)
|
||||
- **Checklist**: tick items that are genuinely true (style, self-review, tests added, docs updated)
|
||||
- **Additional Information**: bullet list of added / modified files and any backward-compatibility notes
|
||||
- Remove template sections explicitly marked "remove if not applicable" (e.g. MCP Spec Compliance) when they don't apply
|
||||
5. **Write the body to a temp file**: `/tmp/pr-body-<branch-or-issue>.md` — never inline a long body via `--body`, always use `--body-file`
|
||||
6. **Choose the title**: prefer the subject of the primary commit if it already follows Conventional Commits; otherwise craft one in the same style (`<type>(<scope>): <imperative summary>`, ≤72 chars)
|
||||
7. **Create the PR**:
|
||||
```
|
||||
gh pr create \
|
||||
--title "<title>" \
|
||||
--body-file /tmp/pr-body-<...>.md \
|
||||
--base main \
|
||||
--head "$(git branch --show-current)"
|
||||
```
|
||||
Use the repo's actual default branch if it isn't `main` (`gh repo view --json defaultBranchRef -q .defaultBranchRef.name`)
|
||||
8. **Report the PR URL** returned by `gh` and stop
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff and commit messages — do **not** invent features that aren't in the code
|
||||
- One PR per logical change; if the branch contains unrelated commits, surface that and ask before continuing
|
||||
- Keep the description focused on reviewer-relevant information (what / why), not a replay of the diff
|
||||
- Only check checklist boxes that are actually satisfied; leave the rest unchecked rather than lying
|
||||
- If `gh` is not authenticated (`gh auth status` fails), stop and tell the user
|
||||
|
||||
$@
|
||||
@@ -2,7 +2,7 @@
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
Create a feature request for the Kit repository. The user wants to request: $@
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
@@ -16,7 +16,7 @@ This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$+`
|
||||
1. **Understand the request** from the user input: $@
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
@@ -16,7 +16,7 @@ This repository has structured issue templates. You MUST use the appropriate tem
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$+`:
|
||||
1. **Determine the issue type** from the user input: $@
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
---
|
||||
description: Implement the fix/feature/docs change requested by a GitHub issue
|
||||
---
|
||||
|
||||
Resolve GitHub issue #$1 by reading it, classifying it, and producing the appropriate code or doc change. **Stop once the working tree contains the change** — committing, pushing, and opening a PR are handled by `/commit-push` and `/create-pr`.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch the issue**:
|
||||
- Run: gh issue view $1 --json number,title,body,labels,state,author,comments
|
||||
- If the issue is closed, stop and ask the user whether to proceed
|
||||
- Read the **entire** thread including comments — the latest comment often refines the ask
|
||||
|
||||
2. **Classify the issue** from labels, title prefix, and body content:
|
||||
- `bug` / `fix:` → reproduce, then fix
|
||||
- `enhancement` / `feature` / `feat:` → design, then implement
|
||||
- `documentation` / `docs:` → locate and update docs
|
||||
- `question` / `discussion` → answer in a comment, do **not** write code
|
||||
- Anything else → ask the user how to proceed
|
||||
|
||||
3. **Create a working branch** off the default branch:
|
||||
- `git checkout main && git pull --ff-only`
|
||||
- Branch name: <type>/$1-<slug> (e.g. `fix/42-borderColor-ignored`, `feat/57-keyboard-clear`, `docs/63-widget-lifecycle`)
|
||||
|
||||
4. **Do the work** based on type:
|
||||
|
||||
### Bug (`bug` label / `fix:` title)
|
||||
- Reproduce the failure first (write a failing test if feasible) — if you cannot reproduce, comment on the issue asking for clarification and stop
|
||||
- Locate the root cause; do not patch symptoms
|
||||
- Add or extend a regression test that fails before and passes after the fix
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Feature (`enhancement` / `feature` label / `feat:` title)
|
||||
- Re-read the motivation and proposed implementation in the issue body
|
||||
- For large, ambiguous, or breaking changes, sketch the design in a comment on the issue and wait for sign-off before writing code
|
||||
- Implement behind sensible defaults; add godoc on every exported symbol
|
||||
- Add unit tests covering the new behaviour and edge cases
|
||||
- Update `README.md` / `docs/` if the public surface changed
|
||||
- Run `go test ./... -race` and `golangci-lint run`
|
||||
|
||||
### Documentation (`documentation` label / `docs:` title)
|
||||
- Open the file/URL referenced in the issue's "Documentation Location"
|
||||
- Apply the suggested improvement; verify code samples compile (`go build ./...`)
|
||||
- No tests required, but run `golangci-lint run` if Go files were touched
|
||||
|
||||
5. **Report**:
|
||||
- Branch name (`git branch --show-current`)
|
||||
- Summary of files changed (`git status -s`) and the diff highlights
|
||||
- Test/lint results (pass/fail with key output)
|
||||
- Suggest the next step explicitly:
|
||||
- `/commit-push` to commit with a Conventional Commit subject (the message should reference `(#$1)` and include `Fixes #$1` so merge auto-closes)
|
||||
- then `/create-pr $1` to open the pull request
|
||||
|
||||
## Guidelines
|
||||
|
||||
- This prompt **stops at a clean working tree with the change applied** — do not run `git commit`, `git push`, or `gh pr create`
|
||||
- If the issue is unclear, post a clarifying comment on the issue and stop; do not guess
|
||||
- Keep the change scoped to the issue; surface unrelated cleanups separately
|
||||
- For breaking changes or architecture shifts, propose the design on the issue first and wait for maintainer sign-off
|
||||
- If the issue is a duplicate or already fixed on `main`, comment with the reference and stop
|
||||
- Do not close the issue manually — the eventual PR's `Fixes #$1` handles that on merge
|
||||
+48
-13
@@ -2,7 +2,7 @@
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
@@ -16,30 +16,64 @@ It becomes a `/slug` slash command in the kit input box — typed as `/filename`
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
Body text of the prompt. Reference user-supplied arguments
|
||||
with positional placeholders (see "Argument placeholders" below).
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
- **Required args**: kit infers required positional args from the highest `$N` it finds *outside* backtick/tilde code fences — a stray `$2` in active prose means kit will refuse to run without 2 arguments
|
||||
|
||||
## Argument placeholders
|
||||
|
||||
kit performs shell-style substitution before sending the prompt to the model:
|
||||
|
||||
- `$1`, `$2`, … — positional arguments (1-indexed)
|
||||
- `${1}`, `${2}`, … — same, brace form (use when followed by digits/letters: `${1}_suffix`)
|
||||
- `$@` — all arguments joined by spaces (zero or more, optional)
|
||||
- `$+` — all arguments, **at least one required**
|
||||
- `$ARGUMENTS` / `${ARGUMENTS}` — alias for `$@`
|
||||
- `${@:N}` — args from the Nth onwards (1-indexed, bash-style)
|
||||
- `${@:N:L}` — `L` args starting from the Nth
|
||||
|
||||
### ⚠️ Critical: code fences and inline code preserve placeholders verbatim
|
||||
|
||||
Anything inside triple-backtick fences, `~~~` fences, or single-backtick `inline` code spans is **left untouched** so example code samples don't get corrupted. That means:
|
||||
|
||||
- An inline-coded `gh issue view $1` stays literal `$1` in the model's input ❌
|
||||
- The same command without backticks: gh issue view $1 → expands to `gh issue view 42` ✓
|
||||
|
||||
**Rule of thumb:** if you want a placeholder to substitute, keep it outside backticks and fences. If you want a literal `$1` in the output (e.g. teaching the user shell syntax), put it inside backticks.
|
||||
|
||||
### Workarounds for "I want it to look like code AND substitute"
|
||||
|
||||
1. **Drop the backticks** around just the placeholder portion — the rest can still read as a command line in prose
|
||||
2. **Use a 4-space-indented code block** instead of a triple-backtick fence — kit only skips backtick/tilde fences, so indentation-style code blocks still get substitution:
|
||||
|
||||
git push -u origin "$(git branch --show-current)"
|
||||
gh pr create --title "fix: ... (#$1)" --base main
|
||||
|
||||
3. **Bind once, reference loosely**: put `Issue: $1` at the top in prose, then leave the backticked examples literal — the model will substitute mentally
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
1. **Understand the workflow** the user described in $@ — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
4. **Decide on arguments**:
|
||||
- No args needed → omit placeholders entirely
|
||||
- One required value (issue number, PR url, file path) → use `$1`
|
||||
- Free-form trailing context → end with a single `$@` line
|
||||
- Multiple distinct values → use `$1`, `$2`, … and document each at the top
|
||||
5. **Draft the body**:
|
||||
- Open with a single sentence stating the goal, weaving in `$1`/`$@` where the value belongs
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
- **Audit every backtick and code fence**: any `$N` or `$@` inside them will not expand — was that intentional? If not, apply one of the workarounds above
|
||||
6. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
7. **Verify substitution** by mentally (or actually) replacing `$1`/`$@` with a sample value and confirming every reference resolves — and that the prompt's *own* example snippets don't accidentally bump the required-arg count (wrap illustrative `$N` examples in triple-backtick fences, not 4-space indentation, so `RequiredArgs()` ignores them)
|
||||
8. **Confirm** by showing the final file content and the slash command that activates it (e.g. `/code-review 42`)
|
||||
|
||||
## Guidelines
|
||||
|
||||
@@ -47,3 +81,4 @@ $1 $2 etc. for positional arguments.
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
- When in doubt about substitution behaviour, write the file and run `/<slug> testvalue` once to confirm — wrong placement of backticks is the #1 failure mode
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
description: Audit and update project documentation (README and docs site) for a recent change
|
||||
---
|
||||
|
||||
Review recent code changes, identify all documentation surfaces that should
|
||||
mention them, and update each one — grounded in the actual diff, not guesses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Identify the change**:
|
||||
- If the user input ($@) names a commit / PR / branch / topic, use that as the focus
|
||||
- Otherwise inspect `git log origin/main..HEAD --oneline` and `git diff origin/main...HEAD --stat` to discover what shipped on the current branch
|
||||
- Read the actual diff (`git diff origin/main...HEAD`) — never document features that aren't in the code
|
||||
|
||||
2. **Inventory the doc surfaces**:
|
||||
- `README.md` at the repo root
|
||||
- Any docs site (commonly `www/`, `docs/`, `site/`) — list its pages and identify the one(s) most thematically related to the change
|
||||
- Inline godoc / API reference comments on the new exported symbols
|
||||
- `CHANGELOG.md` if the project keeps one
|
||||
- Any `examples/` directory entries that demonstrate the affected area
|
||||
|
||||
3. **Audit each surface** with `grep`:
|
||||
- Search for the names of related existing APIs (e.g. if you added `IterTools`, grep for `ListTools`) to find every page that already discusses the area
|
||||
- Decide for each hit: does it need a cross-reference, a side-by-side comparison, or to stay untouched?
|
||||
|
||||
4. **Decide where new content lives**:
|
||||
- Prefer extending an existing page over creating a new one
|
||||
- For a docs site, place new sections near related content (check the page's `## Heading` outline first)
|
||||
- Skip surfaces that genuinely don't apply (e.g. a server-focused README for a client-only change) and say so explicitly
|
||||
|
||||
5. **Draft the updates**:
|
||||
- Lead with a one-sentence statement of what's new and why
|
||||
- Show concrete code examples copied from real signatures — verify against the source files
|
||||
- Include a comparison / "when to use which" table when adding an alternative to an existing API
|
||||
- Note backwards-compatibility behaviour if relevant
|
||||
|
||||
6. **Verify the docs build** before committing:
|
||||
- For vocs / docusaurus / mkdocs sites, run the local build command (e.g. `npx vocs build`, `mkdocs build`) and fix any MDX/markdown errors
|
||||
- For godoc, run `go vet ./...` and `go doc <pkg> <Symbol>` to sanity-check rendering
|
||||
|
||||
7. **Report**:
|
||||
- List every file changed and every file deliberately left alone (with a one-line reason)
|
||||
- Suggest the next step (typically `/commit-push`) — do not auto-commit unless asked
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the diff before writing anything — invented API names erode trust faster than missing docs
|
||||
- One change per doc commit; keep doc updates separate from code changes when possible
|
||||
- Match the existing voice and formatting of each surface (headings, code-fence languages, table styles)
|
||||
- Prefer linking between pages over duplicating content
|
||||
|
||||
$@
|
||||
@@ -162,6 +162,11 @@ mcpServers:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers that don't require auth
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
url: "https://builds.mcp.example.com"
|
||||
tasksMode: always # async task execution — see MCP Tasks below
|
||||
```
|
||||
|
||||
## CLI Reference
|
||||
@@ -626,6 +631,36 @@ in a custom `MCPTokenStoreFactory` for encrypted, DB-backed, or in-memory
|
||||
storage. See the [SDK options docs](/sdk/options#mcp-oauth-authorization) for
|
||||
the full matrix.
|
||||
|
||||
### MCP Tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`, so cooperating MCP servers can respond to `tools/call`
|
||||
with a `taskId` instead of blocking the connection. Kit then polls
|
||||
`tasks/get` / `tasks/result` until the task reaches a terminal state, and
|
||||
best-effort `tasks/cancel`s on context cancellation.
|
||||
|
||||
Defaults are safe — a server that doesn't advertise task capability runs
|
||||
synchronously, exactly as before. Opt in per server via `tasksMode` in
|
||||
`.kit.yml` (`auto` | `never` | `always`) or programmatically through the SDK:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskTimeout: 15 * time.Minute,
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s: %s", p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
```
|
||||
|
||||
See the [configuration docs](/configuration#mcp-tasks-long-running-tools) and
|
||||
[SDK options → MCP Tasks](/sdk/options#mcp-tasks) for the full surface.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
@@ -721,6 +756,45 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Runtime Skills & Context Files
|
||||
|
||||
For multi-tenant hosts (chatbots, per-user agents, web services), the SDK
|
||||
lets you swap skills and `AGENTS.md`-style context files **after** Kit
|
||||
construction. Every mutation recomposes the system prompt and applies it to
|
||||
the agent so the next turn picks up the new instructions — no restart needed.
|
||||
|
||||
```go
|
||||
// Programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Per-user AGENTS.md content pulled from a database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state on logout.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set atomically.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Skills dedupe by `Name`, context files dedupe by `Path` (which can be any
|
||||
opaque identifier — it doesn't have to be a real filesystem path). All
|
||||
mutators and readers (`GetSkills`, `GetContextFiles`) are safe to call
|
||||
concurrently from multiple goroutines. See the [SDK overview docs](/sdk/overview#runtime-skills-and-context-files)
|
||||
for the full reference.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Subagent Pattern
|
||||
|
||||
@@ -0,0 +1,473 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// extensionContextDeps groups the runtime dependencies needed to wire up
|
||||
// an extensions.Context for the interactive TUI mode.
|
||||
type extensionContextDeps struct {
|
||||
ctx context.Context
|
||||
cwd string
|
||||
modelName string
|
||||
interactive bool
|
||||
kitInstance *kit.Kit
|
||||
appInstance *app.App
|
||||
usageTracker *ui.UsageTracker
|
||||
}
|
||||
|
||||
// buildInteractiveExtensionContext returns an extensions.Context with every
|
||||
// field except Print / PrintInfo / PrintError populated. Callers must set
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
+100
-805
File diff suppressed because it is too large
Load Diff
@@ -13,8 +13,6 @@
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
// - The extension runner serializes handler calls per-extension, so
|
||||
// concurrent subagent events cannot race on this shared state.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -45,8 +43,7 @@ const (
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state — safe because the runner serializes all handler
|
||||
// invocations for the same extension (per-extension reentrant mutex).
|
||||
// Package-level state - all simple types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
@@ -285,8 +282,8 @@ func Init(api ext.API) {
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry — build a new slice to avoid aliasing bugs
|
||||
newEntries := make([]*submonEntry, 0, len(submonEntries))
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -26,7 +26,7 @@ func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -84,7 +84,7 @@ func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -134,7 +134,7 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -186,7 +186,7 @@ func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.6
|
||||
charm.land/fantasy v0.19.0
|
||||
charm.land/fantasy v0.25.0
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/alecthomas/chroma/v2 v2.26.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.12.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/coder/acp-go-sdk v0.13.0
|
||||
github.com/fsnotify/fsnotify v1.10.1
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.49.0
|
||||
github.com/mark3labs/mcp-go v0.54.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/term v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -37,21 +37,21 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0 // indirect
|
||||
github.com/aws/smithy-go v1.25.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 // indirect
|
||||
github.com/aws/smithy-go v1.26.0 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
@@ -59,44 +59,46 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/dlclark/regexp2 v1.12.0 // indirect
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.2 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.18 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.8 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.5.2 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.5 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.13 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/gjson v1.19.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
@@ -104,21 +106,21 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 // indirect
|
||||
go.opentelemetry.io/otel v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.44.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/crypto v0.52.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.276.0 // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/api v0.282.0 // indirect
|
||||
google.golang.org/genai v1.58.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
|
||||
google.golang.org/grpc v1.81.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -129,13 +131,13 @@ require (
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0
|
||||
)
|
||||
|
||||
@@ -2,8 +2,8 @@ charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
|
||||
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
|
||||
charm.land/fantasy v0.19.0 h1:fnNXkIJ/xcIW3sdVtWxjtQGpWWe8pDGhBCWSHkgbrd0=
|
||||
charm.land/fantasy v0.19.0/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/fantasy v0.25.0 h1:oXOWY1ivmTSnhYGzAolscF8zKtavWZyBWv0LHRSwN5Q=
|
||||
charm.land/fantasy v0.25.0/go.mod h1:8QrWUzIcKwZQP+aAnC9vLu3iID6hu9/Jt+rPMiieBkc=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
@@ -28,42 +28,42 @@ github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1 h1:nv2AVZdTyClGbVQkIzlDm/rnhk1E9bU9nXwmZ/Vk/iY=
|
||||
github.com/alecthomas/chroma/v2 v2.23.1/go.mod h1:NqVhfBR0lte5Ouh3DcthuUCTUpDC9cxBOfyMbMQPs3o=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1 h1:2X21EdxGZNv5GF9mG5u+uzc02GCFyGxbcBm3Grd9A78=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1/go.mod h1:lxhRRa9H4hPmRLOOdYga4zkQIQjq3dtrrdwQeCfu78Y=
|
||||
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6 h1:1AX0AthnBQzMx1vbmir3Y4WsnJgiydmnJjiLu+LvXOg=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.6/go.mod h1:dy0UzBIfwSeot4grGvY1AqFWN5zgziMmWGzysDnHFcQ=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9 h1:adBsCIIpLbLmYnkQU+nAChU5yhVTvu5PerROm+/Kq2A=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.9/go.mod h1:uOYhgfgThm/ZyAuJGNQ5YgNyOlYfqnGpTHXvk3cpykg=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16 h1:Q0iQ7quUgJP0F/SCRTieScnaMdXr9h/2+wze1u3cNeM=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.16/go.mod h1:duCCnJEFqpt2RC6no1iK6q+8HpwOAkiUua0pY507dQc=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15 h1:fyvgWTszojq8hEnMi8PPBTvZdTtEVmAVyo+NFLHBhH4=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.15/go.mod h1:gJiYyMOjNg8OEdRWOf3CrFQxM2a98qmrtjx1zuiQfB8=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22 h1:IOGsJ1xVWhsi+ZO7/NW8OuZZBtMJLZbk4P5HDjJO0jQ=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.22/go.mod h1:b+hYdbU+jGKfXE8kKM6g1+h+L/Go3vMvzlxBsiuGsxg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22 h1:GmLa5Kw1ESqtFpXsx5MmC84QWa/ZrLZvlJGa2y+4kcQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.22/go.mod h1:6sW9iWm9DK9YRpRGga/qzrzNLgKpT2cIxb7Vo2eNOp0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22 h1:dY4kWZiSaXIzxnKlj17nHnBcXXBfac6UlsAx2qL6XrU=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.22/go.mod h1:KIpEUx0JuRZLO7U6cbV204cWAEco2iC3l061IxlwLtI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23 h1:FPXsW9+gMuIeKmz7j6ENWcWtBGTe1kH8r9thNt5Uxx4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.23/go.mod h1:7J8iGMdRKk6lw2C+cMIphgAnT8uTwBwNOsGkyOCm80U=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8 h1:HtOTYcbVcGABLOVuPYaIihj6IlkqubBwFj10K5fxRek=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.8/go.mod h1:VsK9abqQeGlzPgUr+isNWzPlK2vKe9INMLWnY65f5Xs=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22 h1:PUmZeJU6Y1Lbvt9WFuJ0ugUK2xn6hIWUBBbKuOWF30s=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.22/go.mod h1:nO6egFBoAaoXze24a2C0NjQCvdpk8OueRoYimvEB9jo=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10 h1:a1Fq/KXn75wSzoJaPQTgZO0wHGqE9mjFnylnqEPTchA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.10/go.mod h1:p6+MXNxW7IA6dMgHfTAzljuwSKD0NCm/4lbS4t6+7vI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16 h1:x6bKbmDhsgSZwv6q19wY/u3rLk/3FGjJWyqKcIRufpE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.16/go.mod h1:CudnEVKRtLn0+3uMV0yEXZ+YZOKnAtUJ5DmDhilVnIw=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20 h1:oK/njaL8GtyEihkWMD4k3VgHCT64RQKkZwh0DG5j8ak=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.20/go.mod h1:JHs8/y1f3zY7U5WcuzoJ/yAYGYtNIVPKLIbp61euvmg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0 h1:ks8KBcZPh3PYISr5dAiXCM5/Thcuxk8l+PG4+A0exds=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.0/go.mod h1:pFw33T0WLvXU3rw1WBkpMlkgIn54eCB5FYLhjDc9Foo=
|
||||
github.com/aws/smithy-go v1.25.0 h1:Sz/XJ64rwuiKtB6j98nDIPyYrV1nVNJ4YU74gttcl5U=
|
||||
github.com/aws/smithy-go v1.25.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 h1:sRs7nG6/RiEBZ/K5UO2sNw0w40U02Nmz1VtARloTZXk=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 h1:qRhIJMbevHUvIE7X4TK8N8zye5+5AhapcslPrvB+qKE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19/go.mod h1:RbJ24nfoya63+Mf5VI+CGCGk9vEdv28xPeii+gojRYs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 h1:GcXQz2M/0ZvMo0v5DakUqbDBeBM1ZNaivkolEF4Esgw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18/go.mod h1:sHJ06tMGcD3ZpmMyJqV+VBsGilhSIZPIN+ZFy5Dg0C4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 h1:FQm5ApnyzkuJdXLGskPce83CK1CQKC4RUnIHKVe4BU4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24/go.mod h1:JsC7dqQc55MlZ5mvNsDMMge71u8pVcSzU3RNz2h/5yQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 h1:u6kJU2i0va1AgtJsH3RdWKWqHULlTh7zHwb35Womf74=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24/go.mod h1:7GY+xLcXOFUpCkNwDReft9qOAVg54A4/AnjHIU7sSAY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 h1:Xhbcf3KugX6vX7SDyUK205Oicyfg7EGuvoVNyP5L6DM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24/go.mod h1:rwDgb2HNOGZsnTHylOUedM7Vnl+bCfnXDqUNPsFWYfk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 h1:54CTMmlJ71Rk2dYvM9qZOob+39wjlVja2zDLxCu69Ew=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25/go.mod h1:BZaHqxsS9vN1fvV5EfEl0OBLOk5+AajWsMu6MjqnZB4=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 h1:CQW2FTrflfoslYWLf3fv7vG28Q219+v8YJS5QTQb2+Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24/go.mod h1:Xfx13T+u3nH6EEzgl9fBSO6nDRmze1FvnZNYkctQ2zw=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 h1:yQo3eZ5qFaL1sJWqs1nL6j3yPHA2/R7c6tQ4T+0IO10=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0/go.mod h1:3Zzou41Qt/ueXfIzHvTEjDNuR5IjCUBVF01SNhrt1e8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 h1:ApLTFdAZfDhZSiY5uskwECKHkSNNF83y2Ru2r7SezWA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18/go.mod h1:A9K9qx2l6nK89hp+a350FdGfRkrkH5HdiEjHbiy/Q/c=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 h1:4VD7TIZOGzehrgQ8vDE+1c6BQW4ErZPGY8ohZT5LXEE=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1/go.mod h1:er0SFJfdV89Rit5hIJu/EXtv+qC2XMnxoksLmcUFkqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 h1:XKnxlM4KZH1gktcsh3zSWc7GW4KivEv/OkifmHOhCUY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2/go.mod h1:KJYmkQaFB3SUW2j3aBkPsxNmAb4ZsSOvbvCpuxzHJA0=
|
||||
github.com/aws/smithy-go v1.26.0 h1:9ouqbi+NyKP7fV3Te7UElCwdAb6Y8uk7LGwPE5tVe/s=
|
||||
github.com/aws/smithy-go v1.26.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
@@ -86,8 +86,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7 h1:PbFxahSfyADcQOp+7WxbeqN3wX37KA/Rk+EXOW1xS9Q=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260420095748-421e4a7fa8d7/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654 h1:FpSYhY28ucg9ZRr+2wj67FAQ0Ey5yiK0072PmRDJNek=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260525132238-948f4557a654/go.mod h1:hFpumms29Smx3LStRfku8vcCTBe1Kq8aCXtHUJa3mjY=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -98,14 +98,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5 h1:3ElWZRQqSRqML2P/r2TmuSkdXPMDI+Jg3f0bGA6Ekg4=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260420102150-fe550f2efce5/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d h1:sMilwx1YIYTrQva6jsB522AoRYAerNaDIKP4ZPtUq0A=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260527151214-009e6338d40d/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5 h1:QqpW1CPNAnOpM3Nj0X7IT2IFlR90bLdAkO5+A3Hwbi4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260420102150-fe550f2efce5/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d h1:RxcAR+vJCoD8QqT1cqLtkQKw+1cqvjqnu5IpPqYzPco=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -124,8 +124,8 @@ github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJ
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.12.0 h1:GoIC6RrkPMBIVQ3ckSkl+bO/ERV/IRK6clBdZmx4Uf4=
|
||||
github.com/coder/acp-go-sdk v0.12.0/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/coder/acp-go-sdk v0.13.0 h1:IAKBDIbe/iBfKAGikeIndzb8fowt4ioD+gCtSU4HwMA=
|
||||
github.com/coder/acp-go-sdk v0.13.0/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
@@ -133,6 +133,8 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
|
||||
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 h1:LCUGyd9Wf+r+VVOl8Ny38JTpWJcAsdVnCIuhhtthmKw=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
@@ -146,10 +148,10 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 h1:vymEbVwYFP/L05h5TKQxvkXoKxNvTpjxYKdF1Nlwuao=
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 h1:NZBJxCpbHS1gzS6xAmyxbJznosZIIPk9IB42v62UvKA=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
|
||||
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -167,14 +169,14 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 h1:xolVQTEXusUcAA5UgtyRLjelpFFHWlPQ4XfWGc7MBas=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 h1:F/VPrx0YPBdksZJQdCAp0WUsqnNmZpUZszzfYt0M5Dw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16/go.mod h1:9Yb0eAkH/Xqhvv3zbeKf/+wMJqCeocWc6KIhDvEAuYE=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
@@ -187,14 +189,14 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.2 h1:52gGOx4ZwbLEiOyDMNA1ax2WktKlrKsmV6Ydf9Tw3/I=
|
||||
github.com/kaptinlin/go-i18n v0.4.2/go.mod h1:IACLIi+sHn3pGyryFMiqr2N1CJry4OKFD0MAEneEVQk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.18 h1:EDUXT4WKpOKguU7oaFv6VaNatN7uHFe6dEYHX0+OFxs=
|
||||
github.com/kaptinlin/jsonpointer v0.4.18/go.mod h1:ndmfvrqrEDSbV3F7yGaOuDvr29WrxYU1aqkvef9L2do=
|
||||
github.com/kaptinlin/jsonschema v0.7.8 h1:aHv28bYtfLfUXYI/10Phb1nvVyLXNz1lmu73vtKmlOY=
|
||||
github.com/kaptinlin/jsonschema v0.7.8/go.mod h1:cz7SK0jTHdabKdQp+SwBKKmOeZ55txuNo72Jx9Sbb2w=
|
||||
github.com/kaptinlin/messageformat-go v0.5.2 h1:E+D5oQVRepHgyMiLWRHnPXYFbqBDI4Sek7/CTIAByj4=
|
||||
github.com/kaptinlin/messageformat-go v0.5.2/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/kaptinlin/go-i18n v0.4.5 h1:9tIlo5A0RXth+yZJO2MG7Bhpu/X9PlzQnGz/qyYWNoY=
|
||||
github.com/kaptinlin/go-i18n v0.4.5/go.mod h1:mU/7BH4molY5lGZYBwBRKAaiJ70dWRHuqmQ0/pFLGno=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 h1:iJ197e8n+WwqaqBsa53FqG3rPJCg5oijyFXEXNWWC3E=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
|
||||
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 h1:D6jiXFsKW4/JG2CMddv/F6Rev9KVbCRKEzzV5QOAcpc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -203,10 +205,10 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.49.0 h1:7Ssx4d7/T86qnWoJIdye7wEEvUzv39UIbnZb/FqUZMY=
|
||||
github.com/mark3labs/mcp-go v0.49.0/go.mod h1:BflTAZAzXlrTpiO44gmjMu89n2FO56rJ9m31fp4zd5k=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mark3labs/mcp-go v0.54.1 h1:Ap/ptEB9FtWzFKM8NDsTA7QDxerQOC06eZigrTldVj0=
|
||||
github.com/mark3labs/mcp-go v0.54.1/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
@@ -223,8 +225,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
|
||||
github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo=
|
||||
@@ -238,6 +240,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -254,8 +258,8 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
@@ -274,54 +278,54 @@ github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 h1:2yEATaop1/a1I4psnSLgWVPLWwCzkqWakgJy7xTDVy0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0/go.mod h1:D7J12YRapIekYyPWgGPlA/23pRmpSEZC5xJC/TTLI9U=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 h1:8tvICD4vSTOOsNrsI4Ljf6C+6UKvpTEH5XY3JMoyPoo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0/go.mod h1:z9+yiacE0IHRqM4qFfkbt/JYlmYXgss8GY/jXoNuPJI=
|
||||
go.opentelemetry.io/otel v1.44.0 h1:JjwHmHpA4iZ3wBxluu2fbbE7j4kqlE8jXyAyPXH7HqU=
|
||||
go.opentelemetry.io/otel v1.44.0/go.mod h1:BMgjTHL9WPRlRjL2oZCBTL4whCGtXch2H4BhOPIAyYc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSYCdDdmc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0 h1:nHYwb9lK+fJPU/dnT6s7W7Z8itMWyqrnVfbheVYrZ58=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0/go.mod h1:Osuydd3Se74nqjAKxid74N5eC+jfEqfTegHRnq58oK0=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0 h1:3LlKgI+VjbVsjNRFZJZAJ30WjXC5VkNRks6si09iEfI=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0/go.mod h1:5B5pMARnXxKhltooO4xUuCBorl65a4EpnTalObqOigA=
|
||||
go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk=
|
||||
go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7 h1:cHpkPjp4TILjdZxz/O4ykwCpeS+dDqNuDGse4zgQDCk=
|
||||
golang.org/x/exp v0.0.0-20260528193900-50dc527dd6c7/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY=
|
||||
google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 h1:XF8+t6QQiS0o9ArVan/HW8Q7cycNPGsJf6GA2nXxYAg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/api v0.282.0 h1:WmJiSVqUnKqJCpJOx7YADbXaC+9DDsnGSfllFSj7R2I=
|
||||
google.golang.org/api v0.282.0/go.mod h1:6Wssta4c5n9qHq5CBhmlai5h/PUa1djdDAIhYEHyvcM=
|
||||
google.golang.org/genai v1.58.0 h1:MNA3ZkRyr7MnRwZ9RNZ60p4+UMKV3yYRw6pyHq4pp0U=
|
||||
google.golang.org/genai v1.58.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348 h1:JjVGDZYWkJWZcxveJGzfkXC5myDVWAd4dZdgbzrDUv8=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348/go.mod h1:95PqD4xM+AdOcBGsmgfaofXsiA37uXDtDufVbntT3TU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348 h1:U8orV30l6KpDsi9dxU0CoJZGbjS8EEpw+6ba+XwGPQA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348/go.mod h1:Yzdzr5OOZFgSsEV2D/Xi9NL3bszpXFAg0hFJiRohcD8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -185,6 +185,26 @@ func (a *Agent) ListSessions(_ context.Context, _ acp.ListSessionsRequest) (acp.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CloseSession cancels any ongoing work for the session and frees its resources.
|
||||
func (a *Agent) CloseSession(_ context.Context, params acp.CloseSessionRequest) (acp.CloseSessionResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
log.Debug("acp: close session", "session", sessionID)
|
||||
sess.cancelPrompt()
|
||||
a.registry.remove(sessionID)
|
||||
return acp.CloseSessionResponse{}, nil
|
||||
}
|
||||
|
||||
// ResumeSession is not supported — Kit doesn't persist sessions across
|
||||
// restarts in ACP mode. Clients should use NewSession instead.
|
||||
func (a *Agent) ResumeSession(_ context.Context, _ acp.ResumeSessionRequest) (acp.ResumeSessionResponse, error) {
|
||||
return acp.ResumeSessionResponse{}, fmt.Errorf("resume session not supported")
|
||||
}
|
||||
|
||||
// SetSessionConfigOption handles session configuration changes. Currently
|
||||
// supports the "model" config option to change the active model for a session.
|
||||
func (a *Agent) SetSessionConfigOption(ctx context.Context, params acp.SetSessionConfigOptionRequest) (acp.SetSessionConfigOptionResponse, error) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -152,38 +153,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
@@ -232,6 +202,20 @@ func (r *sessionRegistry) closeAll() {
|
||||
}
|
||||
}
|
||||
|
||||
// remove closes and removes a single session by ID.
|
||||
func (r *sessionRegistry) remove(sessionID string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
sess, ok := r.sessions[sessionID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if sess.kit != nil {
|
||||
_ = sess.kit.Close()
|
||||
}
|
||||
delete(r.sessions, sessionID)
|
||||
}
|
||||
|
||||
// cancelPrompt cancels the current prompt for a session, if any.
|
||||
func (s *acpSession) cancelPrompt() {
|
||||
s.cancelMu.Lock()
|
||||
@@ -255,40 +239,3 @@ func (s *acpSession) clearCancel() {
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
+57
-5
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -59,6 +60,11 @@ type AgentConfig struct {
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour for any
|
||||
// server that didn't advertise task support during initialize.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -231,11 +237,21 @@ type Agent struct {
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpTaskConfig is stored from AgentConfig so AddMCPServer() can
|
||||
// propagate it to a lazily-created MCPToolManager.
|
||||
mcpTaskConfig tools.MCPTaskConfig
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
|
||||
// promptMu serializes runtime updates to systemPrompt and the
|
||||
// accompanying fantasy agent rebuild so concurrent SetSystemPrompt
|
||||
// callers (e.g. Kit.applyComposedSystemPrompt invoked from multiple
|
||||
// goroutines) don't race on a.systemPrompt / a.fantasyAgent.
|
||||
promptMu sync.Mutex
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -329,6 +345,7 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
mcpTaskConfig: agentConfig.MCPTaskConfig,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
@@ -348,6 +365,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
// Apply task-augmented tool execution config (zero value = no-op).
|
||||
toolManager.SetTaskConfig(agentConfig.MCPTaskConfig)
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
@@ -573,8 +592,13 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
// Track tool call args per-ToolCallID so parallel tool calls in a single
|
||||
// step don't clobber each other. Without this, OnToolResult callbacks would
|
||||
// all see the args of the last OnToolCall in the step. The mutex guards
|
||||
// against the possibility that the underlying streaming layer dispatches
|
||||
// callbacks from multiple goroutines.
|
||||
toolCallArgs := make(map[string]string)
|
||||
var toolCallArgsMu sync.Mutex
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
@@ -761,7 +785,9 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolArgs = tc.Input
|
||||
toolCallArgsMu.Lock()
|
||||
toolCallArgs[tc.ToolCallID] = tc.Input
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify about the tool call
|
||||
if cb.OnToolCall != nil {
|
||||
@@ -781,15 +807,22 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Look up the args recorded for this specific tool call. Delete
|
||||
// the entry so the map doesn't accumulate across steps.
|
||||
toolCallArgsMu.Lock()
|
||||
args := toolCallArgs[tr.ToolCallID]
|
||||
delete(toolCallArgs, tr.ToolCallID)
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify tool execution finished
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, args, false)
|
||||
}
|
||||
|
||||
if cb.OnToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, args, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1134,6 +1167,7 @@ func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPSer
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetTaskConfig(a.mcpTaskConfig)
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
@@ -1290,6 +1324,24 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// SetSystemPrompt updates the agent's system prompt and rebuilds the underlying
|
||||
// fantasy agent so subsequent turns use the new prompt. Safe to call while the
|
||||
// agent is idle; if invoked during an in-flight turn the new prompt takes
|
||||
// effect on the next LLM call.
|
||||
func (a *Agent) SetSystemPrompt(prompt string) {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
a.systemPrompt = prompt
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the agent's current system prompt.
|
||||
func (a *Agent) GetSystemPrompt() string {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
return a.systemPrompt
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the effective max output tokens the agent currently
|
||||
// sends to the LLM provider, after per-model defaults, right-sizing, and any
|
||||
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// fakeParallelAgent simulates a provider that emits two parallel tool_use
|
||||
// blocks in a single step. It invokes the streaming callbacks in the order:
|
||||
//
|
||||
// OnToolCall(A) -> OnToolCall(B) -> OnToolResult(A) -> OnToolResult(B)
|
||||
//
|
||||
// Before the fix in #33 the agent-layer wrapper recorded a single
|
||||
// `currentToolArgs` variable that was clobbered by the second OnToolCall, so
|
||||
// both OnToolResult callbacks received B's args instead of their own.
|
||||
type fakeParallelAgent struct {
|
||||
calls []fantasy.ToolCallContent
|
||||
results []fantasy.ToolResultContent
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Stream(_ context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
|
||||
for _, tc := range f.calls {
|
||||
if opts.OnToolCall != nil {
|
||||
if err := opts.OnToolCall(tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tr := range f.results {
|
||||
if opts.OnToolResult != nil {
|
||||
if err := opts.OnToolResult(tr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
// TestGenerateWithCallbacks_ParallelToolArgs is the regression test for #33.
|
||||
// It drives the streaming-callback wiring inside GenerateWithCallbacks with a
|
||||
// fake fantasy.Agent that emits two parallel tool calls before either result.
|
||||
// Each OnToolResult must receive the args of its own tool call (matched by
|
||||
// ToolCallID), not the args of the last OnToolCall in the step.
|
||||
func TestGenerateWithCallbacks_ParallelToolArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
argsA := `{"name":"scheduled_jobs"}`
|
||||
argsB := `{"name":"gmail_trigger"}`
|
||||
|
||||
fake := &fakeParallelAgent{
|
||||
calls: []fantasy.ToolCallContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Input: argsA},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Input: argsB},
|
||||
},
|
||||
results: []fantasy.ToolResultContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-A"}},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-B"}},
|
||||
},
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fake,
|
||||
streamingEnabled: false, // exercise the "hasCallbacks" branch
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
resultArgs := map[string]string{}
|
||||
executionArgs := map[string]string{} // captured when running == false
|
||||
|
||||
cb := GenerateCallbacks{
|
||||
OnToolExecution: func(id, _, args string, running bool) {
|
||||
if running {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
executionArgs[id] = args
|
||||
},
|
||||
OnToolResult: func(id, _, args, _, _ string, _ bool) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
resultArgs[id] = args
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := a.GenerateWithCallbacks(context.Background(), nil, cb); err != nil {
|
||||
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
|
||||
}
|
||||
|
||||
if got, want := resultArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolResult for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := resultArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolResult for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -56,6 +56,8 @@ type AgentCreationOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -76,6 +78,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
@@ -9,12 +9,19 @@ import (
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpExecutor is the subset of *tools.MCPToolManager that the adapter
|
||||
// actually uses. Extracted as an interface so the adapter is unit-testable
|
||||
// without constructing a full manager + connection pool.
|
||||
type mcpExecutor interface {
|
||||
ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error)
|
||||
}
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
manager *tools.MCPToolManager
|
||||
exec mcpExecutor
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
@@ -29,10 +36,26 @@ func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
//
|
||||
// MCP-side failures (JSON-RPC protocol errors, transport failures, schema
|
||||
// validation rejections from the server) are surfaced to the model as soft
|
||||
// tool errors rather than escalated to a critical agent error. This matches
|
||||
// the contract that native Kit tools follow via kit.ErrorResult(...) and
|
||||
// lets the model self-correct (e.g. retry with a fixed argument shape) or
|
||||
// give up gracefully rather than aborting the turn mid-run.
|
||||
//
|
||||
// Context cancellation is the one exception: if the caller cancelled the
|
||||
// context the turn was aborted intentionally, so we propagate the ctx error
|
||||
// to let the agent loop unwind cleanly.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
result, err := t.exec.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return fantasy.ToolResponse{}, ctxErr
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("MCP tool %q failed: %s", t.tool.Name, err.Error()),
|
||||
), nil
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
@@ -57,8 +80,8 @@ func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManage
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
manager: manager,
|
||||
tool: t,
|
||||
exec: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// stubExecutor lets each test script the (result, err) pair returned by
|
||||
// ExecuteTool. The adapter holds an mcpExecutor interface, so this is the
|
||||
// only seam the tests need.
|
||||
type stubExecutor struct {
|
||||
result *tools.MCPToolResult
|
||||
err error
|
||||
// called records the last invocation for assertion.
|
||||
called bool
|
||||
name string
|
||||
input string
|
||||
}
|
||||
|
||||
func (s *stubExecutor) ExecuteTool(_ context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error) {
|
||||
s.called = true
|
||||
s.name = prefixedName
|
||||
s.input = inputJSON
|
||||
return s.result, s.err
|
||||
}
|
||||
|
||||
func newMCPAgentTool(exec mcpExecutor, name string) *mcpAgentTool {
|
||||
return &mcpAgentTool{
|
||||
tool: tools.MCPTool{Name: name},
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager-side Go errors (JSON-RPC protocol errors, transport failures,
|
||||
// schema validation rejections from the MCP server) must be surfaced to
|
||||
// the model as soft tool errors so the agent loop can keep going. Aborting
|
||||
// the turn would discard all prior tool results — see issue #N.
|
||||
func TestMCPAgentTool_RPCErrorBecomesSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
err: errors.New("MCP error -32602: Invalid params: missing field \"task\""),
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "pubmed__search",
|
||||
Input: `{"query":"foo"}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft), got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "pubmed__search") {
|
||||
t.Errorf("expected tool name in error content, got %q", resp.Content)
|
||||
}
|
||||
if !strings.Contains(resp.Content, "-32602") {
|
||||
t.Errorf("expected underlying error text in content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Context cancellation is the one error that must remain critical: it
|
||||
// means the caller intentionally aborted, and the agent loop needs to
|
||||
// unwind cleanly rather than burning more steps.
|
||||
func TestMCPAgentTool_CtxCancelStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
// Real managers typically return ctx.Err() (or a wrapper) when the
|
||||
// context is cancelled mid-call.
|
||||
err: context.Canceled,
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline-exceeded behaves the same as cancellation: ctx.Err() is
|
||||
// non-nil, so the adapter must propagate the critical error rather than
|
||||
// converting the executor's error into a soft response.
|
||||
func TestMCPAgentTool_CtxDeadlineStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{err: context.DeadlineExceeded}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side soft errors (CallToolResult{ isError: true }) must continue
|
||||
// to flow through as soft errors — this was the existing behavior and
|
||||
// must not regress.
|
||||
func TestMCPAgentTool_ServerIsErrorRemainsSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: true,
|
||||
Content: "search service is rate limited; try again in 30s",
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if resp.Content != "search service is rate limited; try again in 30s" {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Happy path: ordinary successful tool result is passed through unchanged.
|
||||
func TestMCPAgentTool_SuccessIsPassthrough(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: false,
|
||||
Content: `{"hits":3}`,
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected IsError=false")
|
||||
}
|
||||
if resp.Content != `{"hits":3}` {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
+132
-33
@@ -70,14 +70,24 @@ type App struct {
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
// widgetUpdatePending is set to true while a WidgetUpdateEvent burst is
|
||||
// being coalesced. The leading edge fires immediately; subsequent calls
|
||||
// within the debounce window set widgetUpdateTrailing so a final event
|
||||
// is delivered with the latest runner state at the end of the window.
|
||||
// Without the trailing send, a rapid SetWidget→RemoveWidget pair (e.g.
|
||||
// SubagentEnd pushing a final frame then removing the widget) would let
|
||||
// the second call get silently dropped, leaving the TUI's layout stuck
|
||||
// on the pre-removal widget height — visible as empty rows below the
|
||||
// status bar after the widget disappears.
|
||||
widgetUpdatePending atomic.Bool
|
||||
widgetUpdateTrailing atomic.Bool
|
||||
|
||||
// steerDrainFn is the test seam used by releaseBusyAfterCompact to pull
|
||||
// any steer messages that arrived during compaction. In production it is
|
||||
// nil and the helper falls back to a.opts.Kit.DrainSteer(); tests that
|
||||
// need to exercise the steer-drain path without standing up a full
|
||||
// *kit.Kit can set this field directly to inject fake items.
|
||||
steerDrainFn func() []queueItem
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
@@ -356,6 +366,10 @@ func (a *App) AddContextMessage(text string) {
|
||||
// tea.Program. customInstructions is optional text appended to the summary
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
//
|
||||
// Any prompts queued via Run/RunWithFiles or steering messages injected via
|
||||
// Steer/SteerWithFiles while compaction is running are flushed automatically
|
||||
// once compaction completes (see releaseBusyAfterCompact).
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Lock()
|
||||
@@ -377,11 +391,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -420,6 +430,9 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
//
|
||||
// Like CompactConversation, any prompts/steer messages received during
|
||||
// compaction are flushed automatically once compaction finishes.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
@@ -440,11 +453,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -489,6 +498,81 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseBusyAfterCompact is the deferred tail that runs at the end of every
|
||||
// compaction goroutine (success, error, or panic-after-recover paths). It
|
||||
// flips a.busy back to false, but before doing so it checks whether any
|
||||
// prompts piled up while compaction was running:
|
||||
//
|
||||
// - Run/RunWithFiles append to a.queue when a.busy is set.
|
||||
// - Steer/SteerWithFiles deposit messages into the SDK steer channel via
|
||||
// Kit.InjectSteerWithFiles when a.busy is set.
|
||||
//
|
||||
// Without this hand-off the queue would sit idle until the user submits
|
||||
// another prompt — see issue #27. If we find anything pending we keep busy
|
||||
// set, splice the steer messages to the front of the queue, and start a
|
||||
// fresh drainQueue goroutine to deliver them as a single batched turn.
|
||||
func (a *App) releaseBusyAfterCompact() {
|
||||
// Pull steer messages outside the app mutex; DrainSteer takes its own
|
||||
// internal lock and we don't want to nest the two. The test seam
|
||||
// (a.steerDrainFn) takes precedence so unit tests can inject fake
|
||||
// steer items without a real *kit.Kit.
|
||||
var steerItems []queueItem
|
||||
switch {
|
||||
case a.steerDrainFn != nil:
|
||||
steerItems = a.steerDrainFn()
|
||||
case a.opts.Kit != nil:
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
steerItems = make([]queueItem, len(leftover))
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
// If the app was closed while compaction was running, drop everything
|
||||
// and just clear busy. Run/Steer would have rejected new items already
|
||||
// after Close(), but this guards against in-flight items that slipped
|
||||
// in just before closed was set.
|
||||
if a.closed {
|
||||
a.queue = a.queue[:0]
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Combine steer-channel items (front) with the in-memory queue (back).
|
||||
// Steer messages are placed first so they retain their "act now"
|
||||
// semantics relative to ordinary queued prompts that arrived later.
|
||||
pending := append(steerItems, a.queue...)
|
||||
a.queue = a.queue[:0]
|
||||
|
||||
if len(pending) == 0 {
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hand off to drainQueue: it will pick up the first item directly and
|
||||
// scoop the rest from a.queue on its first iteration.
|
||||
first := pending[0]
|
||||
if len(pending) > 1 {
|
||||
a.queue = append(a.queue, pending[1:]...)
|
||||
}
|
||||
// Stay busy across the goroutine swap.
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Notify the UI that steer-channel messages were consumed so the
|
||||
// steering badge can clear; ordinary queued prompts will be reflected
|
||||
// by the QueueUpdatedEvent that drainQueue emits as it picks them up.
|
||||
if len(steerItems) > 0 {
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
|
||||
go a.drainQueue(first)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1076,32 +1160,47 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
// Coalescing (leading + trailing edge): the first call in an idle period
|
||||
// fires immediately for responsiveness. Subsequent calls within a ~16 ms
|
||||
// debounce window are batched into a single trailing event delivered at
|
||||
// the end of the window. The trailing send is essential for correctness:
|
||||
// extensions routinely make tight SetWidget→RemoveWidget pairs (e.g. on
|
||||
// SubagentEnd) and silently dropping the second call would leave the TUI's
|
||||
// layout stuck on stale widget dimensions until some other event happens
|
||||
// to trigger a re-render.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
// A leading-edge event is already in flight — mark that the runner
|
||||
// state has changed again so the trailing send below picks it up.
|
||||
a.widgetUpdateTrailing.Store(true)
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
if prog == nil {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
return
|
||||
}
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
// If any extra calls came in during the debounce window, deliver
|
||||
// one trailing event so the TUI sees the latest widget state. We
|
||||
// swap-and-test instead of plain-load so concurrent calls after
|
||||
// the trailing send still race correctly with the pending reset.
|
||||
if a.widgetUpdateTrailing.Swap(false) {
|
||||
a.mu.Lock()
|
||||
p := a.program
|
||||
a.mu.Unlock()
|
||||
if p != nil {
|
||||
p.Send(WidgetUpdateEvent{})
|
||||
}
|
||||
}
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
|
||||
@@ -763,3 +763,209 @@ func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// releaseBusyAfterCompact (issue #27)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestReleaseBusyAfterCompact_flushesQueuedMessages is a regression test for
|
||||
// issue #27: messages queued via Run() while /compact is running used to sit
|
||||
// in a.queue indefinitely until the user typed another prompt. After the fix
|
||||
// the deferred releaseBusyAfterCompact tail picks up any pending items and
|
||||
// dispatches drainQueue automatically.
|
||||
//
|
||||
// We simulate the compaction completion path directly (bypassing the SDK)
|
||||
// by toggling busy=true, populating the queue exactly as Run() would have
|
||||
// during compaction, and then invoking releaseBusyAfterCompact.
|
||||
func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("compacted then drained"), nil
|
||||
},
|
||||
)
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate the state at the start of the compaction tail: busy is set
|
||||
// and a couple of prompts have piled up in the queue while we were
|
||||
// summarising. (Run() would have appended them and returned a queue
|
||||
// length > 0 to the caller.)
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued during compact #1"},
|
||||
queueItem{Prompt: "queued during compact #2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
// Invoke the deferred tail directly. It should kick off drainQueue.
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// drainQueue runs in a goroutine. Wait for the app to come back to idle.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after releaseBusyAfterCompact: queue not drained")
|
||||
}
|
||||
|
||||
// Wait for any in-flight goroutine to finish before reading state.
|
||||
app.wg.Wait()
|
||||
|
||||
if got := app.QueueLength(); got != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d", got)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatalf("expected stub PromptFunc to fire at least once after compact, got %d calls", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_idleWhenQueueEmpty verifies that with no
|
||||
// pending messages the helper just clears busy and does NOT spawn a
|
||||
// drainQueue goroutine (no spurious agent turn).
|
||||
func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false after releaseBusyAfterCompact with empty queue")
|
||||
}
|
||||
|
||||
// Give any rogue goroutine a moment to (incorrectly) call PromptFunc.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls when queue empty, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue exercises the SDK
|
||||
// steer-drain branch of releaseBusyAfterCompact (issue #27 follow-up).
|
||||
//
|
||||
// Production wires a.opts.Kit.DrainSteer() to pull messages that arrived via
|
||||
// Steer/SteerWithFiles during compaction, but Options.Kit is *kit.Kit (a
|
||||
// concrete struct) so unit tests cannot stand up a real instance without a
|
||||
// full LLM backend. The test uses the unexported steerDrainFn seam to inject
|
||||
// fake steer items, then asserts that:
|
||||
//
|
||||
// - Steer items are dispatched ahead of any prompts that piled up in
|
||||
// a.queue (steer retains "act now" priority over ordinary queued
|
||||
// prompts), and
|
||||
// - the helper still hands off to drainQueue so the steer item actually
|
||||
// fires (the previous behaviour left them stranded — see #27).
|
||||
func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
|
||||
var pmu sync.Mutex
|
||||
var firstPrompt string
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("steer dispatched"), nil
|
||||
},
|
||||
)
|
||||
// Wrap PromptFunc so we can capture the prompt text the stub receives
|
||||
// (newStubWithFuncs's fns ignore prompt; we need it to verify ordering).
|
||||
capturingPrompt := func(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
pmu.Lock()
|
||||
if firstPrompt == "" {
|
||||
firstPrompt = prompt
|
||||
}
|
||||
pmu.Unlock()
|
||||
return stub.fn(ctx, prompt)
|
||||
}
|
||||
app := New(Options{PromptFunc: capturingPrompt}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Inject fake steer items via the test seam. In production the same
|
||||
// items would have been delivered through Kit.InjectSteerWithFiles
|
||||
// during /compact and pulled by DrainSteer here.
|
||||
app.steerDrainFn = func() []queueItem {
|
||||
return []queueItem{
|
||||
{Prompt: "steer-1"},
|
||||
{Prompt: "steer-2"},
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate the state at the end of compaction: busy is set and a couple
|
||||
// of regular Run() prompts have piled up after the steer messages.
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued-1"},
|
||||
queueItem{Prompt: "queued-2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// Wait for the dispatched batch to complete.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after steer-spliced releaseBusyAfterCompact")
|
||||
}
|
||||
app.wg.Wait()
|
||||
|
||||
// drainQueue picks up `first` directly and batches the rest. With
|
||||
// PromptFunc set, executeBatch invokes us with items[0] only — that
|
||||
// item must be the first steer message, proving steer items were
|
||||
// spliced ahead of the previously queued prompts.
|
||||
pmu.Lock()
|
||||
got := firstPrompt
|
||||
pmu.Unlock()
|
||||
if got != "steer-1" {
|
||||
t.Fatalf("expected first dispatched prompt to be steer item %q (steer items must come before queued prompts), got %q",
|
||||
"steer-1", got)
|
||||
}
|
||||
|
||||
// Queue should be fully drained and PromptFunc must have actually fired.
|
||||
if n := app.QueueLength(); n != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d entries", n)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatal("expected stub PromptFunc to fire at least once after splice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_dropsQueueWhenClosed verifies that if the app
|
||||
// was closed during compaction the helper discards any pending items rather
|
||||
// than spawning drainQueue against a torn-down App.
|
||||
func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
|
||||
app.closed = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
qLen := len(app.queue)
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false even when closed")
|
||||
}
|
||||
if qLen != 0 {
|
||||
t.Fatalf("expected queue cleared on closed app, got %d entries", qLen)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -255,29 +255,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
@@ -417,26 +394,6 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
// 1. Command-line flag value (highest priority)
|
||||
// 2. Stored credentials (OAuth or API key)
|
||||
|
||||
@@ -38,6 +38,23 @@ type MCPServerConfig struct {
|
||||
// servers that don't support it.
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
|
||||
// TasksMode controls when this server's tools/call requests are augmented
|
||||
// with MCP task metadata (turning a synchronous call into an asynchronous,
|
||||
// pollable job — see https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks).
|
||||
//
|
||||
// Valid values:
|
||||
// - "" or "auto": (default) augment requests with task metadata only
|
||||
// when the server advertises tasks/toolCalls capability during initialize.
|
||||
// - "never": never augment — every tool call is synchronous, regardless
|
||||
// of server capability.
|
||||
// - "always": always augment, even when the server didn't advertise
|
||||
// task support. The server may still respond synchronously; this just
|
||||
// opts in unconditionally on the client side.
|
||||
//
|
||||
// In all modes, when the server returns a CreateTaskResult the client polls
|
||||
// tasks/get / tasks/result until the task reaches a terminal state.
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
|
||||
// InProcessServer holds a live *server.MCPServer for in-process transport.
|
||||
// When set (and Type is "inprocess"), the connection pool creates an
|
||||
// in-process client instead of spawning a subprocess or making HTTP calls.
|
||||
@@ -68,6 +85,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
NoOAuth bool `json:"noOAuth,omitempty" yaml:"noOAuth,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -80,6 +98,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
TasksMode string `json:"tasksMode,omitempty" yaml:"tasksMode,omitempty"`
|
||||
}
|
||||
|
||||
// Try new format first
|
||||
@@ -96,6 +115,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
s.NoOAuth = newConfig.NoOAuth
|
||||
s.TasksMode = newConfig.TasksMode
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,6 +136,7 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = legacyConfig.Headers
|
||||
s.AllowedTools = legacyConfig.AllowedTools
|
||||
s.ExcludedTools = legacyConfig.ExcludedTools
|
||||
s.TasksMode = legacyConfig.TasksMode
|
||||
|
||||
// Infer type from legacy format for better compatibility
|
||||
// Only set Type when it doesn't change existing transport behavior
|
||||
@@ -324,6 +345,17 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("server %s: allowedTools and excludedTools are mutually exclusive", serverName)
|
||||
}
|
||||
|
||||
// Reject unknown tasksMode values up front so a typo (e.g. "alwasy")
|
||||
// fails loud here instead of being silently downgraded to "auto" by
|
||||
// the runtime parser. Comparison is case-insensitive to match
|
||||
// tools.ParseTaskMode.
|
||||
switch strings.ToLower(strings.TrimSpace(serverConfig.TasksMode)) {
|
||||
case "", "auto", "never", "always":
|
||||
// ok
|
||||
default:
|
||||
return fmt.Errorf("server %s: invalid tasksMode %q (expected one of: auto, never, always)", serverName, serverConfig.TasksMode)
|
||||
}
|
||||
|
||||
transport := serverConfig.GetTransportType()
|
||||
switch transport {
|
||||
case "stdio":
|
||||
|
||||
@@ -627,3 +627,92 @@ func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_NewFormat(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://my-mcp-server.com",
|
||||
"tasksMode": "always"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "always" {
|
||||
t.Errorf("expected TasksMode 'always', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_LegacyFormat(t *testing.T) {
|
||||
// tasksMode also recognised in the legacy unmarshal path so users on
|
||||
// the older command/args shape can opt in without migrating.
|
||||
jsonData := `{
|
||||
"command": "npx",
|
||||
"args": ["@modelcontextprotocol/server-filesystem", "/path"],
|
||||
"tasksMode": "never"
|
||||
}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "never" {
|
||||
t.Errorf("expected TasksMode 'never', got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_TasksMode_DefaultEmpty(t *testing.T) {
|
||||
// When tasksMode is not set the field stays empty, which downstream
|
||||
// resolves to "auto" via tools.ParseTaskMode.
|
||||
jsonData := `{"type":"remote","url":"https://x.example"}`
|
||||
var cfg MCPServerConfig
|
||||
if err := json.Unmarshal([]byte(jsonData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
if cfg.TasksMode != "" {
|
||||
t.Errorf("expected default TasksMode to be empty, got %q", cfg.TasksMode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_TasksMode(t *testing.T) {
|
||||
t.Run("empty is valid", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example"},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("empty TasksMode should validate, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("known values are valid", func(t *testing.T) {
|
||||
for _, mode := range []string{"auto", "never", "always", "AUTO", " always "} {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"a": {Type: "remote", URL: "https://x.example", TasksMode: mode},
|
||||
},
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
t.Errorf("TasksMode=%q should validate, got %v", mode, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("typo is rejected with a clear error", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
MCPServers: map[string]MCPServerConfig{
|
||||
"buildbot": {Type: "remote", URL: "https://x.example", TasksMode: "alwasy"},
|
||||
},
|
||||
}
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("expected validation error for invalid TasksMode")
|
||||
}
|
||||
// Error must mention the server name AND the bad value so the
|
||||
// user knows where to look.
|
||||
msg := err.Error()
|
||||
if !strings.Contains(msg, "buildbot") || !strings.Contains(msg, `"alwasy"`) {
|
||||
t.Errorf("error %q should mention both server name and bad value", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -160,15 +160,6 @@ func rewriteSudoForStdin(command string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
|
||||
// This is stored in tool response metadata to signal the TUI to prompt for password.
|
||||
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
|
||||
|
||||
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
|
||||
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
|
||||
return resp.Metadata == SudoPasswordRequiredMetadata
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
|
||||
+6
-42
@@ -21,12 +21,9 @@ type Edit struct {
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
Path string `json:"path"`
|
||||
Edits []Edit `json:"edits"`
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
@@ -52,20 +49,12 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Description: "Edit a file by replacing exact text. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Path to the file to edit (relative or absolute)",
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
@@ -85,7 +74,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
Required: []string{"path", "edits"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -163,36 +152,11 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
if len(args.Edits) == 0 {
|
||||
return nil, fmt.Errorf("edits array is required and must not be empty")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
|
||||
+62
-44
@@ -389,9 +389,11 @@ func TestExecuteEdit_ExactMatch(t *testing.T) {
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "fmt.Println(\"hello\")",
|
||||
NewText: "fmt.Println(\"world\")",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -426,9 +428,11 @@ func TestExecuteEdit_ExactMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
target := lines[49]
|
||||
replacement := "REPLACED_LINE_50"
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: target,
|
||||
NewText: replacement,
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -470,9 +474,11 @@ func TestExecuteEdit_FuzzyMatch_TrailingWhitespace(t *testing.T) {
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "func foo() {\n\treturn 1\n}",
|
||||
NewText: "func foo() {\n\treturn 2\n}",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -519,9 +525,11 @@ func TestExecuteEdit_FuzzyMatch_DoesNotCorruptRest(t *testing.T) {
|
||||
search := strings.Repeat("x", 10) + "\n" + strings.Repeat("x", 10)
|
||||
// But this matches lines 1-2, 2-3, etc. — should fail due to ambiguity.
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: search,
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -546,9 +554,11 @@ func TestExecuteEdit_MultipleMatches_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "hello",
|
||||
NewText: "goodbye",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -575,9 +585,11 @@ func TestExecuteEdit_NoMatch_Fails(t *testing.T) {
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "nonexistent text",
|
||||
NewText: "replacement",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -601,9 +613,11 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
writeFileOrFail(t, path, "line1\r\nline2\r\nline3\r\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line2",
|
||||
NewText: "LINE2",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -622,8 +636,10 @@ func TestExecuteEdit_CRLFNormalization(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -636,9 +652,11 @@ func TestExecuteEdit_MissingPath(t *testing.T) {
|
||||
|
||||
func TestExecuteEdit_NonexistentFile(t *testing.T) {
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
Path: "/tmp/nonexistent_edit_test_file_12345.go",
|
||||
Edits: []Edit{{
|
||||
OldText: "x",
|
||||
NewText: "y",
|
||||
}},
|
||||
})
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, "")
|
||||
if err != nil {
|
||||
@@ -661,9 +679,11 @@ func TestExecuteEdit_DiffContainsHunkHeader(t *testing.T) {
|
||||
writeFileOrFail(t, path, strings.Join(lines, "\n")+"\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "line_10_content",
|
||||
NewText: "REPLACED",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -684,9 +704,11 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
writeFileOrFail(t, path, "old content\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
Path: path,
|
||||
Edits: []Edit{{
|
||||
OldText: "old content",
|
||||
NewText: "new content",
|
||||
}},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -905,18 +927,14 @@ func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
func TestExecuteEdit_EmptyEditsArray_Fails(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
@@ -924,10 +942,10 @@ func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
if !strings.Contains(resp.Content, "required") {
|
||||
t.Errorf("expected 'required' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// Package extbridge wires the public Kit SDK to the internal extensions
|
||||
// package. It exists so that cmd/ and internal/acpserver/ don't both
|
||||
// reimplement the same SDK→extension event/subagent conversions.
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
|
||||
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
|
||||
// (Type=="") for events that don't map to anything useful — callers should
|
||||
// drop those.
|
||||
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
|
||||
// the result/events back into the extension-facing types. The returned
|
||||
// handle is always nil — the SDK path runs synchronously and does not
|
||||
// expose a separate process handle. Callers that need non-blocking
|
||||
// behaviour should run this in their own goroutine.
|
||||
//
|
||||
// This function consolidates the previously-duplicated wiring in
|
||||
// cmd/root.go (interactive + runtime contexts) and
|
||||
// internal/acpserver/session.go.
|
||||
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: cfg.Prompt,
|
||||
Model: cfg.Model,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := SDKEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
cfg.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := k.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
}
|
||||
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
|
||||
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
|
||||
// saveManifest round-trip against a temp directory, ensuring an absent
|
||||
// manifest loads as empty and a saved manifest reads back identically.
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
err = installer.saveManifest(manifest, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
t.Fatalf("saveManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was written to expected path
|
||||
if _, err := os.Stat(manifestPath); err != nil {
|
||||
t.Fatalf("manifest file not created: %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
loaded, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
|
||||
// followed by removeFromManifest leaves the manifest in its original
|
||||
// (empty) state, using a temp-directory installer scope.
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
|
||||
t.Fatalf("addToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
|
||||
t.Fatalf("removeFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
manifest, err = installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindInManifest writes a manifest file directly to the path
|
||||
// resolved by the package-level manifestPathForScope helper and then
|
||||
// confirms FindInManifest locates the entry by identity (and returns
|
||||
// nil for a non-existent identity).
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
t.Setenv("XDG_DATA_HOME", tempDir)
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
// Write a manifest entry directly via the package-level path resolver
|
||||
// so FindInManifest (which uses manifestPathForScope) can read it back.
|
||||
manifestPath := manifestPathForScope(ScopeGlobal)
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
|
||||
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
|
||||
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
|
||||
@@ -2,22 +2,15 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
|
||||
@@ -72,6 +72,9 @@ type AgentSetupOptions struct {
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -229,6 +232,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: opts.MCPTaskConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package models
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"maps"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@@ -111,13 +112,30 @@ func NewModelsRegistry() *ModelsRegistry {
|
||||
}
|
||||
|
||||
// buildFromModelsDB converts models.dev provider data into our internal format.
|
||||
// It tries the on-disk cache first and falls back to the embedded database.
|
||||
// It starts from the compile-time embedded database and merges on-disk cached
|
||||
// data from `kit update-models` on top. Cached provider metadata replaces
|
||||
// embedded metadata, and model entries are merged with cached models taking
|
||||
// precedence. This means newly synced models are available while embedded
|
||||
// models that haven't been synced yet are still reachable.
|
||||
func buildFromModelsDB() map[string]ProviderInfo {
|
||||
// Try cached data first (from `kit update-models`)
|
||||
dbProviders, _ := LoadCachedProviders()
|
||||
if len(dbProviders) == 0 {
|
||||
// Fall back to compile-time embedded data
|
||||
dbProviders = loadEmbeddedProviders()
|
||||
// Start with compile-time embedded data as the base.
|
||||
dbProviders := loadEmbeddedProviders()
|
||||
if dbProviders == nil {
|
||||
dbProviders = make(ModelsDBProviders)
|
||||
}
|
||||
|
||||
// Merge on-disk cached data on top (cached takes precedence).
|
||||
if cached, _ := LoadCachedProviders(); len(cached) > 0 {
|
||||
for providerID, cp := range cached {
|
||||
if existing, ok := dbProviders[providerID]; ok {
|
||||
// Merge models: embedded base + cached overrides.
|
||||
mergedModels := make(map[string]modelsDBModel, len(existing.Models)+len(cp.Models))
|
||||
maps.Copy(mergedModels, existing.Models)
|
||||
maps.Copy(mergedModels, cp.Models)
|
||||
cp.Models = mergedModels
|
||||
}
|
||||
dbProviders[providerID] = cp
|
||||
}
|
||||
}
|
||||
|
||||
providers := make(map[string]ProviderInfo, len(dbProviders))
|
||||
|
||||
+39
-35
@@ -36,15 +36,17 @@ type Diagnostic struct {
|
||||
}
|
||||
|
||||
// LoadAll discovers and loads all prompt templates from standard locations
|
||||
// and any extra paths. Templates are loaded in order of precedence (lowest
|
||||
// to highest), with later templates overriding earlier ones of the same name.
|
||||
// and any extra paths. Templates are loaded in order of precedence (highest
|
||||
// to lowest); the first source to define a given name wins, later definitions
|
||||
// of the same name are dropped with a diagnostic.
|
||||
//
|
||||
// Discovery paths searched in order:
|
||||
// 1. Default templates (if IncludeDefaults)
|
||||
// 2. ~/.kit/prompts/ (global user templates)
|
||||
// 3. .kit/prompts/ (project-local templates)
|
||||
// 4. ConfigPaths (from configuration)
|
||||
// 5. ExtraPaths (explicit paths, highest precedence)
|
||||
// 2. ~/.kit/prompts/ (legacy global)
|
||||
// 3. $XDG_CONFIG_HOME/kit/prompts/ (XDG global, default ~/.config/kit/prompts/)
|
||||
// 4. <cwd>/.kit/prompts/ (project-local templates)
|
||||
// 5. ConfigPaths (from configuration)
|
||||
// 6. ExtraPaths (explicit paths, lowest precedence)
|
||||
func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
if opts.Cwd == "" {
|
||||
opts.Cwd, _ = os.Getwd()
|
||||
@@ -88,13 +90,21 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
addTemplates(defaults, "default")
|
||||
}
|
||||
|
||||
// 2. Global user templates: ~/.kit/prompts/
|
||||
globalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(globalDir); err == nil {
|
||||
// 2. Legacy global user templates: ~/.kit/prompts/
|
||||
legacyGlobalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(legacyGlobalDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
|
||||
// 3. Project-local templates: .kit/prompts/
|
||||
// 3. XDG global user templates: $XDG_CONFIG_HOME/kit/prompts/
|
||||
// Default: ~/.config/kit/prompts/. Aligns with extensions and skills.
|
||||
if xdgDir := GlobalDir(); xdgDir != "" && xdgDir != legacyGlobalDir {
|
||||
if templates, err := LoadFromDir(xdgDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Project-local templates: .kit/prompts/
|
||||
localDir := filepath.Join(opts.Cwd, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(localDir); err == nil {
|
||||
addTemplates(templates, "local")
|
||||
@@ -179,31 +189,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
|
||||
// It returns the deduplicated list and diagnostics for any collisions.
|
||||
// This is a standalone function for when you need to deduplicate an existing list.
|
||||
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
|
||||
seen := make(map[string]*PromptTemplate)
|
||||
var result []*PromptTemplate
|
||||
var diagnostics []Diagnostic
|
||||
|
||||
for _, tpl := range templates {
|
||||
if existing, ok := seen[tpl.Name]; ok {
|
||||
diagnostics = append(diagnostics, Diagnostic{
|
||||
Name: tpl.Name,
|
||||
KeptPath: existing.FilePath,
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: "duplicate template name (first-match-wins)",
|
||||
})
|
||||
} else {
|
||||
seen[tpl.Name] = tpl
|
||||
result = append(result, tpl)
|
||||
}
|
||||
}
|
||||
|
||||
return result, diagnostics
|
||||
}
|
||||
|
||||
// loadDefaultTemplates returns the built-in default templates.
|
||||
// These are embedded templates that ship with Kit.
|
||||
func loadDefaultTemplates() []*PromptTemplate {
|
||||
@@ -211,3 +196,22 @@ func loadDefaultTemplates() []*PromptTemplate {
|
||||
// For now, return an empty slice - users can define their own templates
|
||||
return nil
|
||||
}
|
||||
|
||||
// GlobalDir returns the XDG-aligned global prompts directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/prompts/. Returns an empty
|
||||
// string if the user's home directory cannot be resolved.
|
||||
//
|
||||
// This is the canonical location for user-wide prompt templates and aligns
|
||||
// with the discovery paths used for extensions ($XDG_CONFIG_HOME/kit/extensions/)
|
||||
// and skills ($XDG_CONFIG_HOME/kit/skills/).
|
||||
func GlobalDir() string {
|
||||
base := os.Getenv("XDG_CONFIG_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(base, "kit", "prompts")
|
||||
}
|
||||
|
||||
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
|
||||
// Kept messages must appear BEFORE post-compaction messages so the LLM
|
||||
// sees the conversation in chronological order. Otherwise the latest
|
||||
// post-compaction user message would be followed by an older kept user
|
||||
// message, breaking user/assistant alternation and causing the model to
|
||||
// respond as if the post-compaction turn never happened.
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
// Verify order: summary, M3 (kept), M4 (new)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeCwdForDir verifies the working-directory → session-directory
|
||||
// name encoding strips characters that are illegal on Windows (notably the
|
||||
// drive-letter colon, see issue #18) while preserving the previous output
|
||||
// for the typical Unix paths.
|
||||
func TestEncodeCwdForDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "unix absolute path",
|
||||
cwd: "/home/user/proj",
|
||||
want: "home--user--proj",
|
||||
},
|
||||
{
|
||||
name: "unix relative path",
|
||||
cwd: "proj/sub",
|
||||
want: "proj--sub",
|
||||
},
|
||||
{
|
||||
name: "windows drive root",
|
||||
cwd: `C:\test`,
|
||||
want: "C--test",
|
||||
},
|
||||
{
|
||||
name: "windows nested path",
|
||||
cwd: `C:\Users\User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows secondary drive",
|
||||
cwd: `S:\work\repo`,
|
||||
want: "S--work--repo",
|
||||
},
|
||||
{
|
||||
name: "windows mixed separators",
|
||||
cwd: `C:\Users/User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows other illegal chars stripped",
|
||||
cwd: `C:\a<b>c|d?e*f"g`,
|
||||
want: "C--abcdefg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := encodeCwdForDir(tc.cwd)
|
||||
if got != tc.want {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q, want %q", tc.cwd, got, tc.want)
|
||||
}
|
||||
// Encoded directory must never contain characters that are
|
||||
// illegal in Windows directory names.
|
||||
for _, bad := range []string{":", "<", ">", "\"", "|", "?", "*", "\\", "/"} {
|
||||
if strings.Contains(got, bad) {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q contains illegal char %q", tc.cwd, got, bad)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -97,6 +99,11 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
//
|
||||
// Per-file extraction is parallelized across a small worker pool because each
|
||||
// file requires a full JSONL scan to compute MessageCount and FirstMessage —
|
||||
// for users with many sessions this is the dominant cost of opening the
|
||||
// session picker.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -107,20 +114,47 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
return nil, fmt.Errorf("failed to read directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
var sessions []SessionInfo
|
||||
// Collect candidate paths first so we can parallelize the heavy work.
|
||||
paths := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
|
||||
continue
|
||||
}
|
||||
paths = append(paths, filepath.Join(dir, entry.Name()))
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
info, err := extractSessionInfo(path)
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
results := make([]*SessionInfo, len(paths))
|
||||
|
||||
// Worker pool sized to GOMAXPROCS, capped to avoid thrashing for tiny lists.
|
||||
workers := max(min(runtime.GOMAXPROCS(0), len(paths)), 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan int, len(paths))
|
||||
for range workers {
|
||||
wg.Go(func() {
|
||||
for i := range jobs {
|
||||
info, err := extractSessionInfo(paths[i])
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
results[i] = info
|
||||
}
|
||||
})
|
||||
}
|
||||
for i := range paths {
|
||||
jobs <- i
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
|
||||
sessions := make([]SessionInfo, 0, len(results))
|
||||
for i, info := range results {
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
// Clean up and skip empty sessions (no messages).
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
_ = os.Remove(paths[i])
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
|
||||
@@ -755,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
// If there is a compaction, inject the summary first, then the
|
||||
// preserved "kept" messages (chronologically before the compaction),
|
||||
// then the post-compaction messages (chronologically after).
|
||||
//
|
||||
// Order matters: the kept messages must come BEFORE the post-compaction
|
||||
// branch so the LLM sees the conversation in chronological order. If the
|
||||
// kept messages were appended last, the latest user message in the
|
||||
// current branch would be followed by an older kept user message,
|
||||
// breaking the strict user/assistant alternation that providers expect
|
||||
// and causing the model to respond as if the previous turn never
|
||||
// happened.
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -768,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
})
|
||||
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
// Step 1: collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not on the current branch (the compaction entry is a
|
||||
// new root with no parent), so we iterate tm.entries in append order
|
||||
// and stop when we reach the compaction entry itself.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
@@ -825,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
// Stop when we reach the compaction entry itself; messages
|
||||
// after it are collected from the branch walk below.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -860,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: collect entries on the current branch after the compaction
|
||||
// entry (these are post-compaction messages). The compaction entry
|
||||
// itself is skipped — its summary was already injected above.
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Summary already injected above.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
@@ -1030,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
// If there's a compaction, we collect IDs in the same order as
|
||||
// BuildContext: [summary placeholder, kept messages, post-compaction
|
||||
// messages]. This ordering must stay in sync with BuildContext so a
|
||||
// cut-point index can be mapped back to the correct entry ID.
|
||||
if lastCompaction != nil {
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
|
||||
// Iterate tm.entries in append order and stop at the compaction
|
||||
// entry to avoid double-counting post-compaction messages.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
@@ -1076,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
@@ -1100,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: IDs of entries after the compaction entry on the current
|
||||
// branch (post-compaction messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
@@ -1350,15 +1353,44 @@ func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[stri
|
||||
// --- Path conventions ---
|
||||
|
||||
// DefaultSessionDir returns the default session storage directory for a cwd.
|
||||
// Convention: ~/.kit/sessions/--<cwd-path>--/
|
||||
// Convention: ~/.kit/sessions/<encoded-cwd>, where path separators are
|
||||
// encoded as "--" with no leading or trailing dashes — e.g.
|
||||
// /home/user/proj becomes home--user--proj. See encodeCwdForDir for the
|
||||
// full encoding rules (including Windows path handling).
|
||||
func DefaultSessionDir(cwd string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
// Convert path separators to double dashes.
|
||||
safeCwd := strings.ReplaceAll(cwd, string(filepath.Separator), "--")
|
||||
return filepath.Join(home, ".kit", "sessions", encodeCwdForDir(cwd))
|
||||
}
|
||||
|
||||
// encodeCwdForDir converts a working-directory path into a single, filesystem-
|
||||
// safe directory name. Path separators are replaced with double dashes and
|
||||
// characters that are illegal in Windows directory names — most importantly
|
||||
// the colon that follows the drive letter (e.g. `C:\foo` → `C--foo`) — are
|
||||
// stripped. The result is identical to the previous Unix-only encoding for
|
||||
// paths that do not contain such characters, so existing session directories
|
||||
// are preserved.
|
||||
func encodeCwdForDir(cwd string) string {
|
||||
// Convert both `/` and `\` to double dashes so encoding is stable across
|
||||
// platforms and remains correct on Windows where `filepath.Separator`
|
||||
// would otherwise miss forward-slash style paths.
|
||||
safeCwd := strings.ReplaceAll(cwd, "\\", "--")
|
||||
safeCwd = strings.ReplaceAll(safeCwd, "/", "--")
|
||||
// Remove leading separator replacement.
|
||||
safeCwd = strings.TrimPrefix(safeCwd, "--")
|
||||
return filepath.Join(home, ".kit", "sessions", safeCwd)
|
||||
// Strip characters that are illegal in directory names on Windows
|
||||
// (`< > : " | ? *`). On Unix these characters are legal but rare in
|
||||
// practice; stripping them keeps the encoding portable.
|
||||
replacer := strings.NewReplacer(
|
||||
":", "",
|
||||
"<", "",
|
||||
">", "",
|
||||
"\"", "",
|
||||
"|", "",
|
||||
"?", "",
|
||||
"*", "",
|
||||
)
|
||||
return replacer.Replace(safeCwd)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type MCPConnection struct {
|
||||
client client.MCPClient
|
||||
serverName string
|
||||
serverConfig config.MCPServerConfig
|
||||
initResult *mcp.InitializeResult // captured at handshake; nil before initialize
|
||||
lastUsed time.Time
|
||||
isHealthy bool
|
||||
errorCount int
|
||||
@@ -262,7 +263,9 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
conn := &MCPConnection{}
|
||||
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if oauthEnabled && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
@@ -270,7 +273,7 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
if err := p.initializeClient(ctx, mcpClient, conn); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
@@ -280,15 +283,11 @@ func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName str
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
isHealthy: true,
|
||||
errorCount: 0,
|
||||
lastError: nil,
|
||||
}
|
||||
conn.client = mcpClient
|
||||
conn.serverName = serverName
|
||||
conn.serverConfig = serverConfig
|
||||
conn.lastUsed = time.Now()
|
||||
conn.isHealthy = true
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Created connection for %s", serverName))
|
||||
@@ -346,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig
|
||||
return stdioClient, nil
|
||||
}
|
||||
|
||||
// createSSEClient creates an SSE client
|
||||
// parseHeaders parses "Key: Value" header strings into a map.
|
||||
func parseHeaders(raw []string) map[string]string {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
headers := make(map[string]string)
|
||||
for _, header := range raw {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// buildOAuthConfig constructs a transport.OAuthConfig from the server config
|
||||
// and the pool's OAuth flow. Returns nil if OAuth is not applicable.
|
||||
func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) {
|
||||
if p.oauthFlow == nil || serverConfig.NoOAuth {
|
||||
return nil, nil
|
||||
}
|
||||
tokenStore, err := p.createTokenStore(serverConfig.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", err)
|
||||
}
|
||||
cfg := &transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
cfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
cfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
cfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.ClientOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
|
||||
// PubMed) set NoOAuth to skip dynamic client registration and token
|
||||
// exchange, which would otherwise fail with a 404.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
@@ -407,43 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.StreamableHTTPCOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithHTTPOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
@@ -484,8 +479,10 @@ func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenS
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
// initializeClient initializes the client and captures the server's
|
||||
// initialize result on the supplied connection so callers can later
|
||||
// inspect advertised capabilities (e.g. task support).
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, c client.MCPClient, conn *MCPConnection) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -495,12 +492,21 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
Name: "kit",
|
||||
Version: "1.0.0",
|
||||
}
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
|
||||
// Advertise task support so servers may return CreateTaskResult for
|
||||
// long-running tools/call requests instead of blocking the connection
|
||||
// until completion. The client is responsible for polling tasks/get and
|
||||
// tasks/result until the task reaches a terminal state.
|
||||
initRequest.Params.Capabilities = mcp.ClientCapabilities{
|
||||
Tasks: mcp.NewTasksCapability(),
|
||||
}
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
initResult, err := c.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
if conn != nil {
|
||||
conn.initResult = initResult
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug("[POOL] Initialized MCP client")
|
||||
@@ -615,6 +621,54 @@ func (c *MCPConnection) ServerName() string {
|
||||
return c.serverName
|
||||
}
|
||||
|
||||
// InitializeResult returns the result captured from the server's initialize
|
||||
// response, or nil if the connection was created before initialize completed.
|
||||
// Callers can inspect ServerCapabilities.Tasks to discover task-related
|
||||
// capability advertisements.
|
||||
func (c *MCPConnection) InitializeResult() *mcp.InitializeResult {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.initResult
|
||||
}
|
||||
|
||||
// SupportsToolTasks reports whether the server advertised support for
|
||||
// task-augmented tools/call requests. Returns false when the connection has
|
||||
// not yet completed initialization or when the server omitted task
|
||||
// capabilities.
|
||||
func (c *MCPConnection) SupportsToolTasks() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return supportsToolTasksFromInit(c.initResult)
|
||||
}
|
||||
|
||||
// supportsToolTasksFromInit reports whether the supplied InitializeResult
|
||||
// advertises task-augmented tools/call support. Extracted to a free function
|
||||
// for unit testing without standing up a connection.
|
||||
func supportsToolTasksFromInit(init *mcp.InitializeResult) bool {
|
||||
if init == nil || init.Capabilities.Tasks == nil {
|
||||
return false
|
||||
}
|
||||
req := init.Capabilities.Tasks.Requests
|
||||
if req == nil || req.Tools == nil {
|
||||
return false
|
||||
}
|
||||
return req.Tools.Call != nil
|
||||
}
|
||||
|
||||
// ServerSupportsToolTasks reports whether the named server's connection
|
||||
// advertises task-augmented tools/call support. Returns false when no
|
||||
// connection exists for the server or when the server didn't advertise the
|
||||
// capability.
|
||||
func (p *MCPConnectionPool) ServerSupportsToolTasks(serverName string) bool {
|
||||
p.mu.RLock()
|
||||
conn, ok := p.connections[serverName]
|
||||
p.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return conn.SupportsToolTasks()
|
||||
}
|
||||
|
||||
// GetClients returns a map of all MCP clients currently in the pool.
|
||||
// The map keys are server names and values are the corresponding MCP client instances.
|
||||
// The returned map is a copy and modifications won't affect the pool.
|
||||
|
||||
+229
-27
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"slices"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
log "github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
@@ -141,6 +143,11 @@ type MCPToolManager struct {
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
|
||||
// taskCfg controls task-augmented tools/call execution. The zero value
|
||||
// means: auto-detect server capability, no progress callback, default
|
||||
// poll/timeout.
|
||||
taskCfg MCPTaskConfig
|
||||
|
||||
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
||||
// Called with server name, tool count, and error (nil on success).
|
||||
onServerLoaded func(serverName string, toolCount int, err error)
|
||||
@@ -220,6 +227,21 @@ func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
||||
m.onToolsChanged = cb
|
||||
}
|
||||
|
||||
// SetTaskConfig sets the task-augmented tools/call configuration. Call
|
||||
// this before LoadTools / AddServer if you want the per-server mode
|
||||
// override and progress handler to take effect for the very first call.
|
||||
// Subsequent calls replace the previous configuration wholesale.
|
||||
func (m *MCPToolManager) SetTaskConfig(cfg MCPTaskConfig) {
|
||||
m.taskCfg = cfg
|
||||
}
|
||||
|
||||
// TaskConfig returns the manager's current task-augmented tools/call
|
||||
// configuration. The zero value means: defer to per-server config and
|
||||
// auto-detected capability, with no progress callback and default polling.
|
||||
func (m *MCPToolManager) TaskConfig() MCPTaskConfig {
|
||||
return m.taskCfg
|
||||
}
|
||||
|
||||
// AddServer connects to a new MCP server at runtime and loads its tools.
|
||||
// The server's tools are immediately available to the agent after this call.
|
||||
// Returns the number of tools loaded from the server.
|
||||
@@ -551,6 +573,14 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
|
||||
// checks, OAuth re-authorization, and connection error tracking.
|
||||
// The inputJSON parameter is the raw JSON arguments from the LLM.
|
||||
// Returns the result content, error flag, and any execution error.
|
||||
//
|
||||
// When the per-server TasksMode resolves to "always", or to "auto" and the
|
||||
// server advertised tasks/toolCalls capability during initialize, the call
|
||||
// is augmented with TaskParams. If the server elects to respond with a
|
||||
// CreateTaskResult the manager polls tasks/get / tasks/result until the
|
||||
// task reaches a terminal state, transparently presenting the final
|
||||
// CallToolResult-equivalent content to the agent layer. Context
|
||||
// cancellation triggers a best-effort tasks/cancel.
|
||||
func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*MCPToolResult, error) {
|
||||
m.mu.Lock()
|
||||
mapping, ok := m.toolMap[prefixedName]
|
||||
@@ -582,49 +612,221 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
return nil, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
||||
}
|
||||
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
callParams := mcp.CallToolParams{
|
||||
Name: mapping.originalName,
|
||||
Arguments: arguments,
|
||||
}
|
||||
|
||||
// Call the MCP tool using the original (unprefixed) name
|
||||
result, err := conn.client.CallTool(ctx, callRequest)
|
||||
if err != nil {
|
||||
// Handle OAuth re-authorization: token may have expired mid-session.
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, err); flowErr != nil {
|
||||
// Decide whether to augment the request with TaskParams. Modes:
|
||||
// never — never augment (synchronous-only).
|
||||
// always — always augment, even without server capability.
|
||||
// auto — augment only when the server advertised tasks/toolCalls.
|
||||
mode := m.resolveTaskMode(mapping.serverName, mapping.serverConfig)
|
||||
useTask := mode == MCPTaskModeAlways ||
|
||||
(mode == MCPTaskModeAuto && conn.SupportsToolTasks())
|
||||
if useTask {
|
||||
var ttl *int64
|
||||
if m.taskCfg.DefaultTTL > 0 {
|
||||
ms := m.taskCfg.DefaultTTL.Milliseconds()
|
||||
ttl = &ms
|
||||
}
|
||||
callParams.Task = &mcp.TaskParams{TTL: ttl}
|
||||
}
|
||||
|
||||
// Synchronous fast path: no task augmentation. Use the upstream client
|
||||
// helper which keeps content-block typing identical to historical
|
||||
// behaviour.
|
||||
if !useTask {
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Task-augmented path. Bypass the upstream CallTool helper because its
|
||||
// ParseCallToolResult requires a "content" field that is absent from a
|
||||
// CreateTaskResult.
|
||||
rawClient, ok := conn.client.(*client.Client)
|
||||
if !ok {
|
||||
// Older client implementations — fall back to the synchronous shape.
|
||||
callParams.Task = nil
|
||||
callRequest := mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
||||
}
|
||||
|
||||
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
// Retry the tool call after successful re-auth.
|
||||
result, err = conn.client.CallTool(ctx, callRequest)
|
||||
if err != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, err)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", err)
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, err)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
marshaledResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
// Server chose to answer synchronously — same shape as the no-task path.
|
||||
if callResult != nil {
|
||||
marshaledResult, mErr := json.Marshal(callResult)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: callResult.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Asynchronous task path: poll until terminal, then return the result.
|
||||
if taskResult == nil {
|
||||
return nil, errors.New("mcp tools/call returned neither result nor task")
|
||||
}
|
||||
final, pollErr := pollTaskUntilTerminal(
|
||||
ctx, rawClient, mapping.serverName, taskResult.Task,
|
||||
m.taskCfg, m.taskCfg.Progress,
|
||||
)
|
||||
if pollErr != nil {
|
||||
return nil, fmt.Errorf("task execution failed: %w", pollErr)
|
||||
}
|
||||
|
||||
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
||||
adapted := &mcp.CallToolResult{
|
||||
Content: final.Content,
|
||||
StructuredContent: final.StructuredContent,
|
||||
IsError: final.IsError,
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(adapted)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
IsError: final.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveTaskMode resolves the effective task mode for a given server.
|
||||
// Programmatic overrides via SetTaskConfig take precedence over the
|
||||
// per-server TasksMode in MCPServerConfig. Empty / unknown values map to
|
||||
// MCPTaskModeAuto.
|
||||
func (m *MCPToolManager) resolveTaskMode(name string, cfg config.MCPServerConfig) MCPTaskMode {
|
||||
if m.taskCfg.PerServerMode != nil {
|
||||
if v, ok := m.taskCfg.PerServerMode[name]; ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ParseTaskMode(cfg.TasksMode)
|
||||
}
|
||||
|
||||
// ListServerTasks queries tasks/list on the named server and returns the
|
||||
// active and recent tasks the server is willing to disclose. Errors are
|
||||
// returned untouched (callers commonly ignore METHOD_NOT_FOUND when the
|
||||
// server didn't advertise tasks/list capability).
|
||||
func (m *MCPToolManager) ListServerTasks(ctx context.Context, serverName string) ([]MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := c.ListTasks(ctx, mcp.ListTasksRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tasks/list on %s: %w", serverName, err)
|
||||
}
|
||||
out := make([]MCPTaskInfo, 0, len(res.Tasks))
|
||||
for _, t := range res.Tasks {
|
||||
out = append(out, taskFromMCP(serverName, t))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetServerTask queries tasks/get for a single task on the named server.
|
||||
func (m *MCPToolManager) GetServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, err
|
||||
}
|
||||
res, err := c.GetTask(ctx, mcp.GetTaskRequest{Params: mcp.GetTaskParams{TaskId: taskID}})
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, fmt.Errorf("tasks/get on %s: %w", serverName, err)
|
||||
}
|
||||
return taskFromMCP(serverName, res.Task), nil
|
||||
}
|
||||
|
||||
// CancelServerTask issues tasks/cancel for a task on the named server.
|
||||
// Returns the post-cancel task state when the server responded with one.
|
||||
func (m *MCPToolManager) CancelServerTask(ctx context.Context, serverName, taskID string) (MCPTaskInfo, error) {
|
||||
c, err := m.taskClient(serverName)
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, err
|
||||
}
|
||||
res, err := c.CancelTask(ctx, mcp.CancelTaskRequest{Params: mcp.CancelTaskParams{TaskId: taskID}})
|
||||
if err != nil {
|
||||
return MCPTaskInfo{}, fmt.Errorf("tasks/cancel on %s: %w", serverName, err)
|
||||
}
|
||||
return taskFromMCP(serverName, res.Task), nil
|
||||
}
|
||||
|
||||
// taskClient returns the *client.Client for a server. Tasks endpoints are
|
||||
// not part of the upstream MCPClient interface so callers must work with
|
||||
// the concrete client. Returns an error when the connection is missing
|
||||
// or backed by a non-standard client type.
|
||||
func (m *MCPToolManager) taskClient(serverName string) (*client.Client, error) {
|
||||
if m.connectionPool == nil {
|
||||
return nil, fmt.Errorf("no connection pool available")
|
||||
}
|
||||
clients := m.connectionPool.GetClients()
|
||||
raw, ok := clients[serverName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("MCP server %q not loaded", serverName)
|
||||
}
|
||||
c, ok := raw.(*client.Client)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("MCP server %q does not support task RPCs", serverName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetTools returns all loaded MCP tools from all configured MCP servers.
|
||||
// Tools are returned with their prefixed names (serverName__toolName) to ensure uniqueness.
|
||||
func (m *MCPToolManager) GetTools() []MCPTool {
|
||||
|
||||
@@ -0,0 +1,404 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// MCPTaskMode controls when the connection pool augments tools/call requests
|
||||
// with MCP task metadata. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks.
|
||||
type MCPTaskMode string
|
||||
|
||||
const (
|
||||
// MCPTaskModeAuto augments tools/call with task metadata only when the
|
||||
// server advertises tasks/toolCalls capability during initialize.
|
||||
MCPTaskModeAuto MCPTaskMode = "auto"
|
||||
// MCPTaskModeNever forces every tools/call to be issued synchronously
|
||||
// (no Task field in the request), regardless of server capability.
|
||||
MCPTaskModeNever MCPTaskMode = "never"
|
||||
// MCPTaskModeAlways always sets a Task field on the tools/call request,
|
||||
// even when the server didn't advertise task support. The server may
|
||||
// still respond synchronously; this just opts in unconditionally on
|
||||
// the client side.
|
||||
MCPTaskModeAlways MCPTaskMode = "always"
|
||||
)
|
||||
|
||||
// ParseTaskMode normalises a per-server tasks-mode string from
|
||||
// configuration. Empty input maps to MCPTaskModeAuto. Unknown values are
|
||||
// also treated as MCPTaskModeAuto so a stray config typo never breaks
|
||||
// existing flows.
|
||||
func ParseTaskMode(s string) MCPTaskMode {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "", "auto":
|
||||
return MCPTaskModeAuto
|
||||
case "never", "off", "disabled":
|
||||
return MCPTaskModeNever
|
||||
case "always", "force":
|
||||
return MCPTaskModeAlways
|
||||
default:
|
||||
return MCPTaskModeAuto
|
||||
}
|
||||
}
|
||||
|
||||
// MCPTaskInfo is the connection-layer view of an MCP Task. It mirrors the
|
||||
// upstream mcp.Task but exposes Go-native types and includes the originating
|
||||
// server name. SDK-level wrappers re-export this under public-facing names.
|
||||
type MCPTaskInfo struct {
|
||||
// Server is the configured MCP server name this task lives on.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the task.
|
||||
TaskID string
|
||||
// Status is the current task lifecycle state.
|
||||
Status mcp.TaskStatus
|
||||
// StatusMessage is an optional human-readable description.
|
||||
StatusMessage string
|
||||
// CreatedAt is the wall-clock time the task was created (best-effort
|
||||
// parsed from the server's ISO-8601 timestamp; zero on parse failure).
|
||||
CreatedAt time.Time
|
||||
// UpdatedAt is the wall-clock time the task was last updated (best-
|
||||
// effort parsed; zero on parse failure).
|
||||
UpdatedAt time.Time
|
||||
// TTL is the time-to-live the server intends to retain the task after
|
||||
// creation. Zero means the server did not advertise a TTL.
|
||||
TTL time.Duration
|
||||
// PollInterval is the suggested polling interval. Zero means use the
|
||||
// client's default.
|
||||
PollInterval time.Duration
|
||||
}
|
||||
|
||||
// MCPTaskProgress is emitted while the connection pool is waiting on a
|
||||
// task-augmented tool call. It provides minimal feedback for SDK consumers
|
||||
// that want to render progress widgets without subscribing to the full
|
||||
// notifications/tasks/status channel (Phase 2).
|
||||
type MCPTaskProgress struct {
|
||||
Server string
|
||||
TaskID string
|
||||
Status mcp.TaskStatus
|
||||
Message string
|
||||
}
|
||||
|
||||
// MCPTaskProgressHandler is invoked once after a task is accepted and on
|
||||
// every status transition observed by the polling loop. The final
|
||||
// invocation always carries a terminal status. Implementations must not
|
||||
// block; long work should be queued on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// MCPTaskConfig configures task-aware tool execution on the manager.
|
||||
// All fields are optional; the zero value disables progress callbacks and
|
||||
// applies sensible defaults.
|
||||
type MCPTaskConfig struct {
|
||||
// PerServerMode overrides the per-server TasksMode resolved from
|
||||
// MCPServerConfig. Keys are server names. Missing entries fall back
|
||||
// to the value from config. Used by SDK consumers that want to set
|
||||
// modes programmatically.
|
||||
PerServerMode map[string]MCPTaskMode
|
||||
|
||||
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
|
||||
// tools/call. Zero means omit the TTL — let the server pick its own.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// PollInterval is the fallback interval between tasks/get requests
|
||||
// when the server does not suggest one. Zero defaults to 1 second.
|
||||
PollInterval time.Duration
|
||||
|
||||
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
|
||||
MaxPollInterval time.Duration
|
||||
|
||||
// Timeout is the maximum wall-clock duration to wait for a task to
|
||||
// reach a terminal state. Zero defaults to 15 minutes. Independent
|
||||
// of the per-call context deadline; whichever fires first wins.
|
||||
Timeout time.Duration
|
||||
|
||||
// Progress, if non-nil, receives every status transition observed by
|
||||
// the polling loop.
|
||||
Progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
func (c MCPTaskConfig) resolved() MCPTaskConfig {
|
||||
if c.PollInterval <= 0 {
|
||||
c.PollInterval = 1 * time.Second
|
||||
}
|
||||
if c.MaxPollInterval <= 0 {
|
||||
c.MaxPollInterval = 5 * time.Second
|
||||
}
|
||||
if c.Timeout <= 0 {
|
||||
c.Timeout = 15 * time.Minute
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// requestIDCounter generates monotonically increasing JSON-RPC request IDs
|
||||
// for low-level tools/call invocations that bypass the upstream client's
|
||||
// ParseCallToolResult helper (necessary because that helper rejects task
|
||||
// responses for lacking a "content" field).
|
||||
//
|
||||
// The counter is process-wide rather than per-manager so multiple managers
|
||||
// or repeated calls within the same connection produce unique IDs.
|
||||
var requestIDCounter atomic.Int64
|
||||
|
||||
func nextRequestID() mcp.RequestId {
|
||||
return mcp.NewRequestId(requestIDCounter.Add(1))
|
||||
}
|
||||
|
||||
// callToolWithTask issues tools/call directly on the transport so we can
|
||||
// observe both response shapes:
|
||||
//
|
||||
// - {"content": [...], ...} — synchronous CallToolResult.
|
||||
// - {"task": {...}, ...} — asynchronous CreateTaskResult.
|
||||
//
|
||||
// On success exactly one of (callResult, taskResult) is non-nil. The
|
||||
// upstream client.CallTool helper parses the response with
|
||||
// mcp.ParseCallToolResult which requires a "content" field, so it cannot
|
||||
// be used for task-augmented calls.
|
||||
func callToolWithTask(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
params mcp.CallToolParams,
|
||||
) (callResult *mcp.CallToolResult, taskResult *mcp.CreateTaskResult, err error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodToolsCall),
|
||||
Params: params,
|
||||
}
|
||||
|
||||
resp, sendErr := tr.SendRequest(ctx, req)
|
||||
if sendErr != nil {
|
||||
return nil, nil, sendErr
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Peek at the raw result to decide which shape we got.
|
||||
var probe struct {
|
||||
Task json.RawMessage `json:"task"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
}
|
||||
raw := resp.Result
|
||||
if len(raw) == 0 {
|
||||
return nil, nil, errors.New("empty tools/call result")
|
||||
}
|
||||
if uErr := json.Unmarshal(raw, &probe); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode tools/call result: %w", uErr)
|
||||
}
|
||||
|
||||
if len(probe.Task) > 0 && string(probe.Task) != "null" {
|
||||
// Task-augmented response.
|
||||
var ct mcp.CreateTaskResult
|
||||
if uErr := json.Unmarshal(raw, &ct); uErr != nil {
|
||||
return nil, nil, fmt.Errorf("decode CreateTaskResult: %w", uErr)
|
||||
}
|
||||
return nil, &ct, nil
|
||||
}
|
||||
|
||||
// Synchronous response — defer to the upstream parser so content blocks
|
||||
// are typed correctly (TextContent, ImageContent, ResourceLink, etc.).
|
||||
cr, pErr := mcp.ParseCallToolResult(&raw)
|
||||
if pErr != nil {
|
||||
return nil, nil, fmt.Errorf("parse CallToolResult: %w", pErr)
|
||||
}
|
||||
return cr, nil, nil
|
||||
}
|
||||
|
||||
// pollTaskUntilTerminal blocks until the task reaches a terminal status,
|
||||
// the context is cancelled, or the configured timeout elapses. On
|
||||
// cancellation it best-effort issues tasks/cancel before returning.
|
||||
func pollTaskUntilTerminal(
|
||||
ctx context.Context,
|
||||
c *client.Client,
|
||||
serverName string,
|
||||
task mcp.Task,
|
||||
cfg MCPTaskConfig,
|
||||
progress MCPTaskProgressHandler,
|
||||
) (*mcp.TaskResultResult, error) {
|
||||
cfg = cfg.resolved()
|
||||
deadline := time.Now().Add(cfg.Timeout)
|
||||
|
||||
emit := func(status mcp.TaskStatus, msg string) {
|
||||
if progress != nil {
|
||||
progress(MCPTaskProgress{Server: serverName, TaskID: task.TaskId, Status: status, Message: msg})
|
||||
}
|
||||
}
|
||||
|
||||
emit(task.Status, task.StatusMessage)
|
||||
|
||||
current := task
|
||||
interval := cfg.PollInterval
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = time.Duration(*current.PollInterval) * time.Millisecond
|
||||
}
|
||||
if interval > cfg.MaxPollInterval {
|
||||
interval = cfg.MaxPollInterval
|
||||
}
|
||||
|
||||
for !current.Status.IsTerminal() {
|
||||
if time.Now().After(deadline) {
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, fmt.Errorf("task %s timed out after %s", current.TaskId, cfg.Timeout)
|
||||
}
|
||||
|
||||
// Wait between polls or abort early on context cancellation.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
cancelTaskBestEffort(c, current.TaskId)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
}
|
||||
|
||||
got, err := c.GetTask(ctx, mcp.GetTaskRequest{
|
||||
Params: mcp.GetTaskParams{TaskId: current.TaskId},
|
||||
})
|
||||
if err != nil {
|
||||
// Transient transport hiccup — propagate immediately. The
|
||||
// upstream agent layer treats this like any other tool error.
|
||||
return nil, fmt.Errorf("tasks/get failed: %w", err)
|
||||
}
|
||||
current = got.Task
|
||||
if current.Status != task.Status || current.StatusMessage != task.StatusMessage {
|
||||
emit(current.Status, current.StatusMessage)
|
||||
task = current
|
||||
}
|
||||
|
||||
// Honour any updated suggested poll interval, capped at the limit.
|
||||
if current.PollInterval != nil && *current.PollInterval > 0 {
|
||||
interval = min(time.Duration(*current.PollInterval)*time.Millisecond, cfg.MaxPollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// Terminal state reached. Emit one last progress event and fetch the
|
||||
// definitive tool result.
|
||||
emit(current.Status, current.StatusMessage)
|
||||
|
||||
if current.Status == mcp.TaskStatusCancelled {
|
||||
return nil, fmt.Errorf("task %s was cancelled", current.TaskId)
|
||||
}
|
||||
|
||||
res, err := fetchTaskResult(ctx, c, current.TaskId)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tasks/result failed: %w", err)
|
||||
}
|
||||
if current.Status == mcp.TaskStatusFailed && res != nil && !res.IsError {
|
||||
// The server flagged the task as failed but didn't decorate the
|
||||
// result. Surface the status message so the caller still sees a
|
||||
// useful tool-error.
|
||||
return nil, fmt.Errorf("task %s failed: %s", current.TaskId, current.StatusMessage)
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// fetchTaskResult issues tasks/result on the transport and parses the raw
|
||||
// response. The upstream client.TaskResult helper delegates to
|
||||
// mcp.ParseTaskResultResult which (as of mcp-go v0.51.0) looks for the
|
||||
// content array under a nested "result" key that never exists in the
|
||||
// wire format — leading to systematically empty Content. Doing the
|
||||
// parse here keeps the polling path working until that is fixed upstream.
|
||||
func fetchTaskResult(ctx context.Context, c *client.Client, taskID string) (*mcp.TaskResultResult, error) {
|
||||
tr := c.GetTransport()
|
||||
if tr == nil {
|
||||
return nil, errors.New("mcp client has no transport")
|
||||
}
|
||||
req := transport.JSONRPCRequest{
|
||||
JSONRPC: mcp.JSONRPC_VERSION,
|
||||
ID: nextRequestID(),
|
||||
Method: string(mcp.MethodTasksResult),
|
||||
Params: mcp.TaskResultParams{TaskId: taskID},
|
||||
}
|
||||
resp, err := tr.SendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Error != nil {
|
||||
return nil, resp.Error.AsError()
|
||||
}
|
||||
|
||||
// Manually decode the wire shape: {"_meta": {...}, "content": [...],
|
||||
// "structuredContent": ..., "isError": bool}.
|
||||
var shape struct {
|
||||
Meta json.RawMessage `json:"_meta"`
|
||||
Content []json.RawMessage `json:"content"`
|
||||
StructuredContent any `json:"structuredContent"`
|
||||
IsError bool `json:"isError"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Result, &shape); err != nil {
|
||||
return nil, fmt.Errorf("decode tasks/result: %w", err)
|
||||
}
|
||||
|
||||
out := &mcp.TaskResultResult{
|
||||
StructuredContent: shape.StructuredContent,
|
||||
IsError: shape.IsError,
|
||||
}
|
||||
if len(shape.Meta) > 0 && string(shape.Meta) != "null" {
|
||||
var metaMap map[string]any
|
||||
if err := json.Unmarshal(shape.Meta, &metaMap); err == nil {
|
||||
out.Meta = mcp.NewMetaFromMap(metaMap)
|
||||
}
|
||||
}
|
||||
for _, raw := range shape.Content {
|
||||
var contentMap map[string]any
|
||||
if err := json.Unmarshal(raw, &contentMap); err != nil {
|
||||
return nil, fmt.Errorf("decode content block: %w", err)
|
||||
}
|
||||
parsed, err := mcp.ParseContent(contentMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse content block: %w", err)
|
||||
}
|
||||
out.Content = append(out.Content, parsed)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// cancelTaskBestEffort issues tasks/cancel and ignores any error. Used on
|
||||
// context cancellation paths where the connection is already going away.
|
||||
func cancelTaskBestEffort(c *client.Client, taskID string) {
|
||||
if c == nil || taskID == "" {
|
||||
return
|
||||
}
|
||||
cancelCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_, _ = c.CancelTask(cancelCtx, mcp.CancelTaskRequest{
|
||||
Params: mcp.CancelTaskParams{TaskId: taskID},
|
||||
})
|
||||
}
|
||||
|
||||
// taskFromMCP converts a wire-format mcp.Task to our richer connection-
|
||||
// layer view. Unparseable timestamps surface as the zero time.
|
||||
func taskFromMCP(serverName string, t mcp.Task) MCPTaskInfo {
|
||||
out := MCPTaskInfo{
|
||||
Server: serverName,
|
||||
TaskID: t.TaskId,
|
||||
Status: t.Status,
|
||||
StatusMessage: t.StatusMessage,
|
||||
}
|
||||
if t.CreatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.CreatedAt); err == nil {
|
||||
out.CreatedAt = v
|
||||
}
|
||||
}
|
||||
if t.LastUpdatedAt != "" {
|
||||
if v, err := time.Parse(time.RFC3339, t.LastUpdatedAt); err == nil {
|
||||
out.UpdatedAt = v
|
||||
}
|
||||
}
|
||||
if t.TTL != nil {
|
||||
out.TTL = time.Duration(*t.TTL) * time.Millisecond
|
||||
}
|
||||
if t.PollInterval != nil {
|
||||
out.PollInterval = time.Duration(*t.PollInterval) * time.Millisecond
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,294 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTaskTestInProcessServer builds an in-process MCP server with a
|
||||
// task-augmented tool. The handler simulates work by sleeping briefly
|
||||
// before completing.
|
||||
//
|
||||
// Important: the upstream mcp-go server cancels the request context as
|
||||
// soon as the synchronous part of the tools/call returns (see
|
||||
// request_handler.go:85, `defer cancel()`). Task goroutines spawned by
|
||||
// AddTaskTool inherit that context and therefore see context.Canceled
|
||||
// the instant they start. Real-world transports (stdio, SSE, streamable
|
||||
// HTTP) don't trip this because they keep the connection — and a
|
||||
// background context — alive across the async work, but the in-process
|
||||
// transport runs entirely on the request goroutine. To test the polling
|
||||
// path realistically we detach from the request context here.
|
||||
func newTaskTestInProcessServer(t *testing.T, workDuration time.Duration) *server.MCPServer {
|
||||
t.Helper()
|
||||
srv := server.NewMCPServer("task-test", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
// list=true, cancel=true, toolCallTasks=true so capability detection,
|
||||
// cancellation, and tool augmentation all flow through.
|
||||
server.WithTaskCapabilities(true, true, true),
|
||||
)
|
||||
srv.AddTaskTool(
|
||||
mcp.Tool{
|
||||
Name: "long_running",
|
||||
Description: "Sleep, then echo the input string.",
|
||||
InputSchema: mcp.ToolInputSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]any{
|
||||
"msg": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
Execution: &mcp.ToolExecution{
|
||||
TaskSupport: mcp.TaskSupportRequired,
|
||||
},
|
||||
},
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CreateTaskResult, error) {
|
||||
msg, _ := req.GetArguments()["msg"].(string)
|
||||
// Detach from the request context so the task handler can
|
||||
// outlive the synchronous request — see comment above.
|
||||
time.Sleep(workDuration)
|
||||
_ = ctx
|
||||
return &mcp.CreateTaskResult{
|
||||
Content: []mcp.Content{
|
||||
mcp.TextContent{Type: "text", Text: "echo:" + msg},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
// newSyncOnlyServer is a server that does NOT advertise task capability.
|
||||
// Used to verify the auto-detect path keeps the sync semantics.
|
||||
func newSyncOnlyServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("sync-only", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required()),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("hi " + name), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestConnectionPoolAdvertisesTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
srv := newTaskTestInProcessServer(t, 0)
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: srv}
|
||||
|
||||
conn, err := pool.GetConnection(context.Background(), "tasks", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
|
||||
init := conn.InitializeResult()
|
||||
if init == nil {
|
||||
t.Fatal("InitializeResult is nil after GetConnection")
|
||||
}
|
||||
if init.Capabilities.Tasks == nil {
|
||||
t.Fatal("server did not advertise Tasks capability — initialize handshake regressed")
|
||||
}
|
||||
if !conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be true for a server with toolCallTasks=true")
|
||||
}
|
||||
if !pool.ServerSupportsToolTasks("tasks") {
|
||||
t.Error("ServerSupportsToolTasks should mirror the connection's value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolDetectsAbsentTaskCapability(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
cfg := config.MCPServerConfig{Type: "inprocess", InProcessServer: newSyncOnlyServer()}
|
||||
conn, err := pool.GetConnection(context.Background(), "sync", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection: %v", err)
|
||||
}
|
||||
if conn.SupportsToolTasks() {
|
||||
t.Error("SupportsToolTasks should be false for a server that didn't advertise the capability")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportsToolTasksFromInit(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in *mcp.InitializeResult
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"no tasks", &mcp.InitializeResult{}, false},
|
||||
{"tasks no requests", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: &mcp.TasksCapability{}},
|
||||
}, false},
|
||||
{"tasks with toolCalls", &mcp.InitializeResult{
|
||||
Capabilities: mcp.ServerCapabilities{Tasks: mcp.NewTasksCapability()},
|
||||
}, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := supportsToolTasksFromInit(tc.in); got != tc.want {
|
||||
t.Errorf("supportsToolTasksFromInit() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTaskMode(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want MCPTaskMode
|
||||
}{
|
||||
{"", MCPTaskModeAuto},
|
||||
{"auto", MCPTaskModeAuto},
|
||||
{"AUTO", MCPTaskModeAuto},
|
||||
{"never", MCPTaskModeNever},
|
||||
{"off", MCPTaskModeNever},
|
||||
{"always", MCPTaskModeAlways},
|
||||
{"force", MCPTaskModeAlways},
|
||||
{"bogus", MCPTaskModeAuto},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := ParseTaskMode(tc.in); got != tc.want {
|
||||
t.Errorf("ParseTaskMode(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolPollsTaskToCompletion(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 20 * time.Millisecond,
|
||||
MaxPollInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 50*time.Millisecond),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
res, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hello"}`)
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if res.IsError {
|
||||
t.Fatalf("expected non-error result, got %s", res.Content)
|
||||
}
|
||||
if !strings.Contains(res.Content, "echo:hello") {
|
||||
t.Errorf("expected result to contain 'echo:hello', got %s", res.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolHonorsNeverMode(t *testing.T) {
|
||||
// Even though the server advertises tasks/toolCalls, "never" should
|
||||
// keep the call synchronous. Since the tool is TaskSupportRequired,
|
||||
// the server returns an error rather than running it sync — we just
|
||||
// verify the error surfaces (not a poll-loop hang).
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PerServerMode: map[string]MCPTaskMode{"tasks": MCPTaskModeNever},
|
||||
Timeout: 2 * time.Second,
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// We don't care which way the server fails the sync call; we just want
|
||||
// to confirm we didn't hang in the polling loop and didn't panic.
|
||||
_, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"x"}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected an error when forcing sync execution of a task-required tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolEmitsProgress(t *testing.T) {
|
||||
var statuses []mcp.TaskStatus
|
||||
mgr := NewMCPToolManager()
|
||||
mgr.SetTaskConfig(MCPTaskConfig{
|
||||
PollInterval: 10 * time.Millisecond,
|
||||
MaxPollInterval: 25 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
Progress: func(p MCPTaskProgress) {
|
||||
statuses = append(statuses, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 30*time.Millisecond),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if _, err := mgr.ExecuteTool(ctx, "tasks__long_running", `{"msg":"hi"}`); err != nil {
|
||||
t.Fatalf("ExecuteTool: %v", err)
|
||||
}
|
||||
if len(statuses) == 0 {
|
||||
t.Fatal("expected at least one progress event")
|
||||
}
|
||||
last := statuses[len(statuses)-1]
|
||||
if !last.IsTerminal() {
|
||||
t.Errorf("last progress event should be terminal, got %q", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListGetCancelMCPTasksOnLoadedServer(t *testing.T) {
|
||||
mgr := NewMCPToolManager()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTaskTestInProcessServer(t, 0),
|
||||
}
|
||||
if _, err := mgr.AddServer(context.Background(), "tasks", cfg); err != nil {
|
||||
t.Fatalf("AddServer: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// tasks/list — no in-flight tasks yet, so we just verify the call
|
||||
// succeeds and returns an empty slice (or any slice; the exact length
|
||||
// depends on server retention policy).
|
||||
if _, err := mgr.ListServerTasks(ctx, "tasks"); err != nil {
|
||||
t.Errorf("ListServerTasks: %v", err)
|
||||
}
|
||||
|
||||
// Unknown server should error cleanly without panicking.
|
||||
if _, err := mgr.GetServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("GetServerTask on unknown server should error")
|
||||
}
|
||||
if _, err := mgr.CancelServerTask(ctx, "unknown", "abc"); err == nil {
|
||||
t.Error("CancelServerTask on unknown server should error")
|
||||
}
|
||||
}
|
||||
@@ -28,15 +28,6 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoBorder returns a renderingOption that disables all borders on the
|
||||
// block, rendering content with only padding.
|
||||
func WithNoBorder() renderingOption {
|
||||
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackground returns a renderingOption that sets the background color
|
||||
// for the entire block. The color parameter accepts any color.Color value,
|
||||
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
|
||||
func WithBackground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.background = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
}
|
||||
}
|
||||
|
||||
// renderContentBlock renders content with configurable styling options
|
||||
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
|
||||
renderer := &blockRenderer{
|
||||
|
||||
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
|
||||
@@ -161,6 +161,12 @@ var SlashCommands = []SlashCommand{
|
||||
Category: "Navigation",
|
||||
Aliases: []string{"/r"},
|
||||
},
|
||||
{
|
||||
Name: "/copy",
|
||||
Description: "Copy the last message to the system clipboard",
|
||||
Category: "System",
|
||||
Aliases: []string{"/cp"},
|
||||
},
|
||||
{
|
||||
Name: "/export",
|
||||
Description: "Export session (JSONL by default, or /export path.jsonl)",
|
||||
@@ -199,18 +205,6 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
names = append(names, cmd.Name)
|
||||
names = append(names, cmd.Aliases...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ExtensionCommand is a slash command registered by an extension. Unlike
|
||||
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
|
||||
// extension commands carry their own Execute callback.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
}
|
||||
|
||||
// Format the message to include all the debug info in a structured way
|
||||
var formattedMessage string
|
||||
|
||||
// Check if this is a multi-line debug output (like connection info)
|
||||
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
|
||||
// Extract the tag and content
|
||||
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
|
||||
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Add appropriate emoji based on the message content
|
||||
if strings.Contains(content, "Creating new connection") {
|
||||
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
|
||||
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
|
||||
} else if strings.Contains(content, "Reusing") {
|
||||
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
|
||||
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
|
||||
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
|
||||
} else if strings.Contains(content, "closed") {
|
||||
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Failed to close") {
|
||||
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
|
||||
} else {
|
||||
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
|
||||
// Use the CLI's debug message rendering
|
||||
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
@@ -25,17 +25,6 @@ type TextMessageItem struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
@@ -316,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -88,13 +88,9 @@ func formatToolParams(toolArgs string, maxWidth int) string {
|
||||
}
|
||||
|
||||
bodyKeys := map[string]bool{
|
||||
"content": true,
|
||||
"old_text": true,
|
||||
"new_text": true,
|
||||
"oldText": true,
|
||||
"newText": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
"content": true,
|
||||
"edits": true,
|
||||
"todos": true,
|
||||
}
|
||||
var remaining []string
|
||||
for key, val := range params {
|
||||
|
||||
+225
-15
@@ -129,8 +129,18 @@ type AppController interface {
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
// [Skills] section. Built by the CLI layer from the SDK's []*kit.Skill.
|
||||
type SkillItem struct {
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Source string // "project" or "user" (global).
|
||||
Description string // Short summary used in autocomplete and help.
|
||||
}
|
||||
|
||||
// ExtensionItem holds display metadata about a loaded extension for the
|
||||
// startup [Extensions] section. Built by the CLI layer from the SDK's
|
||||
// []kit.ExtensionInfo.
|
||||
type ExtensionItem struct {
|
||||
Name string // Extension display name (filename without .go extension).
|
||||
Path string // Absolute path to the extension's .go file.
|
||||
Source string // "project" or "user" (global).
|
||||
}
|
||||
|
||||
@@ -363,6 +373,16 @@ type AppModelOptions struct {
|
||||
// watcher detects changes. May be nil if skill hot-reload is not needed.
|
||||
GetSkillItems func() []SkillItem
|
||||
|
||||
// ExtensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section. Each entry shows the filename of an extension that was
|
||||
// discovered and loaded (global, project-local, or explicit).
|
||||
ExtensionItems []ExtensionItem
|
||||
|
||||
// GetExtensionItems, if non-nil, returns the current extension items.
|
||||
// Called on extension hot-reload to refresh the list. May be nil if no
|
||||
// extensions are loaded.
|
||||
GetExtensionItems func() []ExtensionItem
|
||||
|
||||
// MCPToolCount is the number of tools loaded from external MCP servers.
|
||||
MCPToolCount int
|
||||
|
||||
@@ -607,6 +627,14 @@ type AppModel struct {
|
||||
// skill list after content hot-reload. May be nil.
|
||||
getSkillItems func() []SkillItem
|
||||
|
||||
// extensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section (filenames only).
|
||||
extensionItems []ExtensionItem
|
||||
|
||||
// getExtensionItems returns the current extension items. Used to refresh
|
||||
// the list after extension hot-reload. May be nil.
|
||||
getExtensionItems func() []ExtensionItem
|
||||
|
||||
// mcpToolCount and extensionToolCount track tool counts by source for
|
||||
// the startup info display.
|
||||
mcpToolCount int
|
||||
@@ -860,6 +888,8 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
m.contextPaths = opts.ContextPaths
|
||||
m.skillItems = opts.SkillItems
|
||||
m.getSkillItems = opts.GetSkillItems
|
||||
m.extensionItems = opts.ExtensionItems
|
||||
m.getExtensionItems = opts.GetExtensionItems
|
||||
m.mcpToolCount = opts.MCPToolCount
|
||||
m.extensionToolCount = opts.ExtensionToolCount
|
||||
m.startupExtensionMessages = opts.StartupExtensionMessages
|
||||
@@ -912,6 +942,20 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge skills into autocomplete as /skill:<name> commands. Skills accept
|
||||
// optional trailing args, so HasArgs is true — Enter populates the input
|
||||
// with "/skill:name " rather than auto-submitting.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.SkillItems) > 0 {
|
||||
for _, s := range opts.SkillItems {
|
||||
ic.commands = append(ic.commands, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Merge MCP prompts into autocomplete as /<server>:<prompt> commands.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.MCPPrompts) > 0 {
|
||||
for _, p := range opts.MCPPrompts {
|
||||
@@ -1014,8 +1058,21 @@ func (m *AppModel) AddStartupMessageToScrollList() {
|
||||
pairs = append(pairs, [2]string{"Skills", strings.Join(names, ", ")})
|
||||
}
|
||||
|
||||
// Extension tool count (only shown when > 0).
|
||||
if m.extensionToolCount > 0 {
|
||||
// Extensions — listed by filename. Each extension shows its basename
|
||||
// without the .go suffix, matching the [Skills] section's style.
|
||||
if len(m.extensionItems) > 0 {
|
||||
names := make([]string, len(m.extensionItems))
|
||||
for i, ei := range m.extensionItems {
|
||||
names[i] = ei.Name
|
||||
}
|
||||
value := strings.Join(names, ", ")
|
||||
if m.extensionToolCount > 0 {
|
||||
value += fmt.Sprintf(" (%d tools)", m.extensionToolCount)
|
||||
}
|
||||
pairs = append(pairs, [2]string{"Extensions", value})
|
||||
} else if m.extensionToolCount > 0 {
|
||||
// Fallback: tool count only (extensions registered tools but the CLI
|
||||
// did not provide ExtensionItems for some reason).
|
||||
pairs = append(pairs, [2]string{"Extensions", fmt.Sprintf("%d tools", m.extensionToolCount)})
|
||||
}
|
||||
|
||||
@@ -1251,7 +1308,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.scrollList.autoScroll = false
|
||||
case tea.MouseWheelDown:
|
||||
m.scrollList.ScrollBy(scrollLines)
|
||||
if m.scrollList.AtBottom() {
|
||||
// Only re-enable auto-scroll when the user is not actively
|
||||
// selecting text. Otherwise a wheel-down during a drag-select
|
||||
// would re-arm GotoBottom on the next stream chunk, shifting
|
||||
// the highlighted row out from under the cursor.
|
||||
if m.scrollList.AtBottom() && !m.scrollList.IsMouseDown() {
|
||||
m.scrollList.autoScroll = true
|
||||
}
|
||||
}
|
||||
@@ -1259,9 +1320,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// ── Mouse click selection (crush-style character-level) ──────────────────
|
||||
case tea.MouseClickMsg:
|
||||
if msg.Button == tea.MouseLeft {
|
||||
// Calculate viewport-relative coordinates.
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
// Compute the scrollback origin from the current frame's layout
|
||||
// rather than the stale cached value from the previous View().
|
||||
// scrollbackYOffset/scrollList.height are only refreshed inside
|
||||
// View() and lag behind any state change that resized the header
|
||||
// (extension widgets, warning rows, etc.) since the last render.
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
// Clear any previous selection on a new click.
|
||||
// HandleMouseDown will set up new selection state.
|
||||
if m.scrollList.HandleMouseDown(msg.X, viewY) {
|
||||
@@ -1272,8 +1338,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Mouse motion/drag for character-level selection ──────────────────────
|
||||
case tea.MouseMotionMsg:
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
m.scrollList.HandleMouseDrag(msg.X, viewY)
|
||||
}
|
||||
|
||||
@@ -1603,10 +1670,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Cancel timer expired ─────────────────────────────────────────────────
|
||||
case uicore.CancelTimerExpiredMsg:
|
||||
if m.canceling {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.canceling = false
|
||||
|
||||
// ── Ctrl+C reset timer expired ────────────────────────────────────────────
|
||||
case uicore.CtrlCResetMsg:
|
||||
if m.ctrlCPressedOnce {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.ctrlCPressedOnce = false
|
||||
|
||||
// ── Input submitted ──────────────────────────────────────────────────────
|
||||
@@ -2328,6 +2401,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
} else {
|
||||
m.refreshExtensionItems()
|
||||
m.printSystemMessage("Extensions reloaded.")
|
||||
}
|
||||
|
||||
@@ -3095,6 +3169,8 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
|
||||
return m.handleResumeCommand()
|
||||
case "/export":
|
||||
return m.handleExportCommand(args)
|
||||
case "/copy":
|
||||
return m.handleCopyCommand()
|
||||
case "/share":
|
||||
return m.handleShareCommand()
|
||||
case "/import":
|
||||
@@ -3395,13 +3471,56 @@ func (m *AppModel) refreshPromptTemplates() {
|
||||
}
|
||||
}
|
||||
|
||||
// refreshSkillItems reloads skill items from the provider callback.
|
||||
// Called on ContentReloadEvent.
|
||||
// refreshSkillItems reloads skill items from the provider callback and
|
||||
// updates the autocomplete entries. Called on ContentReloadEvent.
|
||||
func (m *AppModel) refreshSkillItems() {
|
||||
if m.getSkillItems == nil {
|
||||
return
|
||||
}
|
||||
m.skillItems = m.getSkillItems()
|
||||
newItems := m.getSkillItems()
|
||||
m.skillItems = newItems
|
||||
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
// Remove old Skills commands and add fresh ones.
|
||||
var kept []commands.SlashCommand
|
||||
for _, sc := range ic.commands {
|
||||
if sc.Category != "Skills" {
|
||||
kept = append(kept, sc)
|
||||
}
|
||||
}
|
||||
for _, s := range newItems {
|
||||
kept = append(kept, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
ic.commands = kept
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExtensionItems reloads extension items from the provider callback
|
||||
// so the [Extensions] startup section reflects the current set after a
|
||||
// hot-reload. Called from the extReloadResultMsg handler.
|
||||
func (m *AppModel) refreshExtensionItems() {
|
||||
if m.getExtensionItems == nil {
|
||||
return
|
||||
}
|
||||
m.extensionItems = m.getExtensionItems()
|
||||
}
|
||||
|
||||
// formatSkillDescription returns the autocomplete description for a skill,
|
||||
// prefixed with [project] or [user] so users can tell colliding names apart.
|
||||
func formatSkillDescription(s SkillItem) string {
|
||||
prefix := "[user]"
|
||||
if s.Source == "project" {
|
||||
prefix = "[project]"
|
||||
}
|
||||
if s.Description == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + " " + s.Description
|
||||
}
|
||||
|
||||
// refreshMCPPrompts reloads MCP prompts from the provider callback and
|
||||
@@ -3476,6 +3595,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**System:**\n" +
|
||||
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
|
||||
"- `/clear`: Clear message history\n" +
|
||||
"- `/copy`: Copy the last message to the system clipboard\n" +
|
||||
"- `/export [path]`: Export session as JSONL\n" +
|
||||
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
|
||||
"- `/reset-usage`: Reset usage statistics\n" +
|
||||
@@ -3712,7 +3832,12 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
}
|
||||
// Auto-scroll to bottom if enabled (iteratr pattern)
|
||||
// Don't call SetItems() - the slice reference hasn't changed
|
||||
if m.scrollList != nil {
|
||||
//
|
||||
// CRITICAL: never scroll the viewport while the user is actively
|
||||
// selecting text (mouse button held). Doing so shifts the
|
||||
// highlighted content out from under the cursor and produces the
|
||||
// off-by-N-row drift users see when copy-selecting during streaming.
|
||||
if m.scrollList != nil && !m.scrollList.IsMouseDown() {
|
||||
if m.scrollList.autoScroll {
|
||||
m.scrollList.GotoBottom()
|
||||
} else if m.scrollList.AtBottom() {
|
||||
@@ -3740,6 +3865,36 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
m.refreshContent()
|
||||
}
|
||||
|
||||
// currentScrollbackBounds returns the live (yOffset, viewportHeight) for the
|
||||
// scrollback region, computed from the current state — not from the cached
|
||||
// values populated inside View().
|
||||
//
|
||||
// scrollbackYOffset and scrollList.height are refreshed once per render, so
|
||||
// any state change that resizes the header (extension widget toggles,
|
||||
// warning rows, queued messages, etc.) leaves the cached values one frame
|
||||
// stale. Mouse click handlers in Update() can then place the cursor on the
|
||||
// wrong line, producing the off-by-N-row drift seen during copy-selection.
|
||||
//
|
||||
// This recomputes the header height by rendering it (cheap — the renderer
|
||||
// returns "" when no extension header is set) and recomputes the viewport
|
||||
// height the same way distributeHeight() does, so both inputs to the
|
||||
// y → (item, line) mapping are always current.
|
||||
func (m *AppModel) currentScrollbackBounds() (yOffset, viewportHeight int) {
|
||||
// Force a fresh layout if anything in Update() marked the state dirty;
|
||||
// otherwise scrollList.height still reflects the previous frame.
|
||||
if m.layoutDirty {
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = false
|
||||
}
|
||||
if headerView := m.renderHeaderFooter(m.getHeader); headerView != "" {
|
||||
yOffset = lipgloss.Height(headerView)
|
||||
}
|
||||
if m.scrollList != nil {
|
||||
viewportHeight = m.scrollList.height
|
||||
}
|
||||
return yOffset, viewportHeight
|
||||
}
|
||||
|
||||
// distributeHeight recalculates child component heights after a window resize,
|
||||
// queue change, widget update, or state transition, and propagates the computed
|
||||
// stream height to the StreamComponent.
|
||||
@@ -3812,7 +3967,20 @@ func (m *AppModel) distributeHeight() {
|
||||
headerFooterLines += lipgloss.Height(footerView)
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines, 0)
|
||||
// Account for transient warning rows that View() injects between the
|
||||
// scrollback and the separator. These flags are toggled by ESC/Ctrl+C
|
||||
// handlers; without subtracting them here the joined view exceeds
|
||||
// m.height by one line per active warning and the bottom of the screen
|
||||
// gets silently clipped — which in turn invalidates scrollbackYOffset.
|
||||
var warningLines int
|
||||
if m.canceling {
|
||||
warningLines++
|
||||
}
|
||||
if m.ctrlCPressedOnce {
|
||||
warningLines++
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines-warningLines, 0)
|
||||
|
||||
// In alt screen mode, give the calculated height to ScrollList instead of stream.
|
||||
// The stream component still exists but is embedded as the last item in scrollList.
|
||||
@@ -4236,6 +4404,48 @@ func (m *AppModel) handleNameCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCopyCommand copies the last user or assistant message to the system
|
||||
// clipboard. Skips transient system messages (e.g. /help output) so the user
|
||||
// gets the actual last conversational message.
|
||||
func (m *AppModel) handleCopyCommand() tea.Cmd {
|
||||
if len(m.messages) == 0 {
|
||||
m.printSystemMessage("No messages to copy.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
text string
|
||||
role string
|
||||
)
|
||||
for i := len(m.messages) - 1; i >= 0; i-- {
|
||||
switch msg := m.messages[i].(type) {
|
||||
case *TextMessageItem:
|
||||
if msg.role == "user" || msg.role == "assistant" {
|
||||
text = msg.content
|
||||
role = msg.role
|
||||
}
|
||||
case *StreamingMessageItem:
|
||||
if msg.role == "assistant" || msg.role == "reasoning" {
|
||||
text = msg.content.String()
|
||||
role = msg.role
|
||||
}
|
||||
}
|
||||
if text != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(text) == "" {
|
||||
m.printSystemMessage("No copyable message found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Copied last %s message to clipboard (%d chars).", role, len(text),
|
||||
))
|
||||
return clipboard.CopyToClipboard(text)
|
||||
}
|
||||
|
||||
// handleExportCommand exports the current session to a file.
|
||||
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
|
||||
//
|
||||
|
||||
@@ -19,28 +19,7 @@ import (
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
// Any @file tokens in the content are highlighted with the theme accent color.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
// Highlight @file tokens with accent color so file references are
|
||||
// visually distinct from surrounding prompt text.
|
||||
content = HighlightFileTokens(content, theme)
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
// UserBlock-related rendering helpers and herald typography.
|
||||
|
||||
// HighlightFileTokens wraps @file tokens in the given text with the theme
|
||||
// accent color so they stand out visually in rendered user messages.
|
||||
@@ -154,44 +133,6 @@ func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) strin
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ToolBlock renders a tool execution result with header and body.
|
||||
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
return styleMarginBottom(theme, fullContent)
|
||||
}
|
||||
|
||||
// styleMarginBottom applies a 1-line margin bottom using the theme.
|
||||
func styleMarginBottom(theme style.Theme, content string) string {
|
||||
return style.GetCachedStyles().MarginBottom1.Render(content)
|
||||
|
||||
@@ -4,30 +4,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// testTypography creates a herald Typography for tests.
|
||||
func testTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertLabel(herald.AlertTip, ""),
|
||||
herald.WithAlertIcon(herald.AlertTip, ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHighlightFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
|
||||
@@ -88,24 +67,25 @@ func TestHighlightFileTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserBlockHighlightsFileTokens(t *testing.T) {
|
||||
// TestHighlightFileTokensInjectsANSI verifies that HighlightFileTokens
|
||||
// preserves the original @file references in the output and wraps each
|
||||
// token with ANSI escape codes for the theme accent color.
|
||||
func TestHighlightFileTokensInjectsANSI(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
ty := testTypography(theme)
|
||||
|
||||
// A user message with @file tokens should contain ANSI escapes around the token.
|
||||
content := "refactor @main.go and @utils.go"
|
||||
result := UserBlock(content, 80, ty, theme)
|
||||
result := HighlightFileTokens(content, theme)
|
||||
|
||||
// The rendered output should contain both file references.
|
||||
// The output should still contain both file references.
|
||||
if !strings.Contains(result, "@main.go") {
|
||||
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @main.go, got:\n%s", result)
|
||||
}
|
||||
if !strings.Contains(result, "@utils.go") {
|
||||
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @utils.go, got:\n%s", result)
|
||||
}
|
||||
|
||||
// Verify ANSI codes are present (the tokens are styled).
|
||||
if !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
|
||||
t.Errorf("HighlightFileTokens output should contain ANSI escape codes for styled @file tokens")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,13 @@ func NewScrollList(width, height int) *ScrollList {
|
||||
}
|
||||
|
||||
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
|
||||
// the viewport will scroll to the bottom to show the latest content.
|
||||
// the viewport will scroll to the bottom to show the latest content — EXCEPT
|
||||
// when the user is actively selecting text (mouse button held), in which case
|
||||
// the scroll position is locked so the highlighted content stays under the
|
||||
// cursor. The pending bottom-scroll is deferred to MouseUp.
|
||||
func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
s.items = items
|
||||
if s.autoScroll {
|
||||
if s.autoScroll && !s.sel.MouseDown {
|
||||
s.GotoBottom()
|
||||
}
|
||||
}
|
||||
@@ -157,6 +160,10 @@ func (s *ScrollList) HandleMouseDown(x, y int) bool {
|
||||
// HandleMouseDrag handles mouse motion while button is held.
|
||||
// Updates the selection endpoint for character-level precision.
|
||||
// Returns true if selection was updated.
|
||||
//
|
||||
// Defensively disables auto-scroll on every drag update — even if the
|
||||
// MouseDown handler missed (e.g. click landed in viewport padding), any
|
||||
// active drag means the user is selecting and the viewport must not jump.
|
||||
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
if !s.sel.MouseDown {
|
||||
return false
|
||||
@@ -171,6 +178,9 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Hard-lock the viewport while dragging.
|
||||
s.autoScroll = false
|
||||
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
@@ -178,6 +188,13 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsMouseDown reports whether the user currently has the mouse button held
|
||||
// (i.e. a selection drag is in progress). Used by the parent model to avoid
|
||||
// re-enabling auto-scroll during streaming while the user is selecting.
|
||||
func (s *ScrollList) IsMouseDown() bool {
|
||||
return s.sel.MouseDown
|
||||
}
|
||||
|
||||
// HandleMouseUp handles mouse button release.
|
||||
// Returns true if there was an active selection.
|
||||
func (s *ScrollList) HandleMouseUp() bool {
|
||||
@@ -521,6 +538,21 @@ func (s *ScrollList) View() string {
|
||||
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
|
||||
item := s.items[idx]
|
||||
content := item.Render(s.width)
|
||||
|
||||
// Items that render to an empty string contribute zero height to
|
||||
// the viewport. This MUST match renderedHeight()'s semantics —
|
||||
// otherwise getItemAndLineAtY (which uses renderedHeight) treats
|
||||
// the item as 0 lines while View() emits one blank line via
|
||||
// strings.Split("", "\n") = [""], producing a 1-row downward
|
||||
// drift in mouse hit-testing per empty item between offsetIdx
|
||||
// and the cursor (most visibly streaming-reasoning items before
|
||||
// any reasoning has streamed, which extension widgets surface by
|
||||
// shrinking the scrollback).
|
||||
if content == "" {
|
||||
s.heightCache[item.ID()] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
// Refresh height cache from the actual render (authoritative).
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fakeItem is a deterministic MessageItem for ScrollList tests.
|
||||
type fakeItem struct {
|
||||
id string
|
||||
lines int
|
||||
}
|
||||
|
||||
func (f *fakeItem) ID() string { return f.id }
|
||||
func (f *fakeItem) Render(_ int) string {
|
||||
if f.lines <= 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, f.lines)
|
||||
for i := range parts {
|
||||
parts[i] = fmt.Sprintf("%s-line-%d", f.id, i)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
func (f *fakeItem) Height() int { return f.lines }
|
||||
|
||||
// makeItems builds n fake items of `lines` height each.
|
||||
func makeItems(n, lines int) []MessageItem {
|
||||
out := make([]MessageItem, n)
|
||||
for i := range out {
|
||||
out[i] = &fakeItem{id: fmt.Sprintf("item-%d", i), lines: lines}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TestScrollList_MouseDownPreventsAutoScroll verifies the core fix for the
|
||||
// copy-selection drift bug: while the user has the mouse button held
|
||||
// (drag-selecting), incoming content updates must NOT shift the viewport,
|
||||
// because doing so moves the highlighted content out from under the cursor.
|
||||
func TestScrollList_MouseDownPreventsAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2)) // 40 lines of content into a 10-line viewport
|
||||
// Capture the auto-scrolled-to-bottom position.
|
||||
startOffsetIdx := sl.offsetIdx
|
||||
startOffsetLine := sl.offsetLine
|
||||
|
||||
// User clicks somewhere in the visible area, starting a drag-select.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown should accept a click inside the viewport")
|
||||
}
|
||||
if !sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be true after HandleMouseDown")
|
||||
}
|
||||
|
||||
// New content arrives. With autoScroll still true, SetItems would
|
||||
// normally call GotoBottom() and shift the viewport. The fix should
|
||||
// suppress that while MouseDown is held.
|
||||
sl.SetItems(makeItems(30, 2)) // 60 lines now
|
||||
if sl.offsetIdx != startOffsetIdx || sl.offsetLine != startOffsetLine {
|
||||
t.Errorf("viewport scrolled during active drag: was (%d,%d), now (%d,%d)",
|
||||
startOffsetIdx, startOffsetLine, sl.offsetIdx, sl.offsetLine)
|
||||
}
|
||||
|
||||
// User releases the mouse — drag is over.
|
||||
sl.HandleMouseUp()
|
||||
if sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be false after HandleMouseUp")
|
||||
}
|
||||
|
||||
// After release, a fresh content update should resume auto-scrolling
|
||||
// (move the offset to track the new bottom).
|
||||
afterReleaseIdx := sl.offsetIdx
|
||||
afterReleaseLine := sl.offsetLine
|
||||
sl.SetItems(makeItems(50, 2))
|
||||
if sl.offsetIdx == afterReleaseIdx && sl.offsetLine == afterReleaseLine {
|
||||
t.Errorf("autoscroll did not resume after MouseUp: offset stuck at (%d,%d)",
|
||||
afterReleaseIdx, afterReleaseLine)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_DragDisablesAutoScroll verifies that any successful
|
||||
// HandleMouseDrag call clears autoScroll, even when HandleMouseDown didn't
|
||||
// observe it (e.g. a stale wheel-down event set it back to true mid-stream).
|
||||
func TestScrollList_DragDisablesAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2))
|
||||
|
||||
// Begin a selection.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
// Simulate an external code path that re-enabled autoScroll while
|
||||
// MouseDown is still held (the precise condition that caused drift).
|
||||
sl.autoScroll = true
|
||||
|
||||
// Drag motion should hard-lock the viewport again.
|
||||
if !sl.HandleMouseDrag(10, 4) {
|
||||
t.Fatalf("HandleMouseDrag failed")
|
||||
}
|
||||
if sl.autoScroll {
|
||||
t.Errorf("HandleMouseDrag must clear autoScroll to prevent mid-drag jumps")
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_SetItemsRespectsMouseDown is the most direct regression
|
||||
// test: even with autoScroll enabled and new content appended at the
|
||||
// bottom, SetItems must not move the viewport while a mouse drag is in
|
||||
// progress. This is what caused the "highlighting shifts by 1+ rows
|
||||
// during streaming" symptom reported by the user.
|
||||
func TestScrollList_SetItemsRespectsMouseDown(t *testing.T) {
|
||||
sl := NewScrollList(80, 5)
|
||||
sl.SetItems(makeItems(10, 2)) // 20 lines into a 5-line viewport
|
||||
// At bottom.
|
||||
preIdx, preLine := sl.offsetIdx, sl.offsetLine
|
||||
|
||||
// Hold mouse down (no actual drag needed).
|
||||
if !sl.HandleMouseDown(0, 0) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
|
||||
// Append several more items as if streaming. With the bug, each
|
||||
// SetItems would call GotoBottom and shift the offset.
|
||||
for n := 11; n <= 15; n++ {
|
||||
sl.SetItems(makeItems(n, 2))
|
||||
if sl.offsetIdx != preIdx || sl.offsetLine != preLine {
|
||||
t.Fatalf("viewport drifted during streaming with mouse held: "+
|
||||
"start=(%d,%d) now=(%d,%d) after adding item %d",
|
||||
preIdx, preLine, sl.offsetIdx, sl.offsetLine, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_EmptyItemsDoNotShiftMouseMapping is the regression test
|
||||
// for the second drift bug: items that render to "" must contribute the
|
||||
// same number of rows in View() (zero) as in renderedHeight(), or mouse
|
||||
// hit-testing drifts by one row per empty item between offsetIdx and the
|
||||
// cursor. This was surfaced by extension widgets (e.g. subagent-monitor)
|
||||
// that shrink the scrollback so empty streaming-reasoning items end up
|
||||
// in the visible window.
|
||||
//
|
||||
// Setup: 1 normal item + 1 empty item + 1 normal item. Click on the line
|
||||
// where the third item begins. With the bug, getItemAndLineAtY skips the
|
||||
// empty item (renderedHeight=0) and reports lineIdx pointing one row
|
||||
// past where View() actually painted that line.
|
||||
func TestScrollList_EmptyItemsDoNotShiftMouseMapping(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems([]MessageItem{
|
||||
&fakeItem{id: "a", lines: 2}, // viewY 0–1
|
||||
&fakeItem{id: "empty", lines: 0}, // renders "" — contributes 0 rows
|
||||
&fakeItem{id: "b", lines: 2}, // viewY 2–3
|
||||
})
|
||||
|
||||
// Render the viewport once so the cache reflects what View() actually
|
||||
// emits (this is the path that previously diverged from renderedHeight
|
||||
// for empty items).
|
||||
rendered := sl.View()
|
||||
lines := strings.Split(rendered, "\n")
|
||||
|
||||
// Sanity: View() must emit exactly height lines.
|
||||
if len(lines) != 10 {
|
||||
t.Fatalf("View() returned %d lines, want 10", len(lines))
|
||||
}
|
||||
// Item b's first line should appear at viewY=2, NOT viewY=3.
|
||||
if !strings.Contains(lines[2], "b-line-0") {
|
||||
t.Errorf("viewY=2 should render b-line-0 (empty item contributes 0 rows), got %q", lines[2])
|
||||
}
|
||||
|
||||
// Now the actual hit-test contract: clicking on viewY=2 must map to
|
||||
// item b line 0 — the same coordinate View() rendered there.
|
||||
idx, line := sl.getItemAndLineAtY(2)
|
||||
if idx != 2 || line != 0 {
|
||||
t.Errorf("getItemAndLineAtY(2) = (%d,%d), want (2,0)", idx, line)
|
||||
}
|
||||
|
||||
// And clicking on the second line of b (viewY=3) must map to b line 1.
|
||||
idx, line = sl.getItemAndLineAtY(3)
|
||||
if idx != 2 || line != 1 {
|
||||
t.Errorf("getItemAndLineAtY(3) = (%d,%d), want (2,1)", idx, line)
|
||||
}
|
||||
}
|
||||
@@ -230,8 +230,10 @@ func FindWordBoundaries(line string, col int) (startCol, endCol int) {
|
||||
|
||||
// HighlightLine applies reverse-video highlighting to a portion of a rendered
|
||||
// line (which may contain ANSI escape codes). startCol/endCol are in display
|
||||
// columns. If startCol == -1, the entire line is highlighted. If startCol ==
|
||||
// endCol, returns the line unchanged.
|
||||
// columns. If startCol == -1, the entire line is highlighted. If endCol ==
|
||||
// -1, the highlight runs from startCol to the end of the line (the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection).
|
||||
// If startCol == endCol, returns the line unchanged.
|
||||
//
|
||||
// Uses ultraviolet ScreenBuffer for cell-level ANSI manipulation.
|
||||
func HighlightLine(line string, startCol, endCol int) string {
|
||||
@@ -250,6 +252,16 @@ func HighlightLine(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (returned by IsLineInRange
|
||||
// for the first line of a multi-line selection). Without this branch,
|
||||
// the start line of a multi-line drag would never be highlighted —
|
||||
// the user perceives this as the selection being shifted one row down
|
||||
// from the cursor, especially when extension widgets shrink the
|
||||
// scrollback and make the start line land on a tall styled block.
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return line
|
||||
}
|
||||
@@ -296,6 +308,11 @@ func ExtractText(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (see HighlightLine).
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -357,6 +357,54 @@ func TestHighlightLine_NoSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlightLine_EndOfLineSentinel verifies that endCol=-1 is interpreted
|
||||
// as "highlight from startCol to end of line", matching the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection.
|
||||
//
|
||||
// Regression: without this contract, the start line of any multi-line drag
|
||||
// would silently fall through HighlightLine's startCol >= endCol guard and
|
||||
// render unstyled, making the selection appear to begin one row below the
|
||||
// cursor — the exact "tracking gets shifted" symptom users reported when
|
||||
// extension widgets shrank the scrollback enough that the click landed on a
|
||||
// styled tool-result block.
|
||||
func TestHighlightLine_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
result := HighlightLine(line, 0, -1)
|
||||
if result == line {
|
||||
t.Errorf("endCol=-1 should highlight from startCol to end of line; got unchanged input")
|
||||
}
|
||||
if len(result) <= len(line) {
|
||||
t.Errorf("highlighted result should be longer than plain input (ANSI codes added); got len=%d want > %d", len(result), len(line))
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractText_EndOfLineSentinel mirrors TestHighlightLine_EndOfLineSentinel
|
||||
// for the extraction path used by the clipboard copy.
|
||||
func TestExtractText_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
got := ExtractText(line, 7, -1)
|
||||
want := "World!"
|
||||
if got != want {
|
||||
t.Errorf("ExtractText(line, 7, -1) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLineInRange_StartLineSentinelHighlights composes IsLineInRange with
|
||||
// HighlightLine end-to-end: the start line of a multi-line, single-item
|
||||
// selection must actually emit highlight ANSI codes. This is the contract
|
||||
// the rendering path in scrolllist.View() relies on.
|
||||
func TestIsLineInRange_StartLineSentinelHighlights(t *testing.T) {
|
||||
r := Range{StartItemIdx: 5, EndItemIdx: 5, StartLine: 0, EndLine: 2, StartCol: 0, EndCol: 10}
|
||||
inRange, sc, ec := IsLineInRange(r, 5, 0)
|
||||
if !inRange {
|
||||
t.Fatalf("item 5 line 0 should be in range")
|
||||
}
|
||||
highlighted := HighlightLine("first line of selection", sc, ec)
|
||||
if highlighted == "first line of selection" {
|
||||
t.Errorf("first line of multi-line selection was not highlighted (sc=%d ec=%d)", sc, ec)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiClickDetection verifies the click counting logic.
|
||||
func TestMultiClickDetection(t *testing.T) {
|
||||
s := NewState()
|
||||
|
||||
@@ -211,106 +211,11 @@ func DefaultTheme() Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
|
||||
// padding, and appropriate width. Used for grouping related content in a visually
|
||||
// distinct box.
|
||||
func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Border).
|
||||
Padding(1, 2).
|
||||
MarginBottom(1)
|
||||
}
|
||||
|
||||
// IsDarkBackground returns the cached terminal background detection result.
|
||||
func IsDarkBackground() bool {
|
||||
return isDarkBg
|
||||
}
|
||||
|
||||
// StyleHeader creates a lipgloss style for primary headers using the theme's
|
||||
// primary color with bold text for emphasis and hierarchy.
|
||||
func StyleHeader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
|
||||
// secondary color with bold text, providing visual hierarchy below primary headers.
|
||||
func StyleSubheader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
|
||||
// and italic formatting, suitable for supplementary or less important information.
|
||||
func StyleMuted(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
}
|
||||
|
||||
// StyleSuccess creates a lipgloss style for success messages using green colors
|
||||
// with bold text to indicate successful operations or positive outcomes.
|
||||
func StyleSuccess(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleError creates a lipgloss style for error messages using red colors
|
||||
// with bold text to ensure visibility of problems or failures.
|
||||
func StyleError(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
|
||||
// colors with bold text to draw attention to potential issues or cautions.
|
||||
func StyleWarning(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleInfo creates a lipgloss style for informational messages using blue colors
|
||||
// with bold text for general notifications and status updates.
|
||||
func StyleInfo(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// CreateSeparator generates a horizontal separator line with the specified width,
|
||||
// character, and color. Useful for visually dividing sections of content in the UI.
|
||||
func CreateSeparator(width int, char string, c color.Color) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Width(width).
|
||||
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
|
||||
}
|
||||
|
||||
// CreateProgressBar generates a visual progress bar with filled and empty segments
|
||||
// based on the percentage complete. The bar uses Unicode block characters for smooth
|
||||
// appearance and theme colors to indicate progress.
|
||||
func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
filled := int(float64(width) * percentage / 100)
|
||||
empty := width - filled
|
||||
|
||||
filledBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Render(lipgloss.PlaceHorizontal(filled, lipgloss.Left, "█"))
|
||||
|
||||
emptyBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Render(lipgloss.PlaceHorizontal(empty, lipgloss.Left, "░"))
|
||||
|
||||
return filledBar + emptyBar
|
||||
}
|
||||
|
||||
// CreateBadge generates a styled badge or label with inverted colors (text on
|
||||
// colored background) for highlighting important tags, statuses, or categories.
|
||||
func CreateBadge(text string, c color.Color) string {
|
||||
|
||||
@@ -6,13 +6,6 @@ import (
|
||||
heraldmd "github.com/indaco/herald-md"
|
||||
)
|
||||
|
||||
// BaseStyle returns a new, empty lipgloss style that can be customized with
|
||||
// additional styling methods. This serves as the foundation for building more
|
||||
// complex styled components.
|
||||
func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// markdownTypographyCache holds the last-created Typography instance for
|
||||
// herald-md rendering. It is cached to avoid re-initialization on every
|
||||
// streaming flush tick. The cache is invalidated by SetTheme when the
|
||||
|
||||
@@ -543,12 +543,6 @@ func ApplyThemeWithoutSave(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshThemeRegistry re-scans the themes directory. Call after the user
|
||||
// drops a new file into ~/.config/kit/themes/.
|
||||
func RefreshThemeRegistry() {
|
||||
initThemeRegistry()
|
||||
}
|
||||
|
||||
// RegisterThemeFromConfig adds a theme to the runtime registry from an
|
||||
// extension's ThemeColorConfig (string hex pairs). Replaces any existing
|
||||
// entry with the same name. The theme is immediately available via
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
type ToolApprovalInput struct {
|
||||
textarea textarea.Model
|
||||
toolName string
|
||||
toolArgs string
|
||||
width int
|
||||
selected bool // true when "yes" is highlighted and false when "no" is
|
||||
approved bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &ToolApprovalInput{
|
||||
textarea: ta,
|
||||
toolName: toolName,
|
||||
toolArgs: toolArgs,
|
||||
width: width,
|
||||
selected: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyPressMsg:
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
t.approved = true
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "n", "N":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "left":
|
||||
t.selected = true
|
||||
return t, nil
|
||||
case "right":
|
||||
t.selected = false
|
||||
return t, nil
|
||||
case "enter":
|
||||
t.approved = t.selected
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "esc", "ctrl+c":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) View() tea.View {
|
||||
if t.done {
|
||||
return tea.NewView("we are done")
|
||||
}
|
||||
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(t.width - 1) // full width minus left border
|
||||
|
||||
// Style for the currently selected/highlighted option
|
||||
selectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true).
|
||||
Underline(true)
|
||||
|
||||
// Style for the unselected/unhighlighted option
|
||||
unselectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted)
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render("Allow tool execution"))
|
||||
view.WriteString("\n")
|
||||
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
|
||||
view.WriteString(details)
|
||||
view.WriteString("Allow tool execution: ")
|
||||
|
||||
var yesText, noText string
|
||||
if t.selected {
|
||||
yesText = selectedStyle.Render("[y]es")
|
||||
noText = unselectedStyle.Render("[n]o")
|
||||
} else {
|
||||
yesText = unselectedStyle.Render("[y]es")
|
||||
noText = selectedStyle.Render("[n]o")
|
||||
}
|
||||
view.WriteString(yesText + "/" + noText + "\n")
|
||||
|
||||
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
|
||||
}
|
||||
@@ -79,8 +79,7 @@ func renderToolBody(toolName, toolArgs, toolResult string, width int) string {
|
||||
// Edit tool — side-by-side diff
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// renderEditBody renders a side-by-side diff from old_text/new_text in toolArgs.
|
||||
// Supports both single-edit mode and multi-edit mode (edits array).
|
||||
// renderEditBody renders a side-by-side diff from the edits array in toolArgs.
|
||||
func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(toolArgs), &args); err != nil {
|
||||
@@ -90,35 +89,28 @@ func renderEditBody(toolArgs, toolResult string, width int) string {
|
||||
// Try to extract the starting line number from the unified diff in the result
|
||||
startLine := extractDiffStartLine(toolResult)
|
||||
|
||||
// Check for multi-edit mode (edits array)
|
||||
if editsArr, ok := args["edits"].([]any); ok && len(editsArr) > 0 {
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
editsArr, ok := args["edits"].([]any)
|
||||
if !ok || len(editsArr) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var results []string
|
||||
for _, edit := range editsArr {
|
||||
if e, ok := edit.(map[string]any); ok {
|
||||
oldText, _ := e["old_text"].(string)
|
||||
newText, _ := e["new_text"].(string)
|
||||
if oldText != "" || newText != "" {
|
||||
diff := renderDiffBlock(oldText, newText, startLine, width)
|
||||
if diff != "" {
|
||||
results = append(results, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Single-edit mode (legacy)
|
||||
oldText, _ := args["old_text"].(string)
|
||||
newText, _ := args["new_text"].(string)
|
||||
if oldText == "" && newText == "" {
|
||||
return ""
|
||||
if len(results) > 0 {
|
||||
return strings.Join(results, "\n")
|
||||
}
|
||||
|
||||
return renderDiffBlock(oldText, newText, startLine, width)
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractDiffStartLine parses the first @@ hunk header from a unified diff
|
||||
|
||||
+92
-2
@@ -190,6 +190,41 @@ msg, err := host.GetMCPPrompt(ctx, "server-name", "prompt-name", map[string]stri
|
||||
})
|
||||
```
|
||||
|
||||
### MCP Tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`. Cooperating servers can respond to `tools/call` with a
|
||||
`taskId` immediately; Kit then polls `tasks/get` / `tasks/result` until the
|
||||
task reaches a terminal state, and best-effort `tasks/cancel`s on context
|
||||
cancellation. Servers that don't advertise the capability keep their previous
|
||||
synchronous behaviour.
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
// Per-server mode: auto (default), never, or always.
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskTimeout: 15 * time.Minute, // total wall-clock cap
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s/%s: %s", p.Server, p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
// Inspect / cancel in-flight tasks
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
t, _ := host.GetMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
if !t.Status.IsTerminal() {
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", t.TaskID)
|
||||
}
|
||||
```
|
||||
|
||||
The progress handler fires once when a task is accepted and again on every
|
||||
observed status transition; the final invocation always carries a terminal
|
||||
status (`MCPTaskStatusCompleted`, `MCPTaskStatusFailed`, or
|
||||
`MCPTaskStatusCancelled`). Don't block in the handler — dispatch long work on
|
||||
a goroutine.
|
||||
|
||||
### Session Management
|
||||
|
||||
Maintain conversation context:
|
||||
@@ -206,9 +241,46 @@ response, _ := host.Prompt(ctx, "What's my name?")
|
||||
host.ClearSession()
|
||||
```
|
||||
|
||||
### Runtime Skills and Context Files
|
||||
|
||||
For multi-tenant chatbots, web services, or any host that needs per-user or
|
||||
per-session instructions, the SDK lets you add, remove, and replace skills and
|
||||
project context files (e.g. `AGENTS.md`) **after** Kit construction. Every
|
||||
mutation recomposes the system prompt and applies it to the agent so the next
|
||||
turn picks up the new instructions — no restart required.
|
||||
|
||||
```go
|
||||
// Add a programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Swap per-user AGENTS.md content fetched from your database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state when the user logs off.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set in one shot.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Readers (`GetSkills`, `GetContextFiles`) return snapshots, and every mutator
|
||||
is safe to call concurrently from multiple goroutines.
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports types so you don't need direct internal imports:
|
||||
The SDK re-exports message/session/MCP types so you don't need direct internal imports. Agent-configuration types are Kit-owned (not aliases) and use only SDK types in their signatures, so consumers never need to import the underlying LLM-provider package.
|
||||
|
||||
```go
|
||||
// Message types
|
||||
@@ -216,13 +288,28 @@ kit.Message, kit.MessageRole, kit.ContentPart
|
||||
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
|
||||
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
|
||||
|
||||
// LLM types — concrete Kit-owned structs, no external library dependency
|
||||
// LLM types — Kit-owned `LLM*` aliases over the underlying provider types,
|
||||
// so consumers never import the provider package directly
|
||||
kit.LLMMessage // {Role LLMMessageRole, Content string}
|
||||
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
|
||||
kit.LLMResponse // {Content, FinishReason, Usage}
|
||||
kit.LLMFilePart // {Filename, Data []byte, MediaType}
|
||||
|
||||
// Agent configuration — concrete Kit-owned structs and function types.
|
||||
// All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct
|
||||
// these without importing any LLM-provider package.
|
||||
kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control
|
||||
kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool
|
||||
kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress)
|
||||
kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string)
|
||||
kit.ToolExecutionHandler // func(toolCallID, toolName, toolArgs string, isStarting bool)
|
||||
kit.ToolResultHandler // func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
|
||||
kit.ResponseHandler // func(content string)
|
||||
kit.StreamingResponseHandler // func(content string)
|
||||
kit.ToolCallContentHandler // func(content string)
|
||||
kit.SpinnerFunc // func(fn func() error) error
|
||||
|
||||
// MCP OAuth types
|
||||
kit.MCPServer // *server.MCPServer for in-process MCP transport
|
||||
kit.MCPServerConfig // Configuration for an MCP server (stdio, SSE, or in-process)
|
||||
@@ -262,6 +349,9 @@ msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
- `ClearSession()` - Clear conversation history
|
||||
- `GetSessionPath()` - Get session file path
|
||||
- `GetSessionID()` - Get session UUID
|
||||
- `AddSkill(*Skill)` / `LoadAndAddSkill(path)` / `RemoveSkill(name)` / `SetSkills([])` - Manage skills at runtime
|
||||
- `AddContextFile(*ContextFile)` / `AddContextFileContent(path, content)` / `LoadAndAddContextFile(path)` / `RemoveContextFile(path)` / `SetContextFiles([])` - Manage AGENTS.md-style context files at runtime
|
||||
- `RefreshSystemPrompt()` - Re-apply the composed system prompt to the agent
|
||||
- `Close()` - Clean up resources
|
||||
|
||||
### Options
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
)
|
||||
|
||||
// TestAgentConfigToInternal verifies that the SDK-side AgentConfig converts
|
||||
// faithfully to the internal agent.AgentConfig representation, preserving
|
||||
// every field consumed by the internal agent layer.
|
||||
//
|
||||
// Regression test for https://github.com/mark3labs/kit/issues/30.
|
||||
func TestAgentConfigToInternal(t *testing.T) {
|
||||
t.Run("nil receiver returns nil", func(t *testing.T) {
|
||||
var c *AgentConfig
|
||||
if got := c.toInternal(); got != nil {
|
||||
t.Errorf("nil.toInternal() = %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scalar fields round-trip", func(t *testing.T) {
|
||||
c := &AgentConfig{
|
||||
SystemPrompt: "sys",
|
||||
MaxSteps: 7,
|
||||
StreamingEnabled: true,
|
||||
DisableCoreTools: true,
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got == nil {
|
||||
t.Fatal("toInternal() = nil")
|
||||
}
|
||||
if got.SystemPrompt != "sys" {
|
||||
t.Errorf("SystemPrompt = %q, want %q", got.SystemPrompt, "sys")
|
||||
}
|
||||
if got.MaxSteps != 7 {
|
||||
t.Errorf("MaxSteps = %d, want 7", got.MaxSteps)
|
||||
}
|
||||
if !got.StreamingEnabled {
|
||||
t.Error("StreamingEnabled = false, want true")
|
||||
}
|
||||
if !got.DisableCoreTools {
|
||||
t.Error("DisableCoreTools = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool slices propagate without conversion", func(t *testing.T) {
|
||||
// Tool is a type alias for the underlying LLM-tool type, so the
|
||||
// SDK []Tool and internal []fantasy.AgentTool slices share the
|
||||
// same backing array after conversion.
|
||||
tool := NewTool[struct{}]("noop", "noop", nil)
|
||||
c := &AgentConfig{
|
||||
CoreTools: []Tool{tool},
|
||||
ExtraTools: []Tool{tool, tool},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if len(got.CoreTools) != 1 {
|
||||
t.Errorf("CoreTools len = %d, want 1", len(got.CoreTools))
|
||||
}
|
||||
if len(got.ExtraTools) != 2 {
|
||||
t.Errorf("ExtraTools len = %d, want 2", len(got.ExtraTools))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool wrapper is invoked through internal config", func(t *testing.T) {
|
||||
called := false
|
||||
c := &AgentConfig{
|
||||
ToolWrapper: func(in []Tool) []Tool {
|
||||
called = true
|
||||
return in
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.ToolWrapper == nil {
|
||||
t.Fatal("internal ToolWrapper is nil")
|
||||
}
|
||||
_ = got.ToolWrapper(nil)
|
||||
if !called {
|
||||
t.Error("SDK ToolWrapper was not invoked through the internal config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnMCPServerLoaded propagates", func(t *testing.T) {
|
||||
var captured string
|
||||
wantErr := errors.New("boom")
|
||||
c := &AgentConfig{
|
||||
OnMCPServerLoaded: func(name string, _ int, _ error) {
|
||||
captured = name
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
got.OnMCPServerLoaded("svr", 3, wantErr)
|
||||
if captured != "svr" {
|
||||
t.Errorf("OnMCPServerLoaded captured = %q, want %q", captured, "svr")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DebugLogger propagates", func(t *testing.T) {
|
||||
dl := &fakeDebugLogger{enabled: true}
|
||||
c := &AgentConfig{DebugLogger: dl}
|
||||
got := c.toInternal()
|
||||
if got.DebugLogger == nil {
|
||||
t.Fatal("internal DebugLogger is nil")
|
||||
}
|
||||
if !got.DebugLogger.IsDebugEnabled() {
|
||||
t.Error("IsDebugEnabled = false, want true")
|
||||
}
|
||||
got.DebugLogger.LogDebug("hello")
|
||||
if len(dl.messages) != 1 || dl.messages[0] != "hello" {
|
||||
t.Errorf("messages = %v, want [hello]", dl.messages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) {
|
||||
c := &AgentConfig{
|
||||
MCPTaskConfig: MCPTaskConfig{
|
||||
PerServerMode: map[string]MCPTaskMode{
|
||||
"build-svr": MCPTaskModeAlways,
|
||||
},
|
||||
DefaultTTL: 30 * time.Second,
|
||||
PollInterval: 250 * time.Millisecond,
|
||||
MaxPollInterval: 2 * time.Second,
|
||||
Timeout: 5 * time.Minute,
|
||||
Progress: func(_ MCPTaskProgress) {},
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.MCPTaskConfig.DefaultTTL != 30*time.Second {
|
||||
t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL)
|
||||
}
|
||||
if got.MCPTaskConfig.PollInterval != 250*time.Millisecond {
|
||||
t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval)
|
||||
}
|
||||
if got.MCPTaskConfig.MaxPollInterval != 2*time.Second {
|
||||
t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval)
|
||||
}
|
||||
if got.MCPTaskConfig.Timeout != 5*time.Minute {
|
||||
t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout)
|
||||
}
|
||||
mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"]
|
||||
if !ok {
|
||||
t.Fatal("PerServerMode missing 'build-svr'")
|
||||
}
|
||||
if string(mode) != string(MCPTaskModeAlways) {
|
||||
t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways)
|
||||
}
|
||||
if got.MCPTaskConfig.Progress == nil {
|
||||
t.Fatal("internal Progress handler is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth and token store factories are wired", func(t *testing.T) {
|
||||
auth := &fakeAuthHandler{}
|
||||
tokenCalls := 0
|
||||
var tokenServer string
|
||||
factory := MCPTokenStoreFactory(func(server string) (MCPTokenStore, error) {
|
||||
tokenCalls++
|
||||
tokenServer = server
|
||||
return nil, nil
|
||||
})
|
||||
c := &AgentConfig{
|
||||
AuthHandler: auth,
|
||||
TokenStoreFactory: factory,
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.AuthHandler == nil {
|
||||
t.Fatal("internal AuthHandler is nil")
|
||||
}
|
||||
if got.TokenStoreFactory == nil {
|
||||
t.Fatal("internal TokenStoreFactory is nil")
|
||||
}
|
||||
_, _ = got.TokenStoreFactory("https://example.test")
|
||||
if tokenCalls != 1 {
|
||||
t.Errorf("token factory call count = %d, want 1", tokenCalls)
|
||||
}
|
||||
if tokenServer != "https://example.test" {
|
||||
t.Errorf("token factory server arg = %q", tokenServer)
|
||||
}
|
||||
if got.AuthHandler.RedirectURI() != "redirect" {
|
||||
t.Errorf("RedirectURI = %q, want %q", got.AuthHandler.RedirectURI(), "redirect")
|
||||
}
|
||||
})
|
||||
|
||||
// Compile-time check that the internal type is what we expect.
|
||||
//nolint:staticcheck // QF1011: explicit type asserts the conversion target.
|
||||
var _ *agent.AgentConfig = (&AgentConfig{}).toInternal()
|
||||
}
|
||||
|
||||
// fakeAuthHandler implements both kit.MCPAuthHandler and the structurally
|
||||
// identical tools.MCPAuthHandler used by the internal layer.
|
||||
type fakeAuthHandler struct{}
|
||||
|
||||
func (f *fakeAuthHandler) RedirectURI() string { return "redirect" }
|
||||
func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// fakeDebugLogger implements kit.DebugLogger for tests.
|
||||
type fakeDebugLogger struct {
|
||||
enabled bool
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) }
|
||||
func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled }
|
||||
@@ -0,0 +1,150 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Runtime context-file management (Issue #36)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// Project context files (AGENTS.md and friends) are normally auto-discovered
|
||||
// during Kit.New() and injected into the system prompt. SDK consumers building
|
||||
// multi-tenant chatbots often need to swap context per user/session at runtime
|
||||
// without restarting the agent. The methods below provide that surface.
|
||||
//
|
||||
// Every mutation recomposes the system prompt and applies it to the underlying
|
||||
// agent so the next turn sees the updated project context.
|
||||
|
||||
// AddContextFile registers a project context file (e.g. an AGENTS.md
|
||||
// equivalent) on this Kit instance. The file does not need to exist on
|
||||
// disk — Path is treated as an opaque identifier used both for de-duplication
|
||||
// and for the "Instructions from: <Path>" header injected into the system
|
||||
// prompt. If a context file with the same Path is already loaded the new
|
||||
// content replaces it.
|
||||
//
|
||||
// Returns an error when cf is nil or has an empty Path. AddContextFile is
|
||||
// safe to call from any goroutine.
|
||||
func (m *Kit) AddContextFile(cf *ContextFile) error {
|
||||
if cf == nil {
|
||||
return fmt.Errorf("AddContextFile: context file is nil")
|
||||
}
|
||||
if cf.Path == "" {
|
||||
return fmt.Errorf("AddContextFile: context file path is required")
|
||||
}
|
||||
|
||||
// Take a defensive copy so later mutations by the caller don't race with
|
||||
// the agent reading the composed prompt.
|
||||
stored := &ContextFile{
|
||||
Path: cf.Path,
|
||||
Content: strings.TrimSpace(cf.Content),
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
replaced := false
|
||||
for i, existing := range m.contextFiles {
|
||||
if existing.Path == stored.Path {
|
||||
m.contextFiles[i] = stored
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !replaced {
|
||||
m.contextFiles = append(m.contextFiles, stored)
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddContextFileContent is a convenience wrapper around [Kit.AddContextFile]
|
||||
// that builds the ContextFile from a path and inline content string. Use this
|
||||
// when the context originates from a database, API response, or any other
|
||||
// non-filesystem source.
|
||||
func (m *Kit) AddContextFileContent(path, content string) (*ContextFile, error) {
|
||||
cf := &ContextFile{Path: path, Content: content}
|
||||
if err := m.AddContextFile(cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cf, nil
|
||||
}
|
||||
|
||||
// LoadAndAddContextFile reads a file from disk and registers it as a project
|
||||
// context file via [Kit.AddContextFile]. The absolute path is stored on the
|
||||
// resulting ContextFile.
|
||||
func (m *Kit) LoadAndAddContextFile(path string) (*ContextFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadAndAddContextFile: %w", err)
|
||||
}
|
||||
abs, absErr := filepath.Abs(path)
|
||||
if absErr != nil {
|
||||
abs = path
|
||||
}
|
||||
cf := &ContextFile{
|
||||
Path: abs,
|
||||
Content: strings.TrimSpace(string(data)),
|
||||
}
|
||||
if err := m.AddContextFile(cf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cf, nil
|
||||
}
|
||||
|
||||
// RemoveContextFile removes the context file with the given path and
|
||||
// recomposes the system prompt. Returns true when a matching file was found
|
||||
// and removed, false otherwise.
|
||||
func (m *Kit) RemoveContextFile(path string) bool {
|
||||
m.runtimeMu.Lock()
|
||||
found := false
|
||||
for i, cf := range m.contextFiles {
|
||||
if cf.Path == path {
|
||||
m.contextFiles = append(m.contextFiles[:i], m.contextFiles[i+1:]...)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
m.applyComposedSystemPrompt()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetContextFiles replaces the active context-file set with the provided
|
||||
// slice. Pass nil or an empty slice to clear all context. The system prompt
|
||||
// is recomposed and applied. ContextFiles with empty Paths are rejected and
|
||||
// no mutation is performed.
|
||||
func (m *Kit) SetContextFiles(files []*ContextFile) error {
|
||||
// Validate first so a bad input doesn't partially mutate state.
|
||||
for i, cf := range files {
|
||||
if cf == nil {
|
||||
return fmt.Errorf("SetContextFiles: context file at index %d is nil", i)
|
||||
}
|
||||
if cf.Path == "" {
|
||||
return fmt.Errorf("SetContextFiles: context file at index %d has empty path", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Defensive copies so caller-side mutation cannot race with composition.
|
||||
copied := make([]*ContextFile, len(files))
|
||||
for i, cf := range files {
|
||||
copied[i] = &ContextFile{
|
||||
Path: cf.Path,
|
||||
Content: strings.TrimSpace(cf.Content),
|
||||
}
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
m.contextFiles = copied
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
+3
-3
@@ -148,9 +148,9 @@ func parseToolArgs(toolArgs string) map[string]any {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Finish reasons reported by the LLM provider on a completed turn. These
|
||||
// mirror fantasy.FinishReason string values so comparisons against
|
||||
// TurnEndEvent.StopReason / TurnResult.StopReason are stable across
|
||||
// providers.
|
||||
// mirror the underlying provider's finish reason string values so
|
||||
// comparisons against TurnEndEvent.StopReason / TurnResult.StopReason are
|
||||
// stable across providers.
|
||||
const (
|
||||
// FinishReasonStop: the model produced a natural stop (e.g. stop sequence
|
||||
// or end-of-turn signal).
|
||||
|
||||
+109
-20
@@ -8,55 +8,104 @@ import (
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
// ==== Extension Types ====
|
||||
//
|
||||
// Type aliases for internal extension types exposed through the public
|
||||
// ExtensionAPI interface. External SDK consumers can use these without
|
||||
// importing internal packages directly.
|
||||
|
||||
// ExtensionContext holds the runtime context passed to extensions, including
|
||||
// callbacks for printing, sending messages, and accessing session state.
|
||||
type ExtensionContext = extensions.Context
|
||||
|
||||
// ExtensionWidgetConfig describes a widget registered by an extension.
|
||||
type ExtensionWidgetConfig = extensions.WidgetConfig
|
||||
|
||||
// ExtensionWidgetPlacement indicates where a widget should be rendered
|
||||
// (e.g. above or below the conversation).
|
||||
type ExtensionWidgetPlacement = extensions.WidgetPlacement
|
||||
|
||||
// ExtensionHeaderFooterConfig describes a header or footer registered by an extension.
|
||||
type ExtensionHeaderFooterConfig = extensions.HeaderFooterConfig
|
||||
|
||||
// ExtensionEditorConfig configures editor behaviour overrides set by extensions.
|
||||
type ExtensionEditorConfig = extensions.EditorConfig
|
||||
|
||||
// ExtensionUIVisibility controls which UI elements are visible.
|
||||
type ExtensionUIVisibility = extensions.UIVisibility
|
||||
|
||||
// ExtensionToolRenderConfig describes custom tool output rendering registered by an extension.
|
||||
type ExtensionToolRenderConfig = extensions.ToolRenderConfig
|
||||
|
||||
// ExtensionMessageRendererConfig describes custom message rendering registered by an extension.
|
||||
type ExtensionMessageRendererConfig = extensions.MessageRendererConfig
|
||||
|
||||
// ExtensionSessionMessage represents a single message in the session history
|
||||
// as exposed to extensions.
|
||||
type ExtensionSessionMessage = extensions.SessionMessage
|
||||
|
||||
// ExtensionEntry represents a custom data entry stored by an extension
|
||||
// in the session tree.
|
||||
type ExtensionEntry = extensions.ExtensionEntry
|
||||
|
||||
// ExtensionStatusBarEntry describes a status bar entry registered by an extension.
|
||||
type ExtensionStatusBarEntry = extensions.StatusBarEntry
|
||||
|
||||
// ExtensionToolInfo describes a tool available to the agent, as seen by extensions.
|
||||
type ExtensionToolInfo = extensions.ToolInfo
|
||||
|
||||
// ExtensionCommandDef describes a slash command registered by an extension.
|
||||
type ExtensionCommandDef = extensions.CommandDef
|
||||
|
||||
// ExtensionAPI provides grouped access to all extension-related functionality.
|
||||
// This cleans up the main Kit API surface while keeping all extension capabilities available.
|
||||
type ExtensionAPI interface {
|
||||
// Context management
|
||||
SetContext(ctx extensions.Context)
|
||||
GetContext() extensions.Context
|
||||
SetContext(ctx ExtensionContext)
|
||||
GetContext() ExtensionContext
|
||||
UpdateContextModel(model string)
|
||||
|
||||
// Widgets
|
||||
SetWidget(config extensions.WidgetConfig)
|
||||
SetWidget(config ExtensionWidgetConfig)
|
||||
RemoveWidget(id string)
|
||||
GetWidgets(placement extensions.WidgetPlacement) []extensions.WidgetConfig
|
||||
GetWidgets(placement ExtensionWidgetPlacement) []ExtensionWidgetConfig
|
||||
|
||||
// Header/Footer
|
||||
SetHeader(config extensions.HeaderFooterConfig)
|
||||
SetHeader(config ExtensionHeaderFooterConfig)
|
||||
RemoveHeader()
|
||||
GetHeader() *extensions.HeaderFooterConfig
|
||||
SetFooter(config extensions.HeaderFooterConfig)
|
||||
GetHeader() *ExtensionHeaderFooterConfig
|
||||
SetFooter(config ExtensionHeaderFooterConfig)
|
||||
RemoveFooter()
|
||||
GetFooter() *extensions.HeaderFooterConfig
|
||||
GetFooter() *ExtensionHeaderFooterConfig
|
||||
|
||||
// Editor
|
||||
SetEditor(config extensions.EditorConfig)
|
||||
SetEditor(config ExtensionEditorConfig)
|
||||
ResetEditor()
|
||||
GetEditor() *extensions.EditorConfig
|
||||
GetEditor() *ExtensionEditorConfig
|
||||
|
||||
// UI Visibility
|
||||
SetUIVisibility(v extensions.UIVisibility)
|
||||
GetUIVisibility() *extensions.UIVisibility
|
||||
SetUIVisibility(v ExtensionUIVisibility)
|
||||
GetUIVisibility() *ExtensionUIVisibility
|
||||
|
||||
// Tool rendering
|
||||
GetToolRenderer(toolName string) *extensions.ToolRenderConfig
|
||||
GetMessageRenderer(name string) *extensions.MessageRendererConfig
|
||||
GetToolRenderer(toolName string) *ExtensionToolRenderConfig
|
||||
GetMessageRenderer(name string) *ExtensionMessageRendererConfig
|
||||
|
||||
// Session data
|
||||
GetSessionMessages() []extensions.SessionMessage
|
||||
GetSessionMessages() []ExtensionSessionMessage
|
||||
AppendEntry(extType, data string) (string, error)
|
||||
GetEntries(extType string) []extensions.ExtensionEntry
|
||||
GetEntries(extType string) []ExtensionEntry
|
||||
|
||||
// Status bar
|
||||
SetStatus(entry extensions.StatusBarEntry)
|
||||
SetStatus(entry ExtensionStatusBarEntry)
|
||||
RemoveStatus(key string)
|
||||
GetStatusEntries() []extensions.StatusBarEntry
|
||||
GetStatusEntries() []ExtensionStatusBarEntry
|
||||
|
||||
// Shortcuts
|
||||
GetShortcuts() map[string]func()
|
||||
|
||||
// Tools
|
||||
GetToolInfos() []extensions.ToolInfo
|
||||
GetToolInfos() []ExtensionToolInfo
|
||||
SetActiveTools(names []string)
|
||||
|
||||
// Options
|
||||
@@ -71,11 +120,27 @@ type ExtensionAPI interface {
|
||||
EmitBeforeSessionSwitch(switchReason string) (cancelled bool, reason string)
|
||||
|
||||
// Commands
|
||||
Commands() []extensions.CommandDef
|
||||
Commands() []ExtensionCommandDef
|
||||
|
||||
// Lifecycle
|
||||
Reload() error
|
||||
HasExtensions() bool
|
||||
|
||||
// Loaded returns metadata about the extensions currently loaded.
|
||||
Loaded() []ExtensionInfo
|
||||
}
|
||||
|
||||
// ExtensionInfo describes a single loaded extension for display purposes
|
||||
// (e.g. the startup banner or `kit extensions list`).
|
||||
type ExtensionInfo struct {
|
||||
// Path is the absolute path of the extension's .go file.
|
||||
Path string
|
||||
// ToolCount is the number of tools registered by the extension.
|
||||
ToolCount int
|
||||
// CommandCount is the number of slash commands registered.
|
||||
CommandCount int
|
||||
// HandlerCount is the total number of event handlers registered.
|
||||
HandlerCount int
|
||||
}
|
||||
|
||||
// extensionAPI implements ExtensionAPI by wrapping a Kit instance.
|
||||
@@ -456,3 +521,27 @@ func (e *extensionAPI) Reload() error {
|
||||
func (e *extensionAPI) HasExtensions() bool {
|
||||
return e.kit.extRunner != nil
|
||||
}
|
||||
|
||||
func (e *extensionAPI) Loaded() []ExtensionInfo {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
exts := e.kit.extRunner.Extensions()
|
||||
if len(exts) == 0 {
|
||||
return nil
|
||||
}
|
||||
infos := make([]ExtensionInfo, 0, len(exts))
|
||||
for _, ex := range exts {
|
||||
handlerCount := 0
|
||||
for _, hs := range ex.Handlers {
|
||||
handlerCount += len(hs)
|
||||
}
|
||||
infos = append(infos, ExtensionInfo{
|
||||
Path: ex.Path,
|
||||
ToolCount: len(ex.Tools),
|
||||
CommandCount: len(ex.Commands),
|
||||
HandlerCount: handlerCount,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
+143
-218
@@ -54,83 +54,51 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Subscribe to SDK events and forward to extension runner so extensions
|
||||
// see lifecycle events from the SDK's runTurn()/generate() path.
|
||||
|
||||
if runner.HasHandlers(extensions.AgentStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(TurnStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.AgentStartEvent{Prompt: ev.Prompt})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.AgentStart, func(ev TurnStartEvent) extensions.Event {
|
||||
return extensions.AgentStartEvent{Prompt: ev.Prompt}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if _, ok := e.(MessageStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageStartEvent{})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageStart, func(_ MessageStartEvent) extensions.Event {
|
||||
return extensions.MessageStartEvent{}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageUpdate) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(MessageUpdateEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageUpdateEvent{Chunk: ev.Chunk})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageUpdate, func(ev MessageUpdateEvent) extensions.Event {
|
||||
return extensions.MessageUpdateEvent{Chunk: ev.Chunk}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.MessageEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(MessageEndEvent); ok {
|
||||
_, _ = runner.Emit(extensions.MessageEndEvent{Content: ev.Content})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.MessageEnd, func(ev MessageEndEvent) extensions.Event {
|
||||
return extensions.MessageEndEvent{Content: ev.Content}
|
||||
})
|
||||
|
||||
// Tool output streaming events (observation only).
|
||||
if runner.HasHandlers(extensions.ToolOutput) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolOutputEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ToolOutput, func(ev ToolOutputEvent) extensions.Event {
|
||||
return extensions.ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
Chunk: ev.Chunk,
|
||||
IsStderr: ev.IsStderr,
|
||||
}
|
||||
})
|
||||
|
||||
// Tool call input streaming events — fire as the LLM generates tool arguments.
|
||||
if runner.HasHandlers(extensions.ToolCallInputStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
ToolKind: ev.ToolKind,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
if runner.HasHandlers(extensions.ToolCallInputDelta) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallDeltaEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputDeltaEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Delta: ev.Delta,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
if runner.HasHandlers(extensions.ToolCallInputEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ToolCallEndEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ToolCallInputEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputStart, func(ev ToolCallStartEvent) extensions.Event {
|
||||
return extensions.ToolCallInputStartEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName,
|
||||
ToolKind: ev.ToolKind,
|
||||
}
|
||||
})
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputDelta, func(ev ToolCallDeltaEvent) extensions.Event {
|
||||
return extensions.ToolCallInputDeltaEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
Delta: ev.Delta,
|
||||
}
|
||||
})
|
||||
bridgeObserve(m, runner, extensions.ToolCallInputEnd, func(ev ToolCallEndEvent) extensions.Event {
|
||||
return extensions.ToolCallInputEndEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.AgentEnd) {
|
||||
m.Subscribe(func(e Event) {
|
||||
@@ -278,54 +246,13 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
// Extension ContextPrepare → SDK ContextPrepare hook.
|
||||
if runner.HasHandlers(extensions.ContextPrepare) {
|
||||
m.OnContextPrepare(HookPriorityNormal, func(h ContextPrepareHook) *ContextPrepareResult {
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
// Extract plain text from each message for the extension API.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
extMsgs := llmToContextMessages(h.Messages)
|
||||
result, _ := runner.Emit(extensions.ContextPrepareEvent{Messages: extMsgs})
|
||||
r, ok := result.(extensions.ContextPrepareResult)
|
||||
if !ok || r.Messages == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]LLMMessage, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
// Reuse original message (preserves original role and content).
|
||||
rebuilt = append(rebuilt, h.Messages[cm.Index])
|
||||
} else {
|
||||
// New message injected by extension — construct from role + text.
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &ContextPrepareResult{Messages: rebuilt}
|
||||
return &ContextPrepareResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -359,99 +286,56 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
|
||||
// --- Step lifecycle observation events ---
|
||||
|
||||
if runner.HasHandlers(extensions.StepStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepStartEvent{StepNumber: ev.StepNumber})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.StepStart, func(ev StepStartEvent) extensions.Event {
|
||||
return extensions.StepStartEvent{StepNumber: ev.StepNumber}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.StepFinish) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(StepFinishEvent); ok {
|
||||
_, _ = runner.Emit(extensions.StepFinishEvent{
|
||||
StepNumber: ev.StepNumber,
|
||||
HasToolCalls: ev.HasToolCalls,
|
||||
FinishReason: ev.FinishReason,
|
||||
InputTokens: ev.Usage.InputTokens,
|
||||
OutputTokens: ev.Usage.OutputTokens,
|
||||
CacheReadTokens: ev.Usage.CacheReadTokens,
|
||||
CacheWriteTokens: ev.Usage.CacheCreationTokens,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.StepFinish, func(ev StepFinishEvent) extensions.Event {
|
||||
return extensions.StepFinishEvent{
|
||||
StepNumber: ev.StepNumber,
|
||||
HasToolCalls: ev.HasToolCalls,
|
||||
FinishReason: ev.FinishReason,
|
||||
InputTokens: ev.Usage.InputTokens,
|
||||
OutputTokens: ev.Usage.OutputTokens,
|
||||
CacheReadTokens: ev.Usage.CacheReadTokens,
|
||||
CacheWriteTokens: ev.Usage.CacheCreationTokens,
|
||||
}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.ReasoningStart) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ReasoningStartEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ReasoningStartEvent{ID: ev.ID})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.ReasoningStart, func(ev ReasoningStartEvent) extensions.Event {
|
||||
return extensions.ReasoningStartEvent{ID: ev.ID}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Warnings) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(WarningsEvent); ok {
|
||||
_, _ = runner.Emit(extensions.WarningsEvent{Warnings: ev.Warnings})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Warnings, func(ev WarningsEvent) extensions.Event {
|
||||
return extensions.WarningsEvent{Warnings: ev.Warnings}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Source) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(SourceEvent); ok {
|
||||
_, _ = runner.Emit(extensions.SourceEvent{
|
||||
SourceType: ev.SourceType,
|
||||
ID: ev.ID,
|
||||
URL: ev.URL,
|
||||
Title: ev.Title,
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Source, func(ev SourceEvent) extensions.Event {
|
||||
return extensions.SourceEvent{
|
||||
SourceType: ev.SourceType,
|
||||
ID: ev.ID,
|
||||
URL: ev.URL,
|
||||
Title: ev.Title,
|
||||
}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Error) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(ErrorEvent); ok {
|
||||
_, _ = runner.Emit(extensions.ErrorEvent{Error: ev.Error.Error()})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Error, func(ev ErrorEvent) extensions.Event {
|
||||
return extensions.ErrorEvent{Error: ev.Error.Error()}
|
||||
})
|
||||
|
||||
if runner.HasHandlers(extensions.Retry) {
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(RetryEvent); ok {
|
||||
_, _ = runner.Emit(extensions.RetryEvent{
|
||||
Attempt: ev.Attempt,
|
||||
Error: ev.Error.Error(),
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
bridgeObserve(m, runner, extensions.Retry, func(ev RetryEvent) extensions.Event {
|
||||
return extensions.RetryEvent{
|
||||
Attempt: ev.Attempt,
|
||||
Error: ev.Error.Error(),
|
||||
}
|
||||
})
|
||||
|
||||
// --- PrepareStep hook ---
|
||||
// Extension PrepareStep → SDK PrepareStep hook.
|
||||
// Same pattern as ContextPrepare: convert LLMMessage ↔ ContextMessage.
|
||||
if runner.HasHandlers(extensions.PrepareStep) {
|
||||
m.OnPrepareStep(HookPriorityNormal, func(h PrepareStepHook) *PrepareStepResult {
|
||||
// Convert LLM message slice to extension ContextMessage slice.
|
||||
extMsgs := make([]extensions.ContextMessage, len(h.Messages))
|
||||
for i, msg := range h.Messages {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
|
||||
extMsgs := llmToContextMessages(h.Messages)
|
||||
result, _ := runner.Emit(extensions.PrepareStepEvent{
|
||||
StepNumber: h.StepNumber,
|
||||
Messages: extMsgs,
|
||||
@@ -460,30 +344,71 @@ func (m *Kit) bridgeExtensions(runner *extensions.Runner) {
|
||||
if !ok || r.Messages == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild LLM message slice from extension result.
|
||||
rebuilt := make([]LLMMessage, 0, len(r.Messages))
|
||||
for _, cm := range r.Messages {
|
||||
if cm.Index >= 0 && cm.Index < len(h.Messages) {
|
||||
rebuilt = append(rebuilt, h.Messages[cm.Index])
|
||||
} else {
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return &PrepareStepResult{Messages: rebuilt}
|
||||
return &PrepareStepResult{Messages: contextMessagesToLLM(r.Messages, h.Messages)}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// bridgeObserve subscribes to SDK events of type In and forwards them to the
|
||||
// extension runner as the event returned by conv. The subscription is only
|
||||
// registered when the runner has handlers for the given event kind.
|
||||
func bridgeObserve[In Event](m *Kit, runner *extensions.Runner, kind extensions.EventType, conv func(In) extensions.Event) {
|
||||
if !runner.HasHandlers(kind) {
|
||||
return
|
||||
}
|
||||
m.Subscribe(func(e Event) {
|
||||
if ev, ok := e.(In); ok {
|
||||
_, _ = runner.Emit(conv(ev))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// llmToContextMessages converts a slice of LLM messages to extension
|
||||
// ContextMessage values, extracting plain text from each message.
|
||||
func llmToContextMessages(msgs []LLMMessage) []extensions.ContextMessage {
|
||||
extMsgs := make([]extensions.ContextMessage, len(msgs))
|
||||
for i, msg := range msgs {
|
||||
var sb strings.Builder
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(LLMTextPart); ok {
|
||||
sb.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
extMsgs[i] = extensions.ContextMessage{
|
||||
Index: i,
|
||||
Role: string(msg.Role),
|
||||
Content: sb.String(),
|
||||
}
|
||||
}
|
||||
return extMsgs
|
||||
}
|
||||
|
||||
// contextMessagesToLLM rebuilds an LLM message slice from extension
|
||||
// ContextMessages. Messages with a valid index reuse the original from
|
||||
// originals; new messages injected by extensions are constructed from
|
||||
// role + text.
|
||||
func contextMessagesToLLM(cms []extensions.ContextMessage, originals []LLMMessage) []LLMMessage {
|
||||
rebuilt := make([]LLMMessage, 0, len(cms))
|
||||
for _, cm := range cms {
|
||||
if cm.Index >= 0 && cm.Index < len(originals) {
|
||||
// Reuse original message (preserves original role and content).
|
||||
rebuilt = append(rebuilt, originals[cm.Index])
|
||||
} else {
|
||||
// New message injected by extension — construct from role + text.
|
||||
role := LLMRoleUser
|
||||
switch cm.Role {
|
||||
case "assistant":
|
||||
role = LLMRoleAssistant
|
||||
case "system":
|
||||
role = LLMRoleSystem
|
||||
case "tool":
|
||||
role = LLMRoleTool
|
||||
}
|
||||
rebuilt = append(rebuilt, LLMMessage{
|
||||
Role: role,
|
||||
Content: []LLMMessagePart{LLMTextPart{Text: cm.Content}},
|
||||
})
|
||||
}
|
||||
}
|
||||
return rebuilt
|
||||
}
|
||||
|
||||
+138
-11
@@ -58,6 +58,14 @@ type Kit struct {
|
||||
// When false, per-model system prompts from modelSettings/customModels
|
||||
// can replace the default prompt on model switch.
|
||||
hasCustomSystemPrompt bool
|
||||
// systemPromptSource holds the raw configured value (file path or text)
|
||||
// when hasCustomSystemPrompt is true; empty when the built-in default is in use.
|
||||
systemPromptSource string
|
||||
// basePrompt holds the resolved base system prompt text (post file-load,
|
||||
// pre runtime-context composition) captured during New. Used by
|
||||
// RefreshSystemPrompt to recompose after skills/context-file mutations.
|
||||
// Protected by runtimeMu.
|
||||
basePrompt string
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
@@ -87,6 +95,12 @@ type Kit struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// runtimeMu protects contextFiles and skills against concurrent runtime
|
||||
// mutations via AddSkill / RemoveSkill / AddContextFile etc. The fields
|
||||
// are read by composeSystemPrompt and several other accessors, so all
|
||||
// reads and writes after Kit construction must take this lock.
|
||||
runtimeMu sync.RWMutex
|
||||
|
||||
// steerCh is a buffered channel used to inject steering messages into
|
||||
// the running agent turn via the LLM library's PrepareStep. Created fresh for
|
||||
// each generate() call and set to nil when idle. Protected by steerMu.
|
||||
@@ -632,21 +646,43 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasCustomSystemPrompt reports whether the user explicitly configured a system
|
||||
// prompt via --system-prompt, a config file entry, or SDK Options.SystemPrompt.
|
||||
// When false, the built-in default (or a per-model override) is in use and can
|
||||
// be replaced transparently on model switch.
|
||||
func (m *Kit) HasCustomSystemPrompt() bool {
|
||||
return m.hasCustomSystemPrompt
|
||||
}
|
||||
|
||||
// GetSystemPromptSource returns the raw configured value — a file path or
|
||||
// inline text — when HasCustomSystemPrompt is true; returns an empty string
|
||||
// when the built-in default prompt is active.
|
||||
func (m *Kit) GetSystemPromptSource() string {
|
||||
return m.systemPromptSource
|
||||
}
|
||||
|
||||
// composeSystemPrompt takes a base system prompt and composes it with the
|
||||
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
|
||||
// This mirrors the composition done during Kit.New() initialization.
|
||||
// It acquires a read lock on runtimeMu while snapshotting contextFiles and
|
||||
// skills, so callers must not hold the write lock.
|
||||
func (m *Kit) composeSystemPrompt(basePrompt string) string {
|
||||
cwd, _ := os.Getwd()
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
m.runtimeMu.RLock()
|
||||
contextFiles := append([]*ContextFile(nil), m.contextFiles...)
|
||||
loadedSkills := append([]*skills.Skill(nil), m.skills...)
|
||||
m.runtimeMu.RUnlock()
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range m.contextFiles {
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
}
|
||||
|
||||
// Inject skills metadata.
|
||||
if len(m.skills) > 0 {
|
||||
pb.WithSkills(m.skills)
|
||||
if len(loadedSkills) > 0 {
|
||||
pb.WithSkills(loadedSkills)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
@@ -1035,6 +1071,41 @@ type Options struct {
|
||||
// real-time progress in the TUI.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// MCPTaskMode overrides the per-server [MCPTaskMode] for task-augmented
|
||||
// tools/call execution. Keys are MCP server names. Servers not present
|
||||
// in the map fall back to the TasksMode field of MCPServerConfig (or
|
||||
// MCPTaskModeAuto when that is empty). See the MCP Tasks spec for the
|
||||
// underlying semantics:
|
||||
// https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks
|
||||
MCPTaskMode map[string]MCPTaskMode
|
||||
|
||||
// MCPTaskTimeout is the maximum wall-clock duration to wait for a
|
||||
// task-augmented tool call to reach a terminal state. Independent of
|
||||
// any per-call context deadline; whichever fires first wins. Zero
|
||||
// means use the default (15 minutes).
|
||||
MCPTaskTimeout time.Duration
|
||||
|
||||
// MCPTaskTTL is the TTL hint sent in TaskParams for every
|
||||
// task-augmented tools/call. Zero omits the TTL and lets the server
|
||||
// pick its own retention policy.
|
||||
MCPTaskTTL time.Duration
|
||||
|
||||
// MCPTaskPollInterval is the fallback interval between tasks/get
|
||||
// requests when the server does not suggest one. Zero means use the
|
||||
// default (1 second).
|
||||
MCPTaskPollInterval time.Duration
|
||||
|
||||
// MCPTaskMaxPollInterval caps the polling interval (a server-supplied
|
||||
// pollInterval can otherwise grow without bound). Zero means use the
|
||||
// default (5 seconds).
|
||||
MCPTaskMaxPollInterval time.Duration
|
||||
|
||||
// MCPTaskProgress, if non-nil, is invoked once when a task is accepted
|
||||
// and on every status transition observed by the polling loop. The
|
||||
// final invocation always carries a terminal status. Implementations
|
||||
// must not block; long work should run on a goroutine.
|
||||
MCPTaskProgress MCPTaskProgressHandler
|
||||
|
||||
// CLI is optional CLI-specific configuration. SDK users leave this nil.
|
||||
CLI *CLIOptions
|
||||
|
||||
@@ -1144,6 +1215,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
maxSteps int
|
||||
streaming bool
|
||||
hasCustomSystemPrompt bool
|
||||
systemPromptSource string
|
||||
capturedBasePrompt string
|
||||
)
|
||||
|
||||
if err := func() error {
|
||||
@@ -1250,13 +1323,27 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// explicitly set system-prompt, use the per-model prompt as the
|
||||
// base instead of the global default.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
rawPromptInput := viper.GetString("system-prompt")
|
||||
|
||||
// Resolve a file path to its content so PromptBuilder receives the
|
||||
// actual prompt text rather than a literal path string. Without this,
|
||||
// when system-prompt is set to a file path in the config file or via
|
||||
// --system-prompt, the path itself becomes the effective system prompt
|
||||
// sent to the model (LoadSystemPrompt only ran later, after viper had
|
||||
// been overwritten with the augmented base text).
|
||||
basePrompt, _ := config.LoadSystemPrompt(rawPromptInput)
|
||||
if basePrompt == "" {
|
||||
basePrompt = rawPromptInput
|
||||
}
|
||||
|
||||
// Track whether the user explicitly configured a custom system
|
||||
// prompt. When they haven't (basePrompt is the built-in default
|
||||
// or empty), per-model system prompts can replace it on switch.
|
||||
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
|
||||
hasCustomSystemPrompt = userSetSystemPrompt
|
||||
if hasCustomSystemPrompt {
|
||||
systemPromptSource = rawPromptInput
|
||||
}
|
||||
|
||||
// Check for per-model system prompt override when no explicit
|
||||
// global system-prompt was configured by the user.
|
||||
@@ -1281,6 +1368,10 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Capture the resolved base prompt so RefreshSystemPrompt can
|
||||
// recompose later after runtime skill/context-file mutations.
|
||||
capturedBasePrompt = basePrompt
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
@@ -1387,6 +1478,14 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streaming,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
MCPTaskConfig: mcpTaskOptions{
|
||||
perServer: opts.MCPTaskMode,
|
||||
defaultTTL: opts.MCPTaskTTL,
|
||||
pollInterval: opts.MCPTaskPollInterval,
|
||||
maxPollInterval: opts.MCPTaskMaxPollInterval,
|
||||
timeout: opts.MCPTaskTimeout,
|
||||
progress: opts.MCPTaskProgress,
|
||||
}.toToolsConfig(),
|
||||
}
|
||||
|
||||
// Set up OAuth handler for remote MCP servers. The SDK does not create
|
||||
@@ -1413,7 +1512,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
setupOpts.SpinnerFunc = agent.SpinnerFunc(opts.CLI.SpinnerFunc)
|
||||
setupOpts.UseBufferedLogger = opts.CLI.UseBufferedLogger
|
||||
if opts.CLI.ProgressReaderFunc != nil {
|
||||
providerConfig.ProgressReaderFunc = opts.CLI.ProgressReaderFunc
|
||||
@@ -1457,6 +1556,8 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
opts: opts,
|
||||
mcpConfig: mcpConfig,
|
||||
hasCustomSystemPrompt: hasCustomSystemPrompt,
|
||||
systemPromptSource: systemPromptSource,
|
||||
basePrompt: capturedBasePrompt,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
@@ -1483,15 +1584,32 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// GetContextFiles returns the context files (e.g. AGENTS.md) loaded during
|
||||
// initialisation. Returns nil if no context files were found.
|
||||
// GetContextFiles returns the context files (e.g. AGENTS.md) currently active
|
||||
// on this Kit instance. The returned slice is a snapshot — mutating it does
|
||||
// not affect Kit state. Returns nil when no context files are loaded.
|
||||
func (m *Kit) GetContextFiles() []*ContextFile {
|
||||
return m.contextFiles
|
||||
m.runtimeMu.RLock()
|
||||
defer m.runtimeMu.RUnlock()
|
||||
if len(m.contextFiles) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*ContextFile, len(m.contextFiles))
|
||||
copy(out, m.contextFiles)
|
||||
return out
|
||||
}
|
||||
|
||||
// GetSkills returns the skills loaded during initialisation.
|
||||
// GetSkills returns the skills currently active on this Kit instance. The
|
||||
// returned slice is a snapshot — mutating it does not affect Kit state.
|
||||
// Returns nil when no skills are loaded.
|
||||
func (m *Kit) GetSkills() []*Skill {
|
||||
return m.skills
|
||||
m.runtimeMu.RLock()
|
||||
defer m.runtimeMu.RUnlock()
|
||||
if len(m.skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*Skill, len(m.skills))
|
||||
copy(out, m.skills)
|
||||
return out
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1536,12 +1654,14 @@ func (m *Kit) expandSkillCommand(prompt string) string {
|
||||
|
||||
// Find the skill by name.
|
||||
var skillPath string
|
||||
m.runtimeMu.RLock()
|
||||
for _, s := range m.skills {
|
||||
if s.Name == name {
|
||||
skillPath = s.Path
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.RUnlock()
|
||||
if skillPath == "" {
|
||||
return prompt
|
||||
}
|
||||
@@ -1799,6 +1919,13 @@ func (m *Kit) Subagent(ctx context.Context, cfg SubagentConfig) (*SubagentResult
|
||||
Streaming: true,
|
||||
MCPConfig: m.mcpConfig,
|
||||
}
|
||||
// Propagate the parent's MCP task configuration so a child subagent
|
||||
// invoking long-running MCP tools observes the same per-server modes,
|
||||
// timeouts, and progress callback as the parent. Without this, child
|
||||
// agents would silently fall back to MCPTaskModeAuto with default
|
||||
// polling and no progress feedback even when the parent had configured
|
||||
// custom values.
|
||||
inheritMCPTaskOptions(childOpts, m.opts)
|
||||
child, err := New(ctx, childOpts)
|
||||
if err != nil {
|
||||
return &SubagentResult{Elapsed: time.Since(start)}, fmt.Errorf("failed to create subagent: %w", err)
|
||||
@@ -2042,7 +2169,7 @@ func (m *Kit) generate(ctx context.Context, messages []fantasy.Message) (*agent.
|
||||
})
|
||||
},
|
||||
|
||||
// New callbacks for previously unwired Fantasy lifecycle events.
|
||||
// New callbacks for previously unwired agent lifecycle events.
|
||||
OnStepStart: func(stepNumber int) {
|
||||
m.events.emit(StepStartEvent{StepNumber: stepNumber})
|
||||
},
|
||||
|
||||
@@ -3,6 +3,7 @@ package kit_test
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@@ -306,3 +307,92 @@ func TestSessionManagement(t *testing.T) {
|
||||
// resetViper wipes viper's global state so a test case doesn't leak
|
||||
// viper.Set() calls into the next one. Used via defer in subtests.
|
||||
func resetViper() { viper.Reset() }
|
||||
|
||||
// TestNewSystemPromptFilePath is a regression test for issue #25.
|
||||
//
|
||||
// When Options.SystemPrompt (or the --system-prompt flag / config entry) is a
|
||||
// file path, Kit must resolve the path to its file contents *before* the
|
||||
// PromptBuilder composes the runtime context. Previously the path string
|
||||
// itself was used verbatim as the base prompt, so the LLM received the path —
|
||||
// not the prompt — as its system message.
|
||||
func TestNewSystemPromptFilePath(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
defer resetViper()
|
||||
|
||||
const promptContent = "You are a strict regression-test persona. Marker: KIT-25-OK"
|
||||
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "kit-system-prompt-*.md")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp prompt file: %v", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(promptContent); err != nil {
|
||||
t.Fatalf("failed to write temp prompt file: %v", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
t.Fatalf("failed to close temp prompt file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Kit with system-prompt file: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if !host.HasCustomSystemPrompt() {
|
||||
t.Error("HasCustomSystemPrompt() = false; want true when --system-prompt is set")
|
||||
}
|
||||
if got, want := host.GetSystemPromptSource(), tmpFile.Name(); got != want {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, want)
|
||||
}
|
||||
|
||||
// The composed system prompt is written back to viper after PromptBuilder
|
||||
// runs. It must contain the file's contents, not the file path.
|
||||
composed := viper.GetString("system-prompt")
|
||||
if !strings.Contains(composed, promptContent) {
|
||||
t.Errorf("composed system-prompt does not contain file contents\n composed = %q\n want substring = %q", composed, promptContent)
|
||||
}
|
||||
if strings.TrimSpace(composed) == tmpFile.Name() {
|
||||
t.Errorf("composed system-prompt is the file path verbatim (%q); LoadSystemPrompt was not applied before PromptBuilder", composed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSystemPromptInline confirms that inline system-prompt strings still
|
||||
// flow through unchanged after the file-path resolution change.
|
||||
func TestNewSystemPromptInline(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
defer resetViper()
|
||||
|
||||
const inline = "You are a concise inline-prompt persona."
|
||||
|
||||
ctx := context.Background()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
SystemPrompt: inline,
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Kit with inline system-prompt: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if !host.HasCustomSystemPrompt() {
|
||||
t.Error("HasCustomSystemPrompt() = false; want true for inline prompt")
|
||||
}
|
||||
if got := host.GetSystemPromptSource(); got != inline {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, inline)
|
||||
}
|
||||
if composed := viper.GetString("system-prompt"); !strings.Contains(composed, inline) {
|
||||
t.Errorf("composed system-prompt missing inline content; got %q", composed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,300 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// MCPTaskStatus represents the lifecycle state of a task-augmented MCP
|
||||
// tool call. See https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks
|
||||
// for the underlying spec.
|
||||
type MCPTaskStatus string
|
||||
|
||||
const (
|
||||
// MCPTaskStatusWorking indicates the task is currently being processed.
|
||||
MCPTaskStatusWorking MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusWorking)
|
||||
// MCPTaskStatusInputRequired indicates the server is waiting for client
|
||||
// input before it can proceed (rare; typically surfaced via elicitation).
|
||||
MCPTaskStatusInputRequired MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusInputRequired)
|
||||
// MCPTaskStatusCompleted indicates the task finished successfully.
|
||||
MCPTaskStatusCompleted MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCompleted)
|
||||
// MCPTaskStatusFailed indicates the task ended in error.
|
||||
MCPTaskStatusFailed MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusFailed)
|
||||
// MCPTaskStatusCancelled indicates the task was cancelled before completion.
|
||||
MCPTaskStatusCancelled MCPTaskStatus = MCPTaskStatus(mcp.TaskStatusCancelled)
|
||||
)
|
||||
|
||||
// IsTerminal reports whether the status represents a final state — that is,
|
||||
// the task will not change again. Terminal states are completed, failed,
|
||||
// and cancelled.
|
||||
func (s MCPTaskStatus) IsTerminal() bool {
|
||||
return mcp.TaskStatus(s).IsTerminal()
|
||||
}
|
||||
|
||||
// MCPTaskMode controls when Kit augments tools/call requests with MCP task
|
||||
// metadata for a specific server.
|
||||
type MCPTaskMode string
|
||||
|
||||
const (
|
||||
// MCPTaskModeAuto augments tools/call with task metadata only when the
|
||||
// server advertises tasks/toolCalls capability during initialize.
|
||||
// This is the default and is safe to leave unconfigured for any
|
||||
// existing MCP server.
|
||||
MCPTaskModeAuto MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAuto)
|
||||
// MCPTaskModeNever forces every tools/call to be issued synchronously
|
||||
// (no Task field), regardless of server capability.
|
||||
MCPTaskModeNever MCPTaskMode = MCPTaskMode(tools.MCPTaskModeNever)
|
||||
// MCPTaskModeAlways always opts into task augmentation, even when the
|
||||
// server didn't advertise the capability. The server may still respond
|
||||
// synchronously; this just expresses client intent unconditionally.
|
||||
MCPTaskModeAlways MCPTaskMode = MCPTaskMode(tools.MCPTaskModeAlways)
|
||||
)
|
||||
|
||||
// MCPTask is the SDK-level view of an MCP Task. Timestamps are best-effort
|
||||
// parsed from the server's ISO-8601 strings; they may be the zero time when
|
||||
// the server omitted them or used a non-RFC3339 format.
|
||||
type MCPTask struct {
|
||||
// Server is the configured MCP server name this task lives on.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the task.
|
||||
TaskID string
|
||||
// Status is the current task lifecycle state.
|
||||
Status MCPTaskStatus
|
||||
// StatusMessage is an optional human-readable description provided by
|
||||
// the server.
|
||||
StatusMessage string
|
||||
// CreatedAt is when the task was created on the server.
|
||||
CreatedAt time.Time
|
||||
// UpdatedAt is when the task was last updated on the server.
|
||||
UpdatedAt time.Time
|
||||
// TTL is how long the server intends to retain this task after creation.
|
||||
// Zero means the server did not advertise a TTL.
|
||||
TTL time.Duration
|
||||
// PollInterval is the suggested time between status checks. Zero means
|
||||
// the client should use its own default.
|
||||
PollInterval time.Duration
|
||||
}
|
||||
|
||||
// MCPTaskProgress is a single status update emitted while Kit is waiting
|
||||
// on a task-augmented tool call.
|
||||
type MCPTaskProgress struct {
|
||||
// Server is the configured MCP server name.
|
||||
Server string
|
||||
// TaskID is the server-assigned identifier for the in-flight task.
|
||||
TaskID string
|
||||
// Status is the most recent task status observed.
|
||||
Status MCPTaskStatus
|
||||
// Message is the optional human-readable status message from the server.
|
||||
Message string
|
||||
}
|
||||
|
||||
// MCPTaskProgressHandler is called once when a task is accepted and again
|
||||
// on every observed status transition. The final invocation always carries
|
||||
// a terminal status. Implementations must not block; long work should be
|
||||
// dispatched on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// MCPTaskConfig configures task-aware MCP tools/call execution. All fields
|
||||
// are optional; the zero value disables progress callbacks and applies
|
||||
// sensible polling defaults inside the engine.
|
||||
//
|
||||
// For most consumers, the flat [Options] fields (`MCPTaskMode`,
|
||||
// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`,
|
||||
// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point.
|
||||
// MCPTaskConfig is exposed for the low-level [AgentConfig] path.
|
||||
type MCPTaskConfig struct {
|
||||
// PerServerMode overrides the per-server task mode resolved from
|
||||
// [MCPServerConfig]. Keys are server names. Missing entries fall back
|
||||
// to the configured value.
|
||||
PerServerMode map[string]MCPTaskMode
|
||||
|
||||
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
|
||||
// tools/call. Zero means omit the TTL — let the server pick its own.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// PollInterval is the fallback interval between tasks/get requests
|
||||
// when the server does not suggest one. Zero defaults to 1 second.
|
||||
PollInterval time.Duration
|
||||
|
||||
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
|
||||
MaxPollInterval time.Duration
|
||||
|
||||
// Timeout is the maximum wall-clock duration to wait for a task to
|
||||
// reach a terminal state. Zero defaults to 15 minutes. Independent
|
||||
// of the per-call context deadline; whichever fires first wins.
|
||||
Timeout time.Duration
|
||||
|
||||
// Progress, if non-nil, receives every status transition observed by
|
||||
// the polling loop.
|
||||
Progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal
|
||||
// tools-package representation. Keeps the dependency arrow internal-only.
|
||||
func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig {
|
||||
cfg := tools.MCPTaskConfig{
|
||||
DefaultTTL: c.DefaultTTL,
|
||||
PollInterval: c.PollInterval,
|
||||
MaxPollInterval: c.MaxPollInterval,
|
||||
Timeout: c.Timeout,
|
||||
}
|
||||
if len(c.PerServerMode) > 0 {
|
||||
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode))
|
||||
for k, v := range c.PerServerMode {
|
||||
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
|
||||
}
|
||||
}
|
||||
if c.Progress != nil {
|
||||
h := c.Progress
|
||||
cfg.Progress = func(p tools.MCPTaskProgress) {
|
||||
h(MCPTaskProgress{
|
||||
Server: p.Server,
|
||||
TaskID: p.TaskID,
|
||||
Status: MCPTaskStatus(p.Status),
|
||||
Message: p.Message,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
|
||||
// Stored on Options as a single value so the public surface stays compact;
|
||||
// individual fields are exposed via WithMCP* builder functions.
|
||||
type mcpTaskOptions struct {
|
||||
perServer map[string]MCPTaskMode
|
||||
defaultTTL time.Duration
|
||||
pollInterval time.Duration
|
||||
maxPollInterval time.Duration
|
||||
timeout time.Duration
|
||||
progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
// toToolsConfig converts the SDK-level config to the internal tools-package
|
||||
// representation. Keeps the dependency arrow internal-only.
|
||||
func (o mcpTaskOptions) toToolsConfig() tools.MCPTaskConfig {
|
||||
cfg := tools.MCPTaskConfig{
|
||||
DefaultTTL: o.defaultTTL,
|
||||
PollInterval: o.pollInterval,
|
||||
MaxPollInterval: o.maxPollInterval,
|
||||
Timeout: o.timeout,
|
||||
}
|
||||
if len(o.perServer) > 0 {
|
||||
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(o.perServer))
|
||||
for k, v := range o.perServer {
|
||||
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
|
||||
}
|
||||
}
|
||||
if o.progress != nil {
|
||||
h := o.progress
|
||||
cfg.Progress = func(p tools.MCPTaskProgress) {
|
||||
h(MCPTaskProgress{
|
||||
Server: p.Server,
|
||||
TaskID: p.TaskID,
|
||||
Status: MCPTaskStatus(p.Status),
|
||||
Message: p.Message,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ListMCPTasks queries tasks/list on the named MCP server and returns the
|
||||
// active and recent tasks the server is willing to disclose. Returns an
|
||||
// error when the server isn't loaded, doesn't expose tasks/list, or the
|
||||
// underlying transport fails.
|
||||
func (m *Kit) ListMCPTasks(ctx context.Context, serverName string) ([]MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
infos, err := mgr.ListServerTasks(ctx, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]MCPTask, len(infos))
|
||||
for i, t := range infos {
|
||||
out[i] = mcpTaskFromInternal(t)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetMCPTask queries tasks/get for a single in-flight task on the named
|
||||
// server. The returned MCPTask reflects the server's current view of the
|
||||
// task.
|
||||
func (m *Kit) GetMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
info, err := mgr.GetServerTask(ctx, serverName, taskID)
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
return mcpTaskFromInternal(info), nil
|
||||
}
|
||||
|
||||
// CancelMCPTask issues tasks/cancel for an in-flight task on the named
|
||||
// server. Returns the post-cancel task state when the server responded
|
||||
// with one. Cancelling an already-terminal task is a no-op on most
|
||||
// servers.
|
||||
func (m *Kit) CancelMCPTask(ctx context.Context, serverName, taskID string) (MCPTask, error) {
|
||||
mgr, err := m.mcpToolManager()
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
info, err := mgr.CancelServerTask(ctx, serverName, taskID)
|
||||
if err != nil {
|
||||
return MCPTask{}, err
|
||||
}
|
||||
return mcpTaskFromInternal(info), nil
|
||||
}
|
||||
|
||||
// mcpToolManager returns the underlying MCP tool manager or an error when
|
||||
// no MCP servers are configured.
|
||||
func (m *Kit) mcpToolManager() (*tools.MCPToolManager, error) {
|
||||
if m == nil || m.agent == nil {
|
||||
return nil, fmt.Errorf("kit instance has no agent")
|
||||
}
|
||||
mgr := m.agent.GetMCPToolManager()
|
||||
if mgr == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// mcpTaskFromInternal adapts the internal tools.MCPTaskInfo to the
|
||||
// SDK-level MCPTask type. Keeps the public surface independent of
|
||||
// internal package types.
|
||||
func mcpTaskFromInternal(t tools.MCPTaskInfo) MCPTask {
|
||||
return MCPTask{
|
||||
Server: t.Server,
|
||||
TaskID: t.TaskID,
|
||||
Status: MCPTaskStatus(t.Status),
|
||||
StatusMessage: t.StatusMessage,
|
||||
CreatedAt: t.CreatedAt,
|
||||
UpdatedAt: t.UpdatedAt,
|
||||
TTL: t.TTL,
|
||||
PollInterval: t.PollInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// inheritMCPTaskOptions copies every MCP task-related field from parent
|
||||
// onto child. Used by Kit.Subagent so child instances observe the same
|
||||
// per-server modes, timeouts, and progress callback as their parent.
|
||||
// A nil parent is a no-op so callers don't have to guard at the call site.
|
||||
func inheritMCPTaskOptions(child, parent *Options) {
|
||||
if child == nil || parent == nil {
|
||||
return
|
||||
}
|
||||
child.MCPTaskMode = parent.MCPTaskMode
|
||||
child.MCPTaskTimeout = parent.MCPTaskTimeout
|
||||
child.MCPTaskTTL = parent.MCPTaskTTL
|
||||
child.MCPTaskPollInterval = parent.MCPTaskPollInterval
|
||||
child.MCPTaskMaxPollInterval = parent.MCPTaskMaxPollInterval
|
||||
child.MCPTaskProgress = parent.MCPTaskProgress
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
func TestMCPTaskStatusIsTerminal(t *testing.T) {
|
||||
cases := []struct {
|
||||
s MCPTaskStatus
|
||||
want bool
|
||||
}{
|
||||
{MCPTaskStatusWorking, false},
|
||||
{MCPTaskStatusInputRequired, false},
|
||||
{MCPTaskStatusCompleted, true},
|
||||
{MCPTaskStatusFailed, true},
|
||||
{MCPTaskStatusCancelled, true},
|
||||
{MCPTaskStatus("unknown"), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := tc.s.IsTerminal(); got != tc.want {
|
||||
t.Errorf("MCPTaskStatus(%q).IsTerminal() = %v, want %v", tc.s, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTaskOptionsToToolsConfig(t *testing.T) {
|
||||
called := 0
|
||||
o := mcpTaskOptions{
|
||||
perServer: map[string]MCPTaskMode{
|
||||
"alpha": MCPTaskModeAlways,
|
||||
"beta": MCPTaskModeNever,
|
||||
},
|
||||
defaultTTL: 30 * time.Second,
|
||||
pollInterval: 250 * time.Millisecond,
|
||||
maxPollInterval: 2 * time.Second,
|
||||
timeout: 5 * time.Minute,
|
||||
progress: func(p MCPTaskProgress) { called++ },
|
||||
}
|
||||
cfg := o.toToolsConfig()
|
||||
|
||||
if cfg.DefaultTTL != 30*time.Second {
|
||||
t.Errorf("DefaultTTL = %v, want 30s", cfg.DefaultTTL)
|
||||
}
|
||||
if cfg.PollInterval != 250*time.Millisecond {
|
||||
t.Errorf("PollInterval = %v, want 250ms", cfg.PollInterval)
|
||||
}
|
||||
if cfg.MaxPollInterval != 2*time.Second {
|
||||
t.Errorf("MaxPollInterval = %v, want 2s", cfg.MaxPollInterval)
|
||||
}
|
||||
if cfg.Timeout != 5*time.Minute {
|
||||
t.Errorf("Timeout = %v, want 5m", cfg.Timeout)
|
||||
}
|
||||
if cfg.PerServerMode["alpha"] != tools.MCPTaskModeAlways {
|
||||
t.Errorf("PerServerMode[alpha] = %q, want always", cfg.PerServerMode["alpha"])
|
||||
}
|
||||
if cfg.PerServerMode["beta"] != tools.MCPTaskModeNever {
|
||||
t.Errorf("PerServerMode[beta] = %q, want never", cfg.PerServerMode["beta"])
|
||||
}
|
||||
|
||||
// Progress conversion: invoking the internal handler must call our
|
||||
// SDK-level callback with the converted struct.
|
||||
if cfg.Progress == nil {
|
||||
t.Fatal("Progress callback was lost in conversion")
|
||||
}
|
||||
cfg.Progress(tools.MCPTaskProgress{
|
||||
Server: "alpha",
|
||||
TaskID: "t1",
|
||||
Status: "working",
|
||||
})
|
||||
if called != 1 {
|
||||
t.Errorf("expected SDK progress handler to be invoked once, got %d", called)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPTaskFromInternal(t *testing.T) {
|
||||
in := tools.MCPTaskInfo{
|
||||
Server: "srv",
|
||||
TaskID: "t-1",
|
||||
Status: "working",
|
||||
StatusMessage: "phase 1",
|
||||
CreatedAt: time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC),
|
||||
UpdatedAt: time.Date(2026, 5, 4, 12, 0, 1, 0, time.UTC),
|
||||
TTL: 5 * time.Minute,
|
||||
PollInterval: 500 * time.Millisecond,
|
||||
}
|
||||
out := mcpTaskFromInternal(in)
|
||||
|
||||
if out.Server != "srv" || out.TaskID != "t-1" {
|
||||
t.Errorf("identity fields not copied: %+v", out)
|
||||
}
|
||||
if out.Status != MCPTaskStatusWorking {
|
||||
t.Errorf("Status = %q, want working", out.Status)
|
||||
}
|
||||
if out.StatusMessage != "phase 1" {
|
||||
t.Errorf("StatusMessage = %q, want phase 1", out.StatusMessage)
|
||||
}
|
||||
if out.TTL != 5*time.Minute || out.PollInterval != 500*time.Millisecond {
|
||||
t.Errorf("durations not copied: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKitMCPTasksWithoutAgentReturnsError(t *testing.T) {
|
||||
// A nil/zero Kit must not panic — task RPCs should surface a clear
|
||||
// error instead. Useful for SDK consumers that try task ops on a Kit
|
||||
// constructed without MCP servers.
|
||||
var k *Kit
|
||||
ctx := t.Context()
|
||||
if _, err := k.ListMCPTasks(ctx, "any"); err == nil {
|
||||
t.Error("ListMCPTasks on nil Kit should error")
|
||||
}
|
||||
if _, err := k.GetMCPTask(ctx, "any", "id"); err == nil {
|
||||
t.Error("GetMCPTask on nil Kit should error")
|
||||
}
|
||||
if _, err := k.CancelMCPTask(ctx, "any", "id"); err == nil {
|
||||
t.Error("CancelMCPTask on nil Kit should error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubagentPropagatesMCPTaskOptions(t *testing.T) {
|
||||
// Exercises the helper Kit.Subagent uses to copy MCP task options
|
||||
// onto child Options. Calling the real helper (rather than
|
||||
// duplicating its body in the test) means any new field added to
|
||||
// the propagation list is picked up automatically by the
|
||||
// equivalence assertion below.
|
||||
parent := &Options{
|
||||
MCPTaskMode: map[string]MCPTaskMode{
|
||||
"build": MCPTaskModeAlways,
|
||||
"chat": MCPTaskModeNever,
|
||||
},
|
||||
MCPTaskTimeout: 30 * time.Minute,
|
||||
MCPTaskTTL: 45 * time.Minute,
|
||||
MCPTaskPollInterval: 750 * time.Millisecond,
|
||||
MCPTaskMaxPollInterval: 4 * time.Second,
|
||||
MCPTaskProgress: func(MCPTaskProgress) {},
|
||||
}
|
||||
|
||||
child := &Options{}
|
||||
inheritMCPTaskOptions(child, parent)
|
||||
|
||||
if child.MCPTaskMode["build"] != MCPTaskModeAlways || child.MCPTaskMode["chat"] != MCPTaskModeNever {
|
||||
t.Errorf("MCPTaskMode not propagated: got %+v", child.MCPTaskMode)
|
||||
}
|
||||
if child.MCPTaskTimeout != 30*time.Minute {
|
||||
t.Errorf("MCPTaskTimeout = %v, want 30m", child.MCPTaskTimeout)
|
||||
}
|
||||
if child.MCPTaskTTL != 45*time.Minute {
|
||||
t.Errorf("MCPTaskTTL = %v, want 45m", child.MCPTaskTTL)
|
||||
}
|
||||
if child.MCPTaskPollInterval != 750*time.Millisecond {
|
||||
t.Errorf("MCPTaskPollInterval = %v, want 750ms", child.MCPTaskPollInterval)
|
||||
}
|
||||
if child.MCPTaskMaxPollInterval != 4*time.Second {
|
||||
t.Errorf("MCPTaskMaxPollInterval = %v, want 4s", child.MCPTaskMaxPollInterval)
|
||||
}
|
||||
if child.MCPTaskProgress == nil {
|
||||
t.Error("MCPTaskProgress not propagated")
|
||||
}
|
||||
|
||||
// Nil parent is a no-op rather than a panic.
|
||||
inheritMCPTaskOptions(&Options{}, nil)
|
||||
inheritMCPTaskOptions(nil, parent)
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
"github.com/mark3labs/kit/internal/skills"
|
||||
)
|
||||
|
||||
// TestAddSkill_AddsAndDeduplicates verifies that AddSkill registers new skills
|
||||
// and that re-adding a skill with the same Name replaces the existing entry
|
||||
// rather than appending a duplicate. agent is nil in these tests; the method
|
||||
// must still mutate the in-memory state and tolerate the absent agent.
|
||||
func TestAddSkill_AddsAndDeduplicates(t *testing.T) {
|
||||
k := &Kit{basePrompt: "base"}
|
||||
|
||||
if err := k.AddSkill(&skills.Skill{Name: "alpha", Content: "first"}); err != nil {
|
||||
t.Fatalf("AddSkill alpha: %v", err)
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{Name: "beta", Content: "second"}); err != nil {
|
||||
t.Fatalf("AddSkill beta: %v", err)
|
||||
}
|
||||
got := k.GetSkills()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 skills, got %d", len(got))
|
||||
}
|
||||
|
||||
// Re-adding alpha with new content must replace, not duplicate.
|
||||
if err := k.AddSkill(&skills.Skill{Name: "alpha", Content: "replaced"}); err != nil {
|
||||
t.Fatalf("AddSkill alpha replace: %v", err)
|
||||
}
|
||||
got = k.GetSkills()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 skills after replace, got %d", len(got))
|
||||
}
|
||||
for _, s := range got {
|
||||
if s.Name == "alpha" && s.Content != "replaced" {
|
||||
t.Errorf("alpha content = %q; want %q", s.Content, "replaced")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddSkill_Validation rejects nil skills and unnamed skills with errors
|
||||
// instead of corrupting state.
|
||||
func TestAddSkill_Validation(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddSkill(nil); err == nil {
|
||||
t.Error("expected error for nil skill")
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{Content: "x"}); err == nil {
|
||||
t.Error("expected error for unnamed skill")
|
||||
}
|
||||
if got := k.GetSkills(); got != nil {
|
||||
t.Errorf("skills list mutated after invalid AddSkill calls: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveSkill verifies removal and the false return for misses.
|
||||
func TestRemoveSkill(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddSkill(&skills.Skill{Name: "alpha"})
|
||||
_ = k.AddSkill(&skills.Skill{Name: "beta"})
|
||||
|
||||
if removed := k.RemoveSkill("missing"); removed {
|
||||
t.Error("RemoveSkill(missing) = true; want false")
|
||||
}
|
||||
if removed := k.RemoveSkill("alpha"); !removed {
|
||||
t.Error("RemoveSkill(alpha) = false; want true")
|
||||
}
|
||||
got := k.GetSkills()
|
||||
if len(got) != 1 || got[0].Name != "beta" {
|
||||
t.Errorf("remaining skills = %#v; want [beta]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetSkills replaces the entire set and validates input.
|
||||
func TestSetSkills(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddSkill(&skills.Skill{Name: "alpha"})
|
||||
|
||||
err := k.SetSkills([]*skills.Skill{
|
||||
{Name: "one"},
|
||||
{Name: "two"},
|
||||
{Name: "three"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetSkills: %v", err)
|
||||
}
|
||||
if got := k.GetSkills(); len(got) != 3 {
|
||||
t.Errorf("expected 3 skills, got %d", len(got))
|
||||
}
|
||||
|
||||
// Invalid entry rejects the whole batch.
|
||||
bad := []*skills.Skill{{Name: "ok"}, nil}
|
||||
if err := k.SetSkills(bad); err == nil {
|
||||
t.Error("expected error when batch contains nil")
|
||||
}
|
||||
// State unchanged after rejected batch.
|
||||
if got := k.GetSkills(); len(got) != 3 {
|
||||
t.Errorf("skills mutated by rejected SetSkills batch: len=%d", len(got))
|
||||
}
|
||||
|
||||
// Empty slice clears.
|
||||
if err := k.SetSkills(nil); err != nil {
|
||||
t.Fatalf("SetSkills(nil): %v", err)
|
||||
}
|
||||
if got := k.GetSkills(); got != nil {
|
||||
t.Errorf("expected nil skills after clear; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndAddSkill round-trips a skill file from disk.
|
||||
func TestLoadAndAddSkill(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "demo.md")
|
||||
body := "---\nname: demo\ndescription: demo skill\n---\nhello world"
|
||||
if err := os.WriteFile(path, []byte(body), 0o644); err != nil {
|
||||
t.Fatalf("write skill file: %v", err)
|
||||
}
|
||||
|
||||
k := &Kit{}
|
||||
s, err := k.LoadAndAddSkill(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndAddSkill: %v", err)
|
||||
}
|
||||
if s.Name != "demo" {
|
||||
t.Errorf("loaded skill Name = %q; want demo", s.Name)
|
||||
}
|
||||
if got := k.GetSkills(); len(got) != 1 {
|
||||
t.Errorf("expected 1 skill registered, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFile_DeduplicatesByPath confirms identical paths replace
|
||||
// rather than duplicate.
|
||||
func TestAddContextFile_DeduplicatesByPath(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/a/AGENTS.md", Content: "v1"}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/b/AGENTS.md", Content: "vB"}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Path: "/a/AGENTS.md", Content: "v2"}); err != nil {
|
||||
t.Fatalf("AddContextFile replace: %v", err)
|
||||
}
|
||||
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("expected 2 context files, got %d", len(got))
|
||||
}
|
||||
for _, cf := range got {
|
||||
if cf.Path == "/a/AGENTS.md" && cf.Content != "v2" {
|
||||
t.Errorf("/a/AGENTS.md content = %q; want v2", cf.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFile_Validation rejects nil and unpathed entries.
|
||||
func TestAddContextFile_Validation(t *testing.T) {
|
||||
k := &Kit{}
|
||||
if err := k.AddContextFile(nil); err == nil {
|
||||
t.Error("expected error for nil context file")
|
||||
}
|
||||
if err := k.AddContextFile(&ContextFile{Content: "x"}); err == nil {
|
||||
t.Error("expected error for empty path")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveContextFile_Behavior verifies remove returns true on hit and
|
||||
// false on miss without mutating state on a miss.
|
||||
func TestRemoveContextFile_Behavior(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/a", Content: "x"})
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/b", Content: "y"})
|
||||
|
||||
if removed := k.RemoveContextFile("/missing"); removed {
|
||||
t.Error("RemoveContextFile(missing) = true; want false")
|
||||
}
|
||||
if removed := k.RemoveContextFile("/a"); !removed {
|
||||
t.Error("RemoveContextFile(/a) = false; want true")
|
||||
}
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 1 || got[0].Path != "/b" {
|
||||
t.Errorf("remaining = %#v; want [/b]", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetContextFiles replaces and validates batch input.
|
||||
func TestSetContextFiles(t *testing.T) {
|
||||
k := &Kit{}
|
||||
_ = k.AddContextFile(&ContextFile{Path: "/seed", Content: "old"})
|
||||
|
||||
err := k.SetContextFiles([]*ContextFile{
|
||||
{Path: "/x", Content: "x"},
|
||||
{Path: "/y", Content: "y"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetContextFiles: %v", err)
|
||||
}
|
||||
if got := k.GetContextFiles(); len(got) != 2 {
|
||||
t.Errorf("expected 2 context files, got %d", len(got))
|
||||
}
|
||||
|
||||
bad := []*ContextFile{{Path: "/ok"}, {Path: ""}}
|
||||
if err := k.SetContextFiles(bad); err == nil {
|
||||
t.Error("expected error for empty path in batch")
|
||||
}
|
||||
if got := k.GetContextFiles(); len(got) != 2 {
|
||||
t.Errorf("state mutated by rejected batch: len=%d", len(got))
|
||||
}
|
||||
|
||||
if err := k.SetContextFiles(nil); err != nil {
|
||||
t.Fatalf("SetContextFiles(nil): %v", err)
|
||||
}
|
||||
if got := k.GetContextFiles(); got != nil {
|
||||
t.Errorf("expected nil after clear; got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndAddContextFile reads from disk and registers the context file.
|
||||
func TestLoadAndAddContextFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "AGENTS.md")
|
||||
const content = "# Agent rules\nuse the new lint config"
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
|
||||
k := &Kit{}
|
||||
cf, err := k.LoadAndAddContextFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadAndAddContextFile: %v", err)
|
||||
}
|
||||
if !strings.HasSuffix(cf.Path, "AGENTS.md") {
|
||||
t.Errorf("Path = %q; want suffix AGENTS.md", cf.Path)
|
||||
}
|
||||
if !strings.Contains(cf.Content, "use the new lint config") {
|
||||
t.Errorf("Content missing expected body: %q", cf.Content)
|
||||
}
|
||||
got := k.GetContextFiles()
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 context file, got %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddContextFileContent registers an in-memory context blob.
|
||||
func TestAddContextFileContent(t *testing.T) {
|
||||
k := &Kit{}
|
||||
cf, err := k.AddContextFileContent("session://user-123/AGENTS.md", "always greet in French")
|
||||
if err != nil {
|
||||
t.Fatalf("AddContextFileContent: %v", err)
|
||||
}
|
||||
if cf.Path != "session://user-123/AGENTS.md" {
|
||||
t.Errorf("Path = %q", cf.Path)
|
||||
}
|
||||
if cf.Content != "always greet in French" {
|
||||
t.Errorf("Content = %q", cf.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// TestComposeSystemPrompt_IncludesSkillsAndContext verifies that runtime
|
||||
// mutations actually flow into the composed system prompt that the agent
|
||||
// would receive.
|
||||
func TestComposeSystemPrompt_IncludesSkillsAndContext(t *testing.T) {
|
||||
k := &Kit{basePrompt: "BASE-PROMPT-MARKER"}
|
||||
|
||||
if err := k.AddContextFile(&ContextFile{
|
||||
Path: "/proj/AGENTS.md",
|
||||
Content: "CTX-MARKER-OK",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddContextFile: %v", err)
|
||||
}
|
||||
if err := k.AddSkill(&skills.Skill{
|
||||
Name: "greeter",
|
||||
Description: "SKILL-DESC-MARKER",
|
||||
Content: "do greetings",
|
||||
Path: "/skills/greeter.md",
|
||||
}); err != nil {
|
||||
t.Fatalf("AddSkill: %v", err)
|
||||
}
|
||||
|
||||
composed := k.composeSystemPrompt(k.basePrompt)
|
||||
for _, want := range []string{
|
||||
"BASE-PROMPT-MARKER",
|
||||
"CTX-MARKER-OK",
|
||||
"/proj/AGENTS.md",
|
||||
"greeter",
|
||||
"SKILL-DESC-MARKER",
|
||||
} {
|
||||
if !strings.Contains(composed, want) {
|
||||
t.Errorf("composed prompt missing %q\n--- composed ---\n%s", want, composed)
|
||||
}
|
||||
}
|
||||
|
||||
// Removing the skill should remove its marker from the next composition.
|
||||
k.RemoveSkill("greeter")
|
||||
composed = k.composeSystemPrompt(k.basePrompt)
|
||||
if strings.Contains(composed, "SKILL-DESC-MARKER") {
|
||||
t.Errorf("composed prompt still contains removed skill description:\n%s", composed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRuntimeMutations_AreThreadSafe stresses the mutation API from multiple
|
||||
// goroutines to surface data races under `go test -race`.
|
||||
func TestRuntimeMutations_AreThreadSafe(t *testing.T) {
|
||||
// Use a non-nil agent so applyComposedSystemPrompt actually invokes
|
||||
// agent.SetSystemPrompt (a no-op agent is fine — we only need the
|
||||
// systemPrompt mutation + fantasy rebuild path to run concurrently so
|
||||
// -race can observe any unsynchronized writes).
|
||||
k := &Kit{basePrompt: "base", agent: &agent.Agent{}}
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 8
|
||||
const iterations = 50
|
||||
|
||||
for g := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_ = k.AddSkill(&skills.Skill{
|
||||
Name: "skill",
|
||||
Content: "content",
|
||||
})
|
||||
_ = k.AddContextFile(&ContextFile{
|
||||
Path: "/shared/AGENTS.md",
|
||||
Content: "shared",
|
||||
})
|
||||
_ = k.GetSkills()
|
||||
_ = k.GetContextFiles()
|
||||
_ = k.composeSystemPrompt("base")
|
||||
k.RemoveSkill("skill")
|
||||
k.RemoveContextFile("/shared/AGENTS.md")
|
||||
}
|
||||
}(g)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
+138
-1
@@ -139,13 +139,150 @@ func (m *Kit) ClearSkillCache() {
|
||||
}
|
||||
|
||||
// ReloadSkills re-discovers skills from disk, replacing the current set.
|
||||
// This is called by file watchers when skill files change.
|
||||
// This is called by file watchers when skill files change. The system prompt
|
||||
// is recomposed and applied to the running agent so subsequent turns see the
|
||||
// new skill set.
|
||||
func (m *Kit) ReloadSkills() error {
|
||||
newSkills, err := loadSkills(m.opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reloading skills: %w", err)
|
||||
}
|
||||
m.runtimeMu.Lock()
|
||||
m.skills = newSkills
|
||||
m.runtimeMu.Unlock()
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Runtime skill management (Issue #36)
|
||||
// ---------------------------------------------------------------------------
|
||||
//
|
||||
// The methods below let SDK consumers (chatbot hosts, multi-tenant agents)
|
||||
// mutate the active skill set after Kit construction. Each mutation recomposes
|
||||
// the system prompt and applies it to the underlying agent so the LLM sees
|
||||
// the new skill metadata on its next turn.
|
||||
|
||||
// AddSkill registers a single skill on this Kit instance. The skill object
|
||||
// can be built programmatically (no file on disk required) — only Name and
|
||||
// Content are mandatory. If a skill with the same Name is already loaded the
|
||||
// new skill replaces it. Returns an error when skill is nil or has an empty
|
||||
// name.
|
||||
//
|
||||
// After mutation the system prompt is recomposed and applied to the running
|
||||
// agent so the next turn sees the updated skill metadata. AddSkill is safe to
|
||||
// call from any goroutine.
|
||||
func (m *Kit) AddSkill(skill *Skill) error {
|
||||
if skill == nil {
|
||||
return fmt.Errorf("AddSkill: skill is nil")
|
||||
}
|
||||
if skill.Name == "" {
|
||||
return fmt.Errorf("AddSkill: skill name is required")
|
||||
}
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
replaced := false
|
||||
for i, s := range m.skills {
|
||||
if s.Name == skill.Name {
|
||||
m.skills[i] = skill
|
||||
replaced = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !replaced {
|
||||
m.skills = append(m.skills, skill)
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadAndAddSkill loads a skill from a filesystem path (single .md/.txt file)
|
||||
// and adds it via [Kit.AddSkill]. Returns the loaded skill on success.
|
||||
func (m *Kit) LoadAndAddSkill(path string) (*Skill, error) {
|
||||
s, err := skills.LoadSkill(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("LoadAndAddSkill: %w", err)
|
||||
}
|
||||
if err := m.AddSkill(s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RemoveSkill removes the named skill from this Kit instance and recomposes
|
||||
// the system prompt. Returns true when a skill with that name was found and
|
||||
// removed, false otherwise.
|
||||
func (m *Kit) RemoveSkill(name string) bool {
|
||||
m.runtimeMu.Lock()
|
||||
found := false
|
||||
for i, s := range m.skills {
|
||||
if s.Name == name {
|
||||
m.skills = append(m.skills[:i], m.skills[i+1:]...)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return true
|
||||
}
|
||||
|
||||
// SetSkills replaces the active skill set with the provided slice. Pass nil
|
||||
// or an empty slice to remove all skills. The system prompt is recomposed and
|
||||
// applied. Skills with empty names are rejected and no mutation is performed.
|
||||
func (m *Kit) SetSkills(skillList []*Skill) error {
|
||||
// Validate first so a bad input doesn't partially mutate state.
|
||||
for i, s := range skillList {
|
||||
if s == nil {
|
||||
return fmt.Errorf("SetSkills: skill at index %d is nil", i)
|
||||
}
|
||||
if s.Name == "" {
|
||||
return fmt.Errorf("SetSkills: skill at index %d has empty name", i)
|
||||
}
|
||||
}
|
||||
|
||||
copied := make([]*Skill, len(skillList))
|
||||
copy(copied, skillList)
|
||||
|
||||
m.runtimeMu.Lock()
|
||||
m.skills = copied
|
||||
m.runtimeMu.Unlock()
|
||||
|
||||
m.ClearSkillCache()
|
||||
m.applyComposedSystemPrompt()
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyComposedSystemPrompt recomposes the system prompt from the captured
|
||||
// base prompt + current contextFiles + current skills + date/cwd, and pushes
|
||||
// the result onto the underlying agent. No-op when the agent is unset (i.e.
|
||||
// during construction).
|
||||
func (m *Kit) applyComposedSystemPrompt() {
|
||||
if m.agent == nil {
|
||||
return
|
||||
}
|
||||
m.runtimeMu.RLock()
|
||||
base := m.basePrompt
|
||||
m.runtimeMu.RUnlock()
|
||||
composed := m.composeSystemPrompt(base)
|
||||
m.agent.SetSystemPrompt(composed)
|
||||
}
|
||||
|
||||
// RefreshSystemPrompt manually recomposes the system prompt from the current
|
||||
// skills and context files and applies it to the agent. Call this after a
|
||||
// batch of low-level mutations or to force a re-render of the date/cwd
|
||||
// section. Most callers don't need to invoke this directly because
|
||||
// AddSkill, RemoveSkill, SetSkills, AddContextFile, RemoveContextFile, and
|
||||
// SetContextFiles all refresh automatically.
|
||||
func (m *Kit) RefreshSystemPrompt() {
|
||||
m.applyComposedSystemPrompt()
|
||||
}
|
||||
|
||||
+145
-18
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
@@ -75,25 +76,151 @@ type Config = config.Config
|
||||
// local (stdio) and remote (StreamableHTTP/SSE) server types.
|
||||
type MCPServerConfig = config.MCPServerConfig
|
||||
|
||||
// ==== Agent Types (internal/agent/) ====
|
||||
// ==== Agent Types ====
|
||||
|
||||
// AgentConfig holds configuration options for creating a new Agent.
|
||||
type AgentConfig = agent.AgentConfig
|
||||
// DebugLogger is an SDK-owned interface for low-level debug logging from
|
||||
// the engine and MCP tool plumbing. Implementations must be safe for
|
||||
// concurrent use.
|
||||
//
|
||||
// Most consumers do not need to provide one; pass [Options.Debug] = true
|
||||
// to use the default logger. DebugLogger is exposed for the low-level
|
||||
// [AgentConfig] path and for embedders that want to route debug output
|
||||
// into their own logging system.
|
||||
type DebugLogger interface {
|
||||
// LogDebug records a single debug message. Implementations may drop,
|
||||
// buffer, or render the message however they choose.
|
||||
LogDebug(message string)
|
||||
// IsDebugEnabled reports whether debug logging is active. Callers may
|
||||
// check this before doing expensive formatting work.
|
||||
IsDebugEnabled() bool
|
||||
}
|
||||
|
||||
type (
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
ToolCallHandler = agent.ToolCallHandler
|
||||
// ToolExecutionHandler is a function type for handling tool execution start/end events.
|
||||
ToolExecutionHandler = agent.ToolExecutionHandler
|
||||
// ToolResultHandler is a function type for handling tool results.
|
||||
ToolResultHandler = agent.ToolResultHandler
|
||||
// ResponseHandler is a function type for handling LLM responses.
|
||||
ResponseHandler = agent.ResponseHandler
|
||||
// StreamingResponseHandler is a function type for handling streaming LLM responses.
|
||||
StreamingResponseHandler = agent.StreamingResponseHandler
|
||||
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
|
||||
ToolCallContentHandler = agent.ToolCallContentHandler
|
||||
)
|
||||
// AgentConfig holds configuration options for constructing an agent at the
|
||||
// SDK boundary. All fields use SDK-owned types, so consumers can populate
|
||||
// this struct without importing any underlying LLM-provider package.
|
||||
//
|
||||
// For most use cases, prefer the high-level [New] entry point with
|
||||
// [Options]. AgentConfig is exposed for advanced consumers that need
|
||||
// direct access to the lower-level agent configuration shape.
|
||||
type AgentConfig struct {
|
||||
// ModelConfig holds the LLM provider configuration. A nil value means
|
||||
// that the default provider/model resolution will be used.
|
||||
ModelConfig *ProviderConfig
|
||||
|
||||
// MCPConfig describes any MCP servers whose tools should be loaded
|
||||
// alongside core tools.
|
||||
MCPConfig *Config
|
||||
|
||||
// SystemPrompt is the system prompt sent to the LLM.
|
||||
SystemPrompt string
|
||||
|
||||
// MaxSteps caps the number of LLM iterations per turn. A value of
|
||||
// zero means no cap is applied at this layer.
|
||||
MaxSteps int
|
||||
|
||||
// StreamingEnabled controls whether the agent streams responses.
|
||||
StreamingEnabled bool
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When nil, remote MCP servers requiring OAuth will fail to connect.
|
||||
AuthHandler MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory MCPTokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, [AllTools]
|
||||
// is used. Provide a custom tool set (e.g. [CodingTools] or tools
|
||||
// built with a custom WorkDir) to scope agent capabilities.
|
||||
CoreTools []Tool
|
||||
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// Combined with empty CoreTools this yields a chat-only agent with
|
||||
// no built-in tools.
|
||||
DisableCoreTools bool
|
||||
|
||||
// ExtraTools are additional tools loaded alongside core and MCP tools.
|
||||
ExtraTools []Tool
|
||||
|
||||
// ToolWrapper, if non-nil, wraps the combined tool list before it is
|
||||
// handed to the LLM. Used to intercept tool calls or results.
|
||||
ToolWrapper func([]Tool) []Tool
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is invoked once for each MCP server
|
||||
// when its tools have finished loading (or failed). Called from a
|
||||
// background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// DebugLogger receives low-level debug output from the engine and the
|
||||
// MCP tool plumbing. Nil means no debug output is emitted at this
|
||||
// layer (regardless of [Options.Debug], which feeds the higher-level
|
||||
// [New] entry point). Pass an implementation here when wiring a custom
|
||||
// logger through the lower-level AgentConfig path.
|
||||
DebugLogger DebugLogger
|
||||
|
||||
// MCPTaskConfig configures task-aware MCP tools/call execution — mode
|
||||
// overrides, polling intervals, timeouts, and the progress handler.
|
||||
// The zero value preserves historical synchronous-only behaviour for
|
||||
// any server that didn't advertise task support during initialize.
|
||||
MCPTaskConfig MCPTaskConfig
|
||||
}
|
||||
|
||||
// toInternal converts an AgentConfig to its internal representation.
|
||||
// Slice and function fields convert without allocation because [Tool]
|
||||
// is a type alias for the underlying LLM-tool type.
|
||||
func (c *AgentConfig) toInternal() *agent.AgentConfig {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
out := &agent.AgentConfig{
|
||||
ModelConfig: c.ModelConfig,
|
||||
MCPConfig: c.MCPConfig,
|
||||
SystemPrompt: c.SystemPrompt,
|
||||
MaxSteps: c.MaxSteps,
|
||||
StreamingEnabled: c.StreamingEnabled,
|
||||
CoreTools: c.CoreTools,
|
||||
DisableCoreTools: c.DisableCoreTools,
|
||||
ExtraTools: c.ExtraTools,
|
||||
ToolWrapper: c.ToolWrapper,
|
||||
OnMCPServerLoaded: c.OnMCPServerLoaded,
|
||||
}
|
||||
if c.AuthHandler != nil {
|
||||
out.AuthHandler = c.AuthHandler
|
||||
}
|
||||
if c.TokenStoreFactory != nil {
|
||||
out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory)
|
||||
}
|
||||
if c.DebugLogger != nil {
|
||||
out.DebugLogger = c.DebugLogger
|
||||
}
|
||||
out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig()
|
||||
return out
|
||||
}
|
||||
|
||||
// ToolCallHandler is invoked when the LLM produces a tool call. It receives
|
||||
// the call ID, tool name, and the JSON-encoded input arguments.
|
||||
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
|
||||
|
||||
// ToolExecutionHandler is invoked at the start and end of tool execution.
|
||||
// The isStarting flag distinguishes the two phases.
|
||||
type ToolExecutionHandler func(toolCallID, toolName, toolArgs string, isStarting bool)
|
||||
|
||||
// ToolResultHandler is invoked after a tool finishes executing. The metadata
|
||||
// parameter carries optional structured data (e.g. file-diff info) from the
|
||||
// tool execution, JSON-encoded; it may be empty.
|
||||
type ToolResultHandler func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
|
||||
|
||||
// ResponseHandler is invoked with the final assistant text for each turn.
|
||||
type ResponseHandler func(content string)
|
||||
|
||||
// StreamingResponseHandler is invoked with each streamed text delta as it
|
||||
// arrives from the LLM.
|
||||
type StreamingResponseHandler func(content string)
|
||||
|
||||
// ToolCallContentHandler is invoked with any assistant text that accompanies
|
||||
// a tool call within the same step.
|
||||
type ToolCallContentHandler func(content string)
|
||||
|
||||
// ==== Provider & Model Types (internal/models/) ====
|
||||
|
||||
@@ -126,7 +253,7 @@ type ModelsRegistry = models.ModelsRegistry
|
||||
|
||||
// SpinnerFunc wraps a function in a loading spinner animation. Used for
|
||||
// Ollama model loading. Signature: func(fn func() error) error.
|
||||
type SpinnerFunc = agent.SpinnerFunc
|
||||
type SpinnerFunc func(fn func() error) error
|
||||
|
||||
// ==== LLM Types ====
|
||||
//
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
@@ -263,6 +264,101 @@ func TestConvertFromLLMMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentConfigNoFantasyImport verifies AgentConfig can be populated with
|
||||
// every field — including CoreTools, ExtraTools, and ToolWrapper — using
|
||||
// only SDK-owned types. This test deliberately does not import
|
||||
// "charm.land/fantasy"; the package compiling at all is the proof that the
|
||||
// SDK no longer leaks the dependency name through AgentConfig.
|
||||
//
|
||||
// Regression test for https://github.com/mark3labs/kit/issues/30.
|
||||
func TestAgentConfigNoFantasyImport(t *testing.T) {
|
||||
myTool := kit.NewTool[struct{}]("noop", "does nothing", func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("ok"), nil
|
||||
})
|
||||
|
||||
wrapperCalled := false
|
||||
cfg := kit.AgentConfig{
|
||||
SystemPrompt: "you are a tester",
|
||||
MaxSteps: 5,
|
||||
StreamingEnabled: true,
|
||||
CoreTools: []kit.Tool{myTool},
|
||||
ExtraTools: []kit.Tool{myTool},
|
||||
DisableCoreTools: false,
|
||||
ToolWrapper: func(in []kit.Tool) []kit.Tool {
|
||||
wrapperCalled = true
|
||||
return in
|
||||
},
|
||||
OnMCPServerLoaded: func(_ string, _ int, _ error) {},
|
||||
}
|
||||
|
||||
if cfg.SystemPrompt != "you are a tester" {
|
||||
t.Errorf("SystemPrompt = %q, want %q", cfg.SystemPrompt, "you are a tester")
|
||||
}
|
||||
if cfg.MaxSteps != 5 {
|
||||
t.Errorf("MaxSteps = %d, want 5", cfg.MaxSteps)
|
||||
}
|
||||
if !cfg.StreamingEnabled {
|
||||
t.Error("StreamingEnabled = false, want true")
|
||||
}
|
||||
if len(cfg.CoreTools) != 1 {
|
||||
t.Errorf("CoreTools len = %d, want 1", len(cfg.CoreTools))
|
||||
}
|
||||
if len(cfg.ExtraTools) != 1 {
|
||||
t.Errorf("ExtraTools len = %d, want 1", len(cfg.ExtraTools))
|
||||
}
|
||||
|
||||
// Exercise the wrapper to confirm the func type is usable.
|
||||
out := cfg.ToolWrapper(cfg.CoreTools)
|
||||
if !wrapperCalled {
|
||||
t.Error("ToolWrapper was not invoked")
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Errorf("wrapped tool list len = %d, want 1", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentConfigToolWrapperSignature documents that AgentConfig.ToolWrapper
|
||||
// uses kit.Tool (not the underlying provider type) in its signature.
|
||||
func TestAgentConfigToolWrapperSignature(t *testing.T) {
|
||||
//nolint:staticcheck // QF1011: explicit type asserts the SDK-side func signature.
|
||||
var _ func([]kit.Tool) []kit.Tool = func(in []kit.Tool) []kit.Tool { return in }
|
||||
cfg := kit.AgentConfig{
|
||||
ToolWrapper: func(in []kit.Tool) []kit.Tool { return in },
|
||||
}
|
||||
if cfg.ToolWrapper == nil {
|
||||
t.Fatal("ToolWrapper assignment failed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpinnerFuncSignature verifies SpinnerFunc has the documented signature
|
||||
// and can be constructed without importing any provider package.
|
||||
func TestSpinnerFuncSignature(t *testing.T) {
|
||||
called := false
|
||||
var sp kit.SpinnerFunc = func(fn func() error) error {
|
||||
called = true
|
||||
return fn()
|
||||
}
|
||||
err := sp(func() error { return nil })
|
||||
if err != nil {
|
||||
t.Errorf("SpinnerFunc returned err: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("SpinnerFunc did not invoke fn")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerTypesSignatures verifies the SDK-owned handler function types
|
||||
// can be assigned from plain function literals using only standard library
|
||||
// types in their signatures (no provider-package import required).
|
||||
func TestHandlerTypesSignatures(t *testing.T) {
|
||||
var _ kit.ToolCallHandler = func(_, _, _ string) {}
|
||||
var _ kit.ToolExecutionHandler = func(_, _, _ string, _ bool) {}
|
||||
var _ kit.ToolResultHandler = func(_, _, _, _, _ string, _ bool) {}
|
||||
var _ kit.ResponseHandler = func(_ string) {}
|
||||
var _ kit.StreamingResponseHandler = func(_ string) {}
|
||||
var _ kit.ToolCallContentHandler = func(_ string) {}
|
||||
}
|
||||
|
||||
// containsStr is a tiny helper to avoid importing strings in test.
|
||||
func containsStr(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && indexStr(s, substr) >= 0)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# Specs
|
||||
|
||||
| Spec | Status | Description |
|
||||
|------|--------|-------------|
|
||||
| [unified-bubbletea-architecture](unified-bubbletea-architecture.md) | Draft | Replace micro-program pattern with single Bubble Tea program + thick app layer |
|
||||
@@ -88,6 +88,11 @@ mcpServers:
|
||||
type: remote
|
||||
url: "https://pubmed.mcp.example.com"
|
||||
noOAuth: true # skip OAuth for public servers
|
||||
|
||||
builds:
|
||||
type: remote
|
||||
url: "https://builds.mcp.example.com"
|
||||
tasksMode: always # always run tools/call as async tasks (Phase 1 MVP)
|
||||
```
|
||||
|
||||
### MCP server fields
|
||||
@@ -101,9 +106,34 @@ mcpServers:
|
||||
| `allowedTools` | list | Whitelist of tool names to expose |
|
||||
| `excludedTools` | list | Blacklist of tool names to hide |
|
||||
| `noOAuth` | bool | Skip OAuth for this server (for public servers that don't require auth) |
|
||||
| `tasksMode` | string | When to augment `tools/call` with MCP task metadata: `auto` (default — only when the server advertises task support), `never`, or `always`. See [MCP tasks](#mcp-tasks-long-running-tools). |
|
||||
|
||||
A legacy format with `transport`, `args`, `env`, and `headers` fields is also supported.
|
||||
|
||||
### MCP tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize` so servers can respond to `tools/call` with a
|
||||
`CreateTaskResult` (a task ID + `working` status) instead of blocking until
|
||||
the operation finishes. Kit then polls `tasks/get` / `tasks/result` until the
|
||||
task reaches a terminal state, and best-effort `tasks/cancel`s on context
|
||||
cancellation.
|
||||
|
||||
This avoids HTTP/SSE proxy timeouts on long builds, deploys, and batch jobs,
|
||||
and lets the user/agent abort cleanly with Ctrl-C.
|
||||
|
||||
**Per-server `tasksMode`:**
|
||||
|
||||
| Value | Behaviour |
|
||||
|-------|-----------|
|
||||
| `auto` (default) | Augment `tools/call` with task metadata only when the server advertised `tasks/toolCalls` capability. Servers that don't advertise it run synchronously, exactly as before. |
|
||||
| `never` | Always issue `tools/call` synchronously, regardless of server capability. |
|
||||
| `always` | Always opt into task augmentation, even when the server didn't advertise the capability. The server may still respond synchronously — this just expresses client intent unconditionally. |
|
||||
|
||||
Defaults are safe: any existing MCP server keeps its previous behaviour
|
||||
bit-for-bit. SDK consumers can also override the mode programmatically and
|
||||
plug in a progress callback — see [SDK options](/sdk/options#mcp-tasks).
|
||||
|
||||
## Custom models
|
||||
|
||||
Define custom models in your `.kit.yml` for use with the `custom` provider. This is useful for self-hosted models or API endpoints not in the built-in database:
|
||||
|
||||
@@ -160,6 +160,14 @@ when embedding Kit as a library.
|
||||
| `SkillsDir` | `string` | — | Override default skills directory |
|
||||
| `NoSkills` | `bool` | `false` | Disable skill loading entirely |
|
||||
|
||||
These fields only control the **initial** skill and context-file set picked
|
||||
up by `New()`. To add, remove, or replace skills and `AGENTS.md`-style
|
||||
context files at runtime (e.g. per user or per session), use the
|
||||
`AddSkill` / `LoadAndAddSkill` / `RemoveSkill` / `SetSkills` and
|
||||
`AddContextFile` / `AddContextFileContent` / `RemoveContextFile` /
|
||||
`SetContextFiles` methods on `*kit.Kit`. See
|
||||
[Runtime skills and context files](/sdk/overview#runtime-skills-and-context-files).
|
||||
|
||||
### Compaction & MCP
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
@@ -169,6 +177,12 @@ when embedding Kit as a library.
|
||||
| `MCPAuthHandler` | `MCPAuthHandler` | — | OAuth handler for remote MCP servers. `nil` disables OAuth (servers returning 401 fail with the authorization-required error). See [MCP OAuth](#mcp-oauth-authorization) below. |
|
||||
| `MCPTokenStoreFactory` | `func` | — | Custom OAuth token storage for MCP servers (default: JSON file in `$XDG_CONFIG_HOME/.kit/mcp_tokens.json`). |
|
||||
| `InProcessMCPServers` | `map[string]*MCPServer` | — | In-process mcp-go servers (no subprocess) |
|
||||
| `MCPTaskMode` | `map[string]MCPTaskMode` | — | Per-server override for task-augmented `tools/call`. Keys are server names; missing entries fall back to the `tasksMode` field of the matching `MCPServerConfig`. See [MCP Tasks](#mcp-tasks). |
|
||||
| `MCPTaskTimeout` | `time.Duration` | `15m` | Maximum wall-clock to wait for a task to reach a terminal state. Independent of any per-call context deadline. |
|
||||
| `MCPTaskTTL` | `time.Duration` | — | TTL hint sent in `TaskParams` for every task-augmented call. Zero omits the field and lets the server pick. |
|
||||
| `MCPTaskPollInterval` | `time.Duration` | `1s` | Fallback interval between `tasks/get` requests when the server does not suggest one. |
|
||||
| `MCPTaskMaxPollInterval` | `time.Duration` | `5s` | Cap on the polling interval (a server-supplied `pollInterval` can otherwise grow without bound). |
|
||||
| `MCPTaskProgress` | `MCPTaskProgressHandler` | — | Optional callback invoked once when a task is accepted and on every observed status transition. The final invocation always carries a terminal status. |
|
||||
|
||||
## MCP OAuth Authorization
|
||||
|
||||
@@ -248,6 +262,79 @@ authorization URL and hang until the 2-minute callback timeout fires. Always
|
||||
set `OnAuthURL`, or use a higher-level wrapper like `CLIMCPAuthHandler`.
|
||||
:::
|
||||
|
||||
## MCP Tasks
|
||||
|
||||
The [MCP Tasks utility](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
turns a synchronous `tools/call` into a pollable async job: the server
|
||||
returns a `taskId` with status `working` immediately, and the client polls
|
||||
`tasks/get` / `tasks/result` until the task reaches a terminal state.
|
||||
|
||||
Kit advertises task support during `initialize` and, by default, augments
|
||||
`tools/call` with task metadata only when the server advertises
|
||||
`tasks/toolCalls` capability — so any existing MCP server keeps its previous
|
||||
synchronous behaviour bit-for-bit. Long-running tools (builds, deployments,
|
||||
batch jobs, sub-agent runs) get HTTP/SSE timeout-resistance and clean
|
||||
cancellation "for free" once both sides opt in.
|
||||
|
||||
### Per-server mode
|
||||
|
||||
```go
|
||||
import "time"
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways, // force task-augmented calls
|
||||
"chat-server": kit.MCPTaskModeNever, // force synchronous calls
|
||||
// any server not in the map honours its `tasksMode` config field
|
||||
// (default "auto")
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
| Mode | Behaviour |
|
||||
|---|---|
|
||||
| `MCPTaskModeAuto` (default) | Augment `tools/call` with `TaskParams` only when the server advertised `tasks/toolCalls`. |
|
||||
| `MCPTaskModeNever` | Always issue `tools/call` synchronously, ignoring server capability. |
|
||||
| `MCPTaskModeAlways` | Always opt in, even when the server didn't advertise the capability. The server may still respond synchronously. |
|
||||
|
||||
### Progress callbacks
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskTimeout: 15 * time.Minute, // total wall-clock cap
|
||||
MCPTaskTTL: 30 * time.Minute, // server retention hint
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s/%s: %s %s", p.Server, p.TaskID, p.Status, p.Message)
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
The handler fires once when a task is accepted and again on every observed
|
||||
status transition. The final call always carries a terminal status
|
||||
(`MCPTaskStatusCompleted`, `MCPTaskStatusFailed`, or `MCPTaskStatusCancelled`).
|
||||
Do not block in the handler — dispatch long work on a goroutine.
|
||||
|
||||
### Inspecting and cancelling tasks
|
||||
|
||||
```go
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
for _, t := range tasks {
|
||||
fmt.Printf("%s: %s (%s)\n", t.TaskID, t.Status, t.StatusMessage)
|
||||
}
|
||||
|
||||
t, _ := host.GetMCPTask(ctx, "build-server", taskID)
|
||||
if !t.Status.IsTerminal() {
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", taskID)
|
||||
}
|
||||
```
|
||||
|
||||
`Kit.ListMCPTasks`, `Kit.GetMCPTask`, and `Kit.CancelMCPTask` work against any
|
||||
loaded MCP server that advertises the corresponding capability.
|
||||
`MCPTaskStatus.IsTerminal()` is the canonical check for completion.
|
||||
|
||||
Context cancellation also works end-to-end: cancelling the `ctx` passed to a
|
||||
tool execution triggers a best-effort `tasks/cancel` before the call returns.
|
||||
|
||||
## Precedence
|
||||
|
||||
For any given generation or provider field, the effective value is resolved
|
||||
|
||||
@@ -201,6 +201,66 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
n, _ := host.AddInProcessMCPServer(ctx, "docs", mcpSrv)
|
||||
```
|
||||
|
||||
## Runtime skills and context files
|
||||
|
||||
Kit auto-discovers skills and `AGENTS.md`-style context files during `New()`,
|
||||
but multi-tenant hosts (chatbots, web services, per-user agents) often need
|
||||
to swap these **after** construction. The runtime mutators below recompose
|
||||
the system prompt and apply it to the agent so the next turn picks up the
|
||||
updated instructions — no restart, no file shuffling.
|
||||
|
||||
```go
|
||||
// Add a programmatic skill — no file on disk required.
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Project context (AGENTS.md equivalents): inline content from a DB...
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
// ...or load from disk.
|
||||
host.LoadAndAddContextFile("/etc/agents/tenant-acme.md")
|
||||
|
||||
// Remove individually when a session ends.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set in one call.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
|
||||
// Inspect current state (snapshot copies — safe to mutate).
|
||||
skills := host.GetSkills()
|
||||
ctxFiles := host.GetContextFiles()
|
||||
```
|
||||
|
||||
Key points:
|
||||
|
||||
- **Auto-refresh.** Every `Add*` / `Remove*` / `Set*` call recomposes the system
|
||||
prompt against the captured base prompt (preserving per-model overrides and
|
||||
`--system-prompt` resolution) and pushes the result onto the agent. Call
|
||||
`host.RefreshSystemPrompt()` only if you mutate state through a different
|
||||
path and need to force a re-render.
|
||||
- **Dedup keys.** Skills dedupe by `Name`; context files dedupe by `Path`.
|
||||
Re-adding the same key replaces the entry instead of appending a duplicate.
|
||||
- **Path is opaque.** `ContextFile.Path` does not have to point at a real file
|
||||
— it's only used for dedup and for the `Instructions from: <Path>` header
|
||||
injected into the prompt. URIs like `session://user-123/AGENTS.md` work fine.
|
||||
- **Thread safety.** All readers and mutators are safe to call concurrently
|
||||
from multiple goroutines; the underlying state is guarded by an internal
|
||||
`RWMutex`.
|
||||
- **Init-time options still apply.** `Options.Skills`, `Options.SkillsDir`,
|
||||
`Options.NoSkills`, and `Options.NoContextFiles` continue to control the
|
||||
startup set; the runtime API mutates from whatever state `New()` produced.
|
||||
See [SDK options](/sdk/options#skills--configuration).
|
||||
|
||||
## MCP prompts and resources
|
||||
|
||||
Query prompts and resources exposed by connected MCP servers:
|
||||
@@ -215,6 +275,33 @@ resources := host.ListMCPResources()
|
||||
content, _ := host.ReadMCPResource(ctx, "server", "file:///path")
|
||||
```
|
||||
|
||||
## MCP tasks (long-running tools)
|
||||
|
||||
Kit advertises [MCP task support](https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/tasks)
|
||||
during `initialize`, so cooperating servers can return a `taskId` immediately
|
||||
and let Kit poll `tasks/get` / `tasks/result` until the operation completes.
|
||||
This avoids HTTP/SSE proxy timeouts on long tools and gives you clean
|
||||
cancellation via context.
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTaskMode: map[string]kit.MCPTaskMode{
|
||||
"build-server": kit.MCPTaskModeAlways,
|
||||
},
|
||||
MCPTaskProgress: func(p kit.MCPTaskProgress) {
|
||||
log.Printf("%s: %s", p.TaskID, p.Status)
|
||||
},
|
||||
})
|
||||
|
||||
// Inspect / cancel in-flight tasks
|
||||
tasks, _ := host.ListMCPTasks(ctx, "build-server")
|
||||
_, _ = host.CancelMCPTask(ctx, "build-server", tasks[0].TaskID)
|
||||
```
|
||||
|
||||
Defaults to `MCPTaskModeAuto` per server, so any existing MCP server keeps
|
||||
its previous synchronous behaviour. See [SDK options → MCP Tasks](/sdk/options#mcp-tasks)
|
||||
for the full surface.
|
||||
|
||||
## Context and compaction
|
||||
|
||||
Monitor and manage context usage:
|
||||
|
||||
@@ -1968,35 +1968,42 @@ a:hover { text-decoration: underline; }
|
||||
|
||||
function renderEditBody(input, result) {
|
||||
let html = '';
|
||||
const oldText = input.old_text || '';
|
||||
const newText = input.new_text || '';
|
||||
const edits = input.edits || [];
|
||||
|
||||
html += '<div class="tool-input">';
|
||||
html += '<div class="tool-input-label">Changes</div>';
|
||||
html += '<div class="diff-display">';
|
||||
|
||||
// Render old text lines as removed
|
||||
if (oldText) {
|
||||
const oldLines = oldText.split('\n');
|
||||
for (const line of oldLines) {
|
||||
html += `<div class="diff-line removed">- ${escapeHtml(line)}</div>`;
|
||||
for (const edit of edits) {
|
||||
const oldText = edit.old_text || '';
|
||||
const newText = edit.new_text || '';
|
||||
|
||||
html += '<div class="diff-display">';
|
||||
|
||||
// Render old text lines as removed
|
||||
if (oldText) {
|
||||
const oldLines = oldText.split('\n');
|
||||
for (const line of oldLines) {
|
||||
html += `<div class="diff-line removed">- ${escapeHtml(line)}</div>`;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Separator
|
||||
if (oldText && newText) {
|
||||
html += '<div class="diff-separator">───</div>';
|
||||
}
|
||||
|
||||
// Render new text lines as added
|
||||
if (newText) {
|
||||
const newLines = newText.split('\n');
|
||||
for (const line of newLines) {
|
||||
html += `<div class="diff-line added">+ ${escapeHtml(line)}</div>`;
|
||||
// Separator
|
||||
if (oldText && newText) {
|
||||
html += '<div class="diff-separator">───</div>';
|
||||
}
|
||||
|
||||
// Render new text lines as added
|
||||
if (newText) {
|
||||
const newLines = newText.split('\n');
|
||||
for (const line of newLines) {
|
||||
html += `<div class="diff-line added">+ ${escapeHtml(line)}</div>`;
|
||||
}
|
||||
}
|
||||
|
||||
html += '</div>';
|
||||
}
|
||||
|
||||
html += '</div></div>';
|
||||
html += '</div>';
|
||||
|
||||
if (result && result.is_error) {
|
||||
html += renderToolOutput(result.content, true);
|
||||
|
||||
Reference in New Issue
Block a user