mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 394a4676a1 | |||
| 30f2bc243d | |||
| 922e246098 | |||
| 32b6376515 | |||
| cf194ff89a | |||
| 03006425fa | |||
| a322dfc59a | |||
| b1387d837e | |||
| f561f4cfd9 | |||
| 64caed57d4 | |||
| 975c30a773 | |||
| 35b9360d64 | |||
| 1b8373e133 | |||
| 1a5e4ce7c5 | |||
| 8823977612 | |||
| 24e2ea111c | |||
| 31ea80ec4f | |||
| 99f2680c2e | |||
| da7e05eb87 | |||
| a95714a22d | |||
| c4a2b0f1a3 | |||
| 2016570e2d | |||
| d557f4b870 | |||
| 65054fe3db | |||
| 97d2246375 | |||
| 1e12505741 | |||
| 6755597c9b | |||
| 45689cb30d | |||
| 78570d4188 | |||
| 7cf38b37ee |
@@ -1,268 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
diagnosticsTimeout = 20 * time.Second
|
||||
maxOutputBytes = 12_000
|
||||
)
|
||||
|
||||
type toolPathInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type lintResult struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
absPath, ok := resolveGoFilePath(e.Input, ctx.CWD)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
borderColor = "#f38ba8" // red
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: strings.Join(msgLines, "\n"),
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isEditOrWrite(toolName string) bool {
|
||||
return strings.EqualFold(toolName, "edit") || strings.EqualFold(toolName, "write")
|
||||
}
|
||||
|
||||
func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
var args toolPathInput
|
||||
if err := json.Unmarshal([]byte(inputJSON), &args); err != nil || args.Path == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
absPath := args.Path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
absPath = filepath.Join(cwd, absPath)
|
||||
}
|
||||
|
||||
if strings.ToLower(filepath.Ext(absPath)) != ".go" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "gopls", "check", absPath)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes), Err: fmt.Errorf("failed to run gopls check: %w", err)}
|
||||
}
|
||||
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes)}
|
||||
}
|
||||
|
||||
func runGolangCILint(cwd, target string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []string{
|
||||
"run",
|
||||
target,
|
||||
"--show-stats=false",
|
||||
"--output.text.path", "stdout",
|
||||
"--output.text.colors=false",
|
||||
"--output.text.print-issued-lines=false",
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "golangci-lint", args...)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
trimmed := truncate(string(out), maxOutputBytes)
|
||||
if err == nil {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
exitErr, ok := err.(*exec.ExitError)
|
||||
if ok && exitErr.ExitCode() == 1 {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
return lintResult{Output: trimmed, Err: fmt.Errorf("failed to run golangci-lint: %w", err)}
|
||||
}
|
||||
|
||||
func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
var lines []string
|
||||
if res.Err != nil {
|
||||
lines = append(lines, "ERROR: "+res.Err.Error())
|
||||
}
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return emptyFallback
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "\n... output truncated ..."
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
lintCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return goplsCount, lintCount
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
---
|
||||
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
|
||||
---
|
||||
|
||||
Perform a comprehensive **read-only** audit of this repository and report
|
||||
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
|
||||
hints from the user: $@
|
||||
|
||||
## Scope
|
||||
|
||||
If the user supplied focus hints above (a package path, a subsystem name, a
|
||||
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
|
||||
audit the whole repo, prioritising the highest-traffic packages first
|
||||
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Map the repo first**:
|
||||
- `ls` / `find` the top-level layout and list every Go package
|
||||
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
|
||||
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
|
||||
extension surface)
|
||||
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
|
||||
(e.g. "no dependency name leakage", "UI never imports extensions
|
||||
directly") — these define what counts as a violation
|
||||
|
||||
2. **Hunt for dead code**:
|
||||
- Run `go vet ./...` and capture warnings
|
||||
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
|
||||
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
|
||||
with zero non-test references inside the module are suspects
|
||||
- Check for unreferenced files, `// TODO: remove` markers, commented-out
|
||||
blocks, and `_ = x` discard patterns
|
||||
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
|
||||
them and include their output verbatim
|
||||
- **Do not delete anything** — list candidates with file:line and a
|
||||
confidence level (high / medium / low)
|
||||
|
||||
3. **Find unnecessary duplication**:
|
||||
- Look for near-identical function bodies, struct shapes, or switch
|
||||
statements across packages — `grep` for repeated function signatures
|
||||
and copy-pasted string literals / error messages is a fast first pass
|
||||
- Distinguish *coincidental* duplication (two things that happen to look
|
||||
alike but evolve independently) from *unnecessary* duplication (same
|
||||
intent, drifting in lockstep) — only flag the latter
|
||||
- For each cluster, propose where the extracted helper should live
|
||||
(which package, which file) and whether it crosses a boundary
|
||||
|
||||
4. **Check concerns / boundary violations**:
|
||||
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
|
||||
in exported signatures, and for dependency-name leakage in exported
|
||||
names / godoc (e.g. library jargon appearing in `LLM*` types)
|
||||
- **UI ↔ extensions**: grep `internal/ui/` for any import of
|
||||
`internal/extensions/` — per AGENTS.md the UI must not import
|
||||
extensions directly; converters in `cmd/root.go` should bridge them
|
||||
- **cmd vs internal**: business logic living in `cmd/` that should be
|
||||
in `internal/` (and vice versa)
|
||||
- **Cyclic risk**: packages that import each other transitively or that
|
||||
reach across sibling boundaries unexpectedly
|
||||
- For each violation, cite the offending import / signature with
|
||||
file:line
|
||||
|
||||
5. **Spot refactor opportunities**:
|
||||
- Long functions (>80 lines) doing multiple unrelated things
|
||||
- Deeply nested conditionals that flatten well with early returns
|
||||
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
|
||||
that could become helpers — but only where the wrapping context is
|
||||
genuinely uniform
|
||||
- Structs with too many fields that hint at split responsibilities
|
||||
- Exported APIs that would be cleaner with options structs / functional
|
||||
options
|
||||
- Tests that share setup boilerplate ripe for a helper
|
||||
- Flag each with: location, current shape (1-2 lines), proposed shape
|
||||
(1-2 lines), and estimated risk (low / medium / high)
|
||||
|
||||
6. **Cross-check against project rules**:
|
||||
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
|
||||
findings contradicts the documented gotchas (Yaegi interface ban,
|
||||
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
|
||||
"refactor" would reintroduce a known pitfall, drop it from the report
|
||||
and note why
|
||||
|
||||
7. **Write the report** as your final message (do not write it to disk)
|
||||
structured as:
|
||||
|
||||
```
|
||||
# Code Audit Report
|
||||
|
||||
## Summary
|
||||
- N dead-code candidates
|
||||
- N duplication clusters
|
||||
- N boundary violations
|
||||
- N refactor opportunities
|
||||
|
||||
## Dead Code
|
||||
### High confidence
|
||||
- path/to/file.go:LINE — symbol — reason
|
||||
|
||||
### Medium confidence
|
||||
...
|
||||
|
||||
## Duplication
|
||||
### Cluster: <short name>
|
||||
- Sites: file:line, file:line, …
|
||||
- Suggested home: package/path
|
||||
- Notes: …
|
||||
|
||||
## Boundary Violations
|
||||
- Rule: <which rule from AGENTS.md / project convention>
|
||||
- Offender: file:line
|
||||
- Fix sketch: …
|
||||
|
||||
## Refactor Opportunities
|
||||
- Location: file:line
|
||||
- Current: …
|
||||
- Proposed: …
|
||||
- Risk: low/medium/high
|
||||
- Why it's worth it: …
|
||||
|
||||
## Suggested Next Steps
|
||||
1. …
|
||||
2. …
|
||||
```
|
||||
|
||||
8. **End the report with an explicit reminder** that no files were modified,
|
||||
and recommend the user pick the highest-leverage items to act on
|
||||
manually (or via a follow-up `/fix-issue` style prompt) rather than
|
||||
running a sweeping refactor.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
|
||||
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
|
||||
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
|
||||
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
|
||||
straight to it
|
||||
- **Be honest about confidence**: false positives in a code audit are
|
||||
expensive — prefer "medium confidence, worth a look" over confidently
|
||||
wrong claims
|
||||
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
|
||||
anything that's purely stylistic unless it directly causes one of the
|
||||
four issue categories above
|
||||
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
|
||||
`vendor/`) and obvious third-party copies
|
||||
- **Don't propose architectural rewrites** — stay within the existing
|
||||
shape of the repo and recommend incremental, reviewable changes
|
||||
@@ -0,0 +1,473 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// extensionContextDeps groups the runtime dependencies needed to wire up
|
||||
// an extensions.Context for the interactive TUI mode.
|
||||
type extensionContextDeps struct {
|
||||
ctx context.Context
|
||||
cwd string
|
||||
modelName string
|
||||
interactive bool
|
||||
kitInstance *kit.Kit
|
||||
appInstance *app.App
|
||||
usageTracker *ui.UsageTracker
|
||||
}
|
||||
|
||||
// buildInteractiveExtensionContext returns an extensions.Context with every
|
||||
// field except Print / PrintInfo / PrintError populated. Callers must set
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// This consolidates two near-identical 400-line literal expressions that
|
||||
// previously appeared inline in runNormalMode.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
ctx := deps.ctx
|
||||
|
||||
return extensions.Context{
|
||||
CWD: deps.cwd,
|
||||
Model: deps.modelName,
|
||||
Interactive: deps.interactive,
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
},
|
||||
SendMessage: func(text string) { appInstance.Run(text) },
|
||||
CancelAndSend: func(text string) { appInstance.InterruptAndSend(text) },
|
||||
Abort: func() { appInstance.Abort() },
|
||||
IsIdle: func() bool { return !appInstance.IsBusy() },
|
||||
Compact: func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
},
|
||||
SendMultimodalMessage: func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
},
|
||||
GetSessionUsage: func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
},
|
||||
Exit: func() { appInstance.QuitFromExtension() },
|
||||
SetWidget: func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveWidget: func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetHeader: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveHeader: func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
SetFooter: func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveFooter: func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
PromptSelect: func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
},
|
||||
PromptConfirm: func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
},
|
||||
PromptInput: func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
},
|
||||
SetUIVisibility: func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
SetEditor: func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send()
|
||||
// deadlocks if called synchronously from inside BubbleTea's
|
||||
// Update() handler. All call sites use go-routines uniformly.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
ResetEditor: func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
SetEditorText: func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
},
|
||||
SetStatus: func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
RemoveStatus: func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
},
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
SetModel: func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
if usageTracker != nil {
|
||||
newProvider, newModel, _ := models.ParseModelString(modelString)
|
||||
if newProvider != "unknown" && newModel != "unknown" && newProvider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(newProvider, newModel); modelInfo != nil {
|
||||
// Check OAuth status for Anthropic models
|
||||
isOAuth := false
|
||||
if newProvider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(viper.GetString("provider-api-key"))
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
usageTracker.UpdateModelInfo(modelInfo, newProvider, isOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SuspendTUI: func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
},
|
||||
RenderMessage: func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
},
|
||||
ReloadExtensions: func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
RegisterTheme: func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
},
|
||||
SetTheme: func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
},
|
||||
ListThemes: func() []string {
|
||||
return ui.ListThemes()
|
||||
},
|
||||
ShowOverlay: func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return extbridge.SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
InjectSkillAsContext: func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
},
|
||||
InjectRawSkillAsContext: func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
+107
-804
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ import (
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -26,7 +26,7 @@ func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -84,7 +84,7 @@ func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -134,7 +134,7 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -186,7 +186,7 @@ func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -152,38 +153,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
return extbridge.SpawnSubagent(context.Background(), kitInstance, config)
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
@@ -269,40 +239,3 @@ func (s *acpSession) clearCancel() {
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,12 +9,19 @@ import (
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpExecutor is the subset of *tools.MCPToolManager that the adapter
|
||||
// actually uses. Extracted as an interface so the adapter is unit-testable
|
||||
// without constructing a full manager + connection pool.
|
||||
type mcpExecutor interface {
|
||||
ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error)
|
||||
}
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
manager *tools.MCPToolManager
|
||||
exec mcpExecutor
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
@@ -29,10 +36,26 @@ func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
//
|
||||
// MCP-side failures (JSON-RPC protocol errors, transport failures, schema
|
||||
// validation rejections from the server) are surfaced to the model as soft
|
||||
// tool errors rather than escalated to a critical agent error. This matches
|
||||
// the contract that native Kit tools follow via kit.ErrorResult(...) and
|
||||
// lets the model self-correct (e.g. retry with a fixed argument shape) or
|
||||
// give up gracefully rather than aborting the turn mid-run.
|
||||
//
|
||||
// Context cancellation is the one exception: if the caller cancelled the
|
||||
// context the turn was aborted intentionally, so we propagate the ctx error
|
||||
// to let the agent loop unwind cleanly.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
result, err := t.exec.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return fantasy.ToolResponse{}, ctxErr
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("MCP tool %q failed: %s", t.tool.Name, err.Error()),
|
||||
), nil
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
@@ -57,8 +80,8 @@ func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManage
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
manager: manager,
|
||||
tool: t,
|
||||
exec: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// stubExecutor lets each test script the (result, err) pair returned by
|
||||
// ExecuteTool. The adapter holds an mcpExecutor interface, so this is the
|
||||
// only seam the tests need.
|
||||
type stubExecutor struct {
|
||||
result *tools.MCPToolResult
|
||||
err error
|
||||
// called records the last invocation for assertion.
|
||||
called bool
|
||||
name string
|
||||
input string
|
||||
}
|
||||
|
||||
func (s *stubExecutor) ExecuteTool(_ context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error) {
|
||||
s.called = true
|
||||
s.name = prefixedName
|
||||
s.input = inputJSON
|
||||
return s.result, s.err
|
||||
}
|
||||
|
||||
func newMCPAgentTool(exec mcpExecutor, name string) *mcpAgentTool {
|
||||
return &mcpAgentTool{
|
||||
tool: tools.MCPTool{Name: name},
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager-side Go errors (JSON-RPC protocol errors, transport failures,
|
||||
// schema validation rejections from the MCP server) must be surfaced to
|
||||
// the model as soft tool errors so the agent loop can keep going. Aborting
|
||||
// the turn would discard all prior tool results — see issue #N.
|
||||
func TestMCPAgentTool_RPCErrorBecomesSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
err: errors.New("MCP error -32602: Invalid params: missing field \"task\""),
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "pubmed__search",
|
||||
Input: `{"query":"foo"}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft), got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "pubmed__search") {
|
||||
t.Errorf("expected tool name in error content, got %q", resp.Content)
|
||||
}
|
||||
if !strings.Contains(resp.Content, "-32602") {
|
||||
t.Errorf("expected underlying error text in content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Context cancellation is the one error that must remain critical: it
|
||||
// means the caller intentionally aborted, and the agent loop needs to
|
||||
// unwind cleanly rather than burning more steps.
|
||||
func TestMCPAgentTool_CtxCancelStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
// Real managers typically return ctx.Err() (or a wrapper) when the
|
||||
// context is cancelled mid-call.
|
||||
err: context.Canceled,
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline-exceeded behaves the same as cancellation: ctx.Err() is
|
||||
// non-nil, so the adapter must propagate the critical error rather than
|
||||
// converting the executor's error into a soft response.
|
||||
func TestMCPAgentTool_CtxDeadlineStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{err: context.DeadlineExceeded}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side soft errors (CallToolResult{ isError: true }) must continue
|
||||
// to flow through as soft errors — this was the existing behavior and
|
||||
// must not regress.
|
||||
func TestMCPAgentTool_ServerIsErrorRemainsSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: true,
|
||||
Content: "search service is rate limited; try again in 30s",
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if resp.Content != "search service is rate limited; try again in 30s" {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Happy path: ordinary successful tool result is passed through unchanged.
|
||||
func TestMCPAgentTool_SuccessIsPassthrough(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: false,
|
||||
Content: `{"hits":3}`,
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected IsError=false")
|
||||
}
|
||||
if resp.Content != `{"hits":3}` {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
+132
-33
@@ -70,14 +70,24 @@ type App struct {
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
// widgetUpdatePending is set to true while a WidgetUpdateEvent burst is
|
||||
// being coalesced. The leading edge fires immediately; subsequent calls
|
||||
// within the debounce window set widgetUpdateTrailing so a final event
|
||||
// is delivered with the latest runner state at the end of the window.
|
||||
// Without the trailing send, a rapid SetWidget→RemoveWidget pair (e.g.
|
||||
// SubagentEnd pushing a final frame then removing the widget) would let
|
||||
// the second call get silently dropped, leaving the TUI's layout stuck
|
||||
// on the pre-removal widget height — visible as empty rows below the
|
||||
// status bar after the widget disappears.
|
||||
widgetUpdatePending atomic.Bool
|
||||
widgetUpdateTrailing atomic.Bool
|
||||
|
||||
// steerDrainFn is the test seam used by releaseBusyAfterCompact to pull
|
||||
// any steer messages that arrived during compaction. In production it is
|
||||
// nil and the helper falls back to a.opts.Kit.DrainSteer(); tests that
|
||||
// need to exercise the steer-drain path without standing up a full
|
||||
// *kit.Kit can set this field directly to inject fake items.
|
||||
steerDrainFn func() []queueItem
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
@@ -356,6 +366,10 @@ func (a *App) AddContextMessage(text string) {
|
||||
// tea.Program. customInstructions is optional text appended to the summary
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
//
|
||||
// Any prompts queued via Run/RunWithFiles or steering messages injected via
|
||||
// Steer/SteerWithFiles while compaction is running are flushed automatically
|
||||
// once compaction completes (see releaseBusyAfterCompact).
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Lock()
|
||||
@@ -377,11 +391,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -420,6 +430,9 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
//
|
||||
// Like CompactConversation, any prompts/steer messages received during
|
||||
// compaction are flushed automatically once compaction finishes.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
@@ -440,11 +453,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -489,6 +498,81 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseBusyAfterCompact is the deferred tail that runs at the end of every
|
||||
// compaction goroutine (success, error, or panic-after-recover paths). It
|
||||
// flips a.busy back to false, but before doing so it checks whether any
|
||||
// prompts piled up while compaction was running:
|
||||
//
|
||||
// - Run/RunWithFiles append to a.queue when a.busy is set.
|
||||
// - Steer/SteerWithFiles deposit messages into the SDK steer channel via
|
||||
// Kit.InjectSteerWithFiles when a.busy is set.
|
||||
//
|
||||
// Without this hand-off the queue would sit idle until the user submits
|
||||
// another prompt — see issue #27. If we find anything pending we keep busy
|
||||
// set, splice the steer messages to the front of the queue, and start a
|
||||
// fresh drainQueue goroutine to deliver them as a single batched turn.
|
||||
func (a *App) releaseBusyAfterCompact() {
|
||||
// Pull steer messages outside the app mutex; DrainSteer takes its own
|
||||
// internal lock and we don't want to nest the two. The test seam
|
||||
// (a.steerDrainFn) takes precedence so unit tests can inject fake
|
||||
// steer items without a real *kit.Kit.
|
||||
var steerItems []queueItem
|
||||
switch {
|
||||
case a.steerDrainFn != nil:
|
||||
steerItems = a.steerDrainFn()
|
||||
case a.opts.Kit != nil:
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
steerItems = make([]queueItem, len(leftover))
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
// If the app was closed while compaction was running, drop everything
|
||||
// and just clear busy. Run/Steer would have rejected new items already
|
||||
// after Close(), but this guards against in-flight items that slipped
|
||||
// in just before closed was set.
|
||||
if a.closed {
|
||||
a.queue = a.queue[:0]
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Combine steer-channel items (front) with the in-memory queue (back).
|
||||
// Steer messages are placed first so they retain their "act now"
|
||||
// semantics relative to ordinary queued prompts that arrived later.
|
||||
pending := append(steerItems, a.queue...)
|
||||
a.queue = a.queue[:0]
|
||||
|
||||
if len(pending) == 0 {
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hand off to drainQueue: it will pick up the first item directly and
|
||||
// scoop the rest from a.queue on its first iteration.
|
||||
first := pending[0]
|
||||
if len(pending) > 1 {
|
||||
a.queue = append(a.queue, pending[1:]...)
|
||||
}
|
||||
// Stay busy across the goroutine swap.
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Notify the UI that steer-channel messages were consumed so the
|
||||
// steering badge can clear; ordinary queued prompts will be reflected
|
||||
// by the QueueUpdatedEvent that drainQueue emits as it picks them up.
|
||||
if len(steerItems) > 0 {
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
|
||||
go a.drainQueue(first)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1076,32 +1160,47 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
// Coalescing (leading + trailing edge): the first call in an idle period
|
||||
// fires immediately for responsiveness. Subsequent calls within a ~16 ms
|
||||
// debounce window are batched into a single trailing event delivered at
|
||||
// the end of the window. The trailing send is essential for correctness:
|
||||
// extensions routinely make tight SetWidget→RemoveWidget pairs (e.g. on
|
||||
// SubagentEnd) and silently dropping the second call would leave the TUI's
|
||||
// layout stuck on stale widget dimensions until some other event happens
|
||||
// to trigger a re-render.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
// A leading-edge event is already in flight — mark that the runner
|
||||
// state has changed again so the trailing send below picks it up.
|
||||
a.widgetUpdateTrailing.Store(true)
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
if prog == nil {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
return
|
||||
}
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
// If any extra calls came in during the debounce window, deliver
|
||||
// one trailing event so the TUI sees the latest widget state. We
|
||||
// swap-and-test instead of plain-load so concurrent calls after
|
||||
// the trailing send still race correctly with the pending reset.
|
||||
if a.widgetUpdateTrailing.Swap(false) {
|
||||
a.mu.Lock()
|
||||
p := a.program
|
||||
a.mu.Unlock()
|
||||
if p != nil {
|
||||
p.Send(WidgetUpdateEvent{})
|
||||
}
|
||||
}
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
|
||||
@@ -763,3 +763,209 @@ func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// releaseBusyAfterCompact (issue #27)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestReleaseBusyAfterCompact_flushesQueuedMessages is a regression test for
|
||||
// issue #27: messages queued via Run() while /compact is running used to sit
|
||||
// in a.queue indefinitely until the user typed another prompt. After the fix
|
||||
// the deferred releaseBusyAfterCompact tail picks up any pending items and
|
||||
// dispatches drainQueue automatically.
|
||||
//
|
||||
// We simulate the compaction completion path directly (bypassing the SDK)
|
||||
// by toggling busy=true, populating the queue exactly as Run() would have
|
||||
// during compaction, and then invoking releaseBusyAfterCompact.
|
||||
func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("compacted then drained"), nil
|
||||
},
|
||||
)
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate the state at the start of the compaction tail: busy is set
|
||||
// and a couple of prompts have piled up in the queue while we were
|
||||
// summarising. (Run() would have appended them and returned a queue
|
||||
// length > 0 to the caller.)
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued during compact #1"},
|
||||
queueItem{Prompt: "queued during compact #2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
// Invoke the deferred tail directly. It should kick off drainQueue.
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// drainQueue runs in a goroutine. Wait for the app to come back to idle.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after releaseBusyAfterCompact: queue not drained")
|
||||
}
|
||||
|
||||
// Wait for any in-flight goroutine to finish before reading state.
|
||||
app.wg.Wait()
|
||||
|
||||
if got := app.QueueLength(); got != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d", got)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatalf("expected stub PromptFunc to fire at least once after compact, got %d calls", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_idleWhenQueueEmpty verifies that with no
|
||||
// pending messages the helper just clears busy and does NOT spawn a
|
||||
// drainQueue goroutine (no spurious agent turn).
|
||||
func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false after releaseBusyAfterCompact with empty queue")
|
||||
}
|
||||
|
||||
// Give any rogue goroutine a moment to (incorrectly) call PromptFunc.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls when queue empty, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue exercises the SDK
|
||||
// steer-drain branch of releaseBusyAfterCompact (issue #27 follow-up).
|
||||
//
|
||||
// Production wires a.opts.Kit.DrainSteer() to pull messages that arrived via
|
||||
// Steer/SteerWithFiles during compaction, but Options.Kit is *kit.Kit (a
|
||||
// concrete struct) so unit tests cannot stand up a real instance without a
|
||||
// full LLM backend. The test uses the unexported steerDrainFn seam to inject
|
||||
// fake steer items, then asserts that:
|
||||
//
|
||||
// - Steer items are dispatched ahead of any prompts that piled up in
|
||||
// a.queue (steer retains "act now" priority over ordinary queued
|
||||
// prompts), and
|
||||
// - the helper still hands off to drainQueue so the steer item actually
|
||||
// fires (the previous behaviour left them stranded — see #27).
|
||||
func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
|
||||
var pmu sync.Mutex
|
||||
var firstPrompt string
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("steer dispatched"), nil
|
||||
},
|
||||
)
|
||||
// Wrap PromptFunc so we can capture the prompt text the stub receives
|
||||
// (newStubWithFuncs's fns ignore prompt; we need it to verify ordering).
|
||||
capturingPrompt := func(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
pmu.Lock()
|
||||
if firstPrompt == "" {
|
||||
firstPrompt = prompt
|
||||
}
|
||||
pmu.Unlock()
|
||||
return stub.fn(ctx, prompt)
|
||||
}
|
||||
app := New(Options{PromptFunc: capturingPrompt}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Inject fake steer items via the test seam. In production the same
|
||||
// items would have been delivered through Kit.InjectSteerWithFiles
|
||||
// during /compact and pulled by DrainSteer here.
|
||||
app.steerDrainFn = func() []queueItem {
|
||||
return []queueItem{
|
||||
{Prompt: "steer-1"},
|
||||
{Prompt: "steer-2"},
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate the state at the end of compaction: busy is set and a couple
|
||||
// of regular Run() prompts have piled up after the steer messages.
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued-1"},
|
||||
queueItem{Prompt: "queued-2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// Wait for the dispatched batch to complete.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after steer-spliced releaseBusyAfterCompact")
|
||||
}
|
||||
app.wg.Wait()
|
||||
|
||||
// drainQueue picks up `first` directly and batches the rest. With
|
||||
// PromptFunc set, executeBatch invokes us with items[0] only — that
|
||||
// item must be the first steer message, proving steer items were
|
||||
// spliced ahead of the previously queued prompts.
|
||||
pmu.Lock()
|
||||
got := firstPrompt
|
||||
pmu.Unlock()
|
||||
if got != "steer-1" {
|
||||
t.Fatalf("expected first dispatched prompt to be steer item %q (steer items must come before queued prompts), got %q",
|
||||
"steer-1", got)
|
||||
}
|
||||
|
||||
// Queue should be fully drained and PromptFunc must have actually fired.
|
||||
if n := app.QueueLength(); n != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d entries", n)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatal("expected stub PromptFunc to fire at least once after splice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_dropsQueueWhenClosed verifies that if the app
|
||||
// was closed during compaction the helper discards any pending items rather
|
||||
// than spawning drainQueue against a torn-down App.
|
||||
func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
|
||||
app.closed = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
qLen := len(app.queue)
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false even when closed")
|
||||
}
|
||||
if qLen != 0 {
|
||||
t.Fatalf("expected queue cleared on closed app, got %d entries", qLen)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
// Package extbridge wires the public Kit SDK to the internal extensions
|
||||
// package. It exists so that cmd/ and internal/acpserver/ don't both
|
||||
// reimplement the same SDK→extension event/subagent conversions.
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
|
||||
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
|
||||
// (Type=="") for events that don't map to anything useful — callers should
|
||||
// drop those.
|
||||
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
|
||||
// the result/events back into the extension-facing types. The returned
|
||||
// handle is always nil — the SDK path runs synchronously and does not
|
||||
// expose a separate process handle. Callers that need non-blocking
|
||||
// behaviour should run this in their own goroutine.
|
||||
//
|
||||
// This function consolidates the previously-duplicated wiring in
|
||||
// cmd/root.go (interactive + runtime contexts) and
|
||||
// internal/acpserver/session.go.
|
||||
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: cfg.Prompt,
|
||||
Model: cfg.Model,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := SDKEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
cfg.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := k.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
}
|
||||
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
|
||||
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
|
||||
// saveManifest round-trip against a temp directory, ensuring an absent
|
||||
// manifest loads as empty and a saved manifest reads back identically.
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
err = installer.saveManifest(manifest, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
t.Fatalf("saveManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was written to expected path
|
||||
if _, err := os.Stat(manifestPath); err != nil {
|
||||
t.Fatalf("manifest file not created: %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
loaded, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
|
||||
// followed by removeFromManifest leaves the manifest in its original
|
||||
// (empty) state, using a temp-directory installer scope.
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
|
||||
t.Fatalf("addToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
|
||||
t.Fatalf("removeFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
manifest, err = installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindInManifest writes a manifest file directly to the path
|
||||
// resolved by the package-level manifestPathForScope helper and then
|
||||
// confirms FindInManifest locates the entry by identity (and returns
|
||||
// nil for a non-existent identity).
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
t.Setenv("XDG_DATA_HOME", tempDir)
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
// Write a manifest entry directly via the package-level path resolver
|
||||
// so FindInManifest (which uses manifestPathForScope) can read it back.
|
||||
manifestPath := manifestPathForScope(ScopeGlobal)
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
|
||||
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
|
||||
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
|
||||
@@ -2,22 +2,15 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package models
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
+39
-35
@@ -36,15 +36,17 @@ type Diagnostic struct {
|
||||
}
|
||||
|
||||
// LoadAll discovers and loads all prompt templates from standard locations
|
||||
// and any extra paths. Templates are loaded in order of precedence (lowest
|
||||
// to highest), with later templates overriding earlier ones of the same name.
|
||||
// and any extra paths. Templates are loaded in order of precedence (highest
|
||||
// to lowest); the first source to define a given name wins, later definitions
|
||||
// of the same name are dropped with a diagnostic.
|
||||
//
|
||||
// Discovery paths searched in order:
|
||||
// 1. Default templates (if IncludeDefaults)
|
||||
// 2. ~/.kit/prompts/ (global user templates)
|
||||
// 3. .kit/prompts/ (project-local templates)
|
||||
// 4. ConfigPaths (from configuration)
|
||||
// 5. ExtraPaths (explicit paths, highest precedence)
|
||||
// 2. ~/.kit/prompts/ (legacy global)
|
||||
// 3. $XDG_CONFIG_HOME/kit/prompts/ (XDG global, default ~/.config/kit/prompts/)
|
||||
// 4. <cwd>/.kit/prompts/ (project-local templates)
|
||||
// 5. ConfigPaths (from configuration)
|
||||
// 6. ExtraPaths (explicit paths, lowest precedence)
|
||||
func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
if opts.Cwd == "" {
|
||||
opts.Cwd, _ = os.Getwd()
|
||||
@@ -88,13 +90,21 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
addTemplates(defaults, "default")
|
||||
}
|
||||
|
||||
// 2. Global user templates: ~/.kit/prompts/
|
||||
globalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(globalDir); err == nil {
|
||||
// 2. Legacy global user templates: ~/.kit/prompts/
|
||||
legacyGlobalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(legacyGlobalDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
|
||||
// 3. Project-local templates: .kit/prompts/
|
||||
// 3. XDG global user templates: $XDG_CONFIG_HOME/kit/prompts/
|
||||
// Default: ~/.config/kit/prompts/. Aligns with extensions and skills.
|
||||
if xdgDir := GlobalDir(); xdgDir != "" && xdgDir != legacyGlobalDir {
|
||||
if templates, err := LoadFromDir(xdgDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Project-local templates: .kit/prompts/
|
||||
localDir := filepath.Join(opts.Cwd, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(localDir); err == nil {
|
||||
addTemplates(templates, "local")
|
||||
@@ -179,31 +189,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
|
||||
// It returns the deduplicated list and diagnostics for any collisions.
|
||||
// This is a standalone function for when you need to deduplicate an existing list.
|
||||
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
|
||||
seen := make(map[string]*PromptTemplate)
|
||||
var result []*PromptTemplate
|
||||
var diagnostics []Diagnostic
|
||||
|
||||
for _, tpl := range templates {
|
||||
if existing, ok := seen[tpl.Name]; ok {
|
||||
diagnostics = append(diagnostics, Diagnostic{
|
||||
Name: tpl.Name,
|
||||
KeptPath: existing.FilePath,
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: "duplicate template name (first-match-wins)",
|
||||
})
|
||||
} else {
|
||||
seen[tpl.Name] = tpl
|
||||
result = append(result, tpl)
|
||||
}
|
||||
}
|
||||
|
||||
return result, diagnostics
|
||||
}
|
||||
|
||||
// loadDefaultTemplates returns the built-in default templates.
|
||||
// These are embedded templates that ship with Kit.
|
||||
func loadDefaultTemplates() []*PromptTemplate {
|
||||
@@ -211,3 +196,22 @@ func loadDefaultTemplates() []*PromptTemplate {
|
||||
// For now, return an empty slice - users can define their own templates
|
||||
return nil
|
||||
}
|
||||
|
||||
// GlobalDir returns the XDG-aligned global prompts directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/prompts/. Returns an empty
|
||||
// string if the user's home directory cannot be resolved.
|
||||
//
|
||||
// This is the canonical location for user-wide prompt templates and aligns
|
||||
// with the discovery paths used for extensions ($XDG_CONFIG_HOME/kit/extensions/)
|
||||
// and skills ($XDG_CONFIG_HOME/kit/skills/).
|
||||
func GlobalDir() string {
|
||||
base := os.Getenv("XDG_CONFIG_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(base, "kit", "prompts")
|
||||
}
|
||||
|
||||
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
|
||||
// Kept messages must appear BEFORE post-compaction messages so the LLM
|
||||
// sees the conversation in chronological order. Otherwise the latest
|
||||
// post-compaction user message would be followed by an older kept user
|
||||
// message, breaking user/assistant alternation and causing the model to
|
||||
// respond as if the post-compaction turn never happened.
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
// Verify order: summary, M3 (kept), M4 (new)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
|
||||
@@ -755,9 +755,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
// If there is a compaction, inject the summary first, then the
|
||||
// preserved "kept" messages (chronologically before the compaction),
|
||||
// then the post-compaction messages (chronologically after).
|
||||
//
|
||||
// Order matters: the kept messages must come BEFORE the post-compaction
|
||||
// branch so the LLM sees the conversation in chronological order. If the
|
||||
// kept messages were appended last, the latest user message in the
|
||||
// current branch would be followed by an older kept user message,
|
||||
// breaking the strict user/assistant alternation that providers expect
|
||||
// and causing the model to respond as if the previous turn never
|
||||
// happened.
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -768,49 +776,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
})
|
||||
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
// Step 1: collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not on the current branch (the compaction entry is a
|
||||
// new root with no parent), so we iterate tm.entries in append order
|
||||
// and stop when we reach the compaction entry itself.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
@@ -825,13 +794,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
// Stop when we reach the compaction entry itself; messages
|
||||
// after it are collected from the branch walk below.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -860,6 +828,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: collect entries on the current branch after the compaction
|
||||
// entry (these are post-compaction messages). The compaction entry
|
||||
// itself is skipped — its summary was already injected above.
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Summary already injected above.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
@@ -1030,44 +1034,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
// If there's a compaction, we collect IDs in the same order as
|
||||
// BuildContext: [summary placeholder, kept messages, post-compaction
|
||||
// messages]. This ordering must stay in sync with BuildContext so a
|
||||
// cut-point index can be mapped back to the correct entry ID.
|
||||
if lastCompaction != nil {
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
|
||||
// Iterate tm.entries in append order and stop at the compaction
|
||||
// entry to avoid double-counting post-compaction messages.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
@@ -1076,7 +1058,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
@@ -1100,6 +1081,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: IDs of entries after the compaction entry on the current
|
||||
// branch (post-compaction messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
|
||||
@@ -28,15 +28,6 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoBorder returns a renderingOption that disables all borders on the
|
||||
// block, rendering content with only padding.
|
||||
func WithNoBorder() renderingOption {
|
||||
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackground returns a renderingOption that sets the background color
|
||||
// for the entire block. The color parameter accepts any color.Color value,
|
||||
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
|
||||
func WithBackground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.background = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
}
|
||||
}
|
||||
|
||||
// renderContentBlock renders content with configurable styling options
|
||||
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
|
||||
renderer := &blockRenderer{
|
||||
|
||||
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
|
||||
@@ -161,6 +161,12 @@ var SlashCommands = []SlashCommand{
|
||||
Category: "Navigation",
|
||||
Aliases: []string{"/r"},
|
||||
},
|
||||
{
|
||||
Name: "/copy",
|
||||
Description: "Copy the last message to the system clipboard",
|
||||
Category: "System",
|
||||
Aliases: []string{"/cp"},
|
||||
},
|
||||
{
|
||||
Name: "/export",
|
||||
Description: "Export session (JSONL by default, or /export path.jsonl)",
|
||||
@@ -199,18 +205,6 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
names = append(names, cmd.Name)
|
||||
names = append(names, cmd.Aliases...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ExtensionCommand is a slash command registered by an extension. Unlike
|
||||
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
|
||||
// extension commands carry their own Execute callback.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
}
|
||||
|
||||
// Format the message to include all the debug info in a structured way
|
||||
var formattedMessage string
|
||||
|
||||
// Check if this is a multi-line debug output (like connection info)
|
||||
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
|
||||
// Extract the tag and content
|
||||
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
|
||||
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Add appropriate emoji based on the message content
|
||||
if strings.Contains(content, "Creating new connection") {
|
||||
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
|
||||
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
|
||||
} else if strings.Contains(content, "Reusing") {
|
||||
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
|
||||
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
|
||||
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
|
||||
} else if strings.Contains(content, "closed") {
|
||||
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Failed to close") {
|
||||
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
|
||||
} else {
|
||||
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
|
||||
// Use the CLI's debug message rendering
|
||||
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
@@ -25,17 +25,6 @@ type TextMessageItem struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
@@ -316,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
+225
-15
@@ -129,8 +129,18 @@ type AppController interface {
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
// [Skills] section. Built by the CLI layer from the SDK's []*kit.Skill.
|
||||
type SkillItem struct {
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Source string // "project" or "user" (global).
|
||||
Description string // Short summary used in autocomplete and help.
|
||||
}
|
||||
|
||||
// ExtensionItem holds display metadata about a loaded extension for the
|
||||
// startup [Extensions] section. Built by the CLI layer from the SDK's
|
||||
// []kit.ExtensionInfo.
|
||||
type ExtensionItem struct {
|
||||
Name string // Extension display name (filename without .go extension).
|
||||
Path string // Absolute path to the extension's .go file.
|
||||
Source string // "project" or "user" (global).
|
||||
}
|
||||
|
||||
@@ -363,6 +373,16 @@ type AppModelOptions struct {
|
||||
// watcher detects changes. May be nil if skill hot-reload is not needed.
|
||||
GetSkillItems func() []SkillItem
|
||||
|
||||
// ExtensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section. Each entry shows the filename of an extension that was
|
||||
// discovered and loaded (global, project-local, or explicit).
|
||||
ExtensionItems []ExtensionItem
|
||||
|
||||
// GetExtensionItems, if non-nil, returns the current extension items.
|
||||
// Called on extension hot-reload to refresh the list. May be nil if no
|
||||
// extensions are loaded.
|
||||
GetExtensionItems func() []ExtensionItem
|
||||
|
||||
// MCPToolCount is the number of tools loaded from external MCP servers.
|
||||
MCPToolCount int
|
||||
|
||||
@@ -607,6 +627,14 @@ type AppModel struct {
|
||||
// skill list after content hot-reload. May be nil.
|
||||
getSkillItems func() []SkillItem
|
||||
|
||||
// extensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section (filenames only).
|
||||
extensionItems []ExtensionItem
|
||||
|
||||
// getExtensionItems returns the current extension items. Used to refresh
|
||||
// the list after extension hot-reload. May be nil.
|
||||
getExtensionItems func() []ExtensionItem
|
||||
|
||||
// mcpToolCount and extensionToolCount track tool counts by source for
|
||||
// the startup info display.
|
||||
mcpToolCount int
|
||||
@@ -860,6 +888,8 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
m.contextPaths = opts.ContextPaths
|
||||
m.skillItems = opts.SkillItems
|
||||
m.getSkillItems = opts.GetSkillItems
|
||||
m.extensionItems = opts.ExtensionItems
|
||||
m.getExtensionItems = opts.GetExtensionItems
|
||||
m.mcpToolCount = opts.MCPToolCount
|
||||
m.extensionToolCount = opts.ExtensionToolCount
|
||||
m.startupExtensionMessages = opts.StartupExtensionMessages
|
||||
@@ -912,6 +942,20 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge skills into autocomplete as /skill:<name> commands. Skills accept
|
||||
// optional trailing args, so HasArgs is true — Enter populates the input
|
||||
// with "/skill:name " rather than auto-submitting.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.SkillItems) > 0 {
|
||||
for _, s := range opts.SkillItems {
|
||||
ic.commands = append(ic.commands, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Merge MCP prompts into autocomplete as /<server>:<prompt> commands.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.MCPPrompts) > 0 {
|
||||
for _, p := range opts.MCPPrompts {
|
||||
@@ -1014,8 +1058,21 @@ func (m *AppModel) AddStartupMessageToScrollList() {
|
||||
pairs = append(pairs, [2]string{"Skills", strings.Join(names, ", ")})
|
||||
}
|
||||
|
||||
// Extension tool count (only shown when > 0).
|
||||
if m.extensionToolCount > 0 {
|
||||
// Extensions — listed by filename. Each extension shows its basename
|
||||
// without the .go suffix, matching the [Skills] section's style.
|
||||
if len(m.extensionItems) > 0 {
|
||||
names := make([]string, len(m.extensionItems))
|
||||
for i, ei := range m.extensionItems {
|
||||
names[i] = ei.Name
|
||||
}
|
||||
value := strings.Join(names, ", ")
|
||||
if m.extensionToolCount > 0 {
|
||||
value += fmt.Sprintf(" (%d tools)", m.extensionToolCount)
|
||||
}
|
||||
pairs = append(pairs, [2]string{"Extensions", value})
|
||||
} else if m.extensionToolCount > 0 {
|
||||
// Fallback: tool count only (extensions registered tools but the CLI
|
||||
// did not provide ExtensionItems for some reason).
|
||||
pairs = append(pairs, [2]string{"Extensions", fmt.Sprintf("%d tools", m.extensionToolCount)})
|
||||
}
|
||||
|
||||
@@ -1251,7 +1308,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.scrollList.autoScroll = false
|
||||
case tea.MouseWheelDown:
|
||||
m.scrollList.ScrollBy(scrollLines)
|
||||
if m.scrollList.AtBottom() {
|
||||
// Only re-enable auto-scroll when the user is not actively
|
||||
// selecting text. Otherwise a wheel-down during a drag-select
|
||||
// would re-arm GotoBottom on the next stream chunk, shifting
|
||||
// the highlighted row out from under the cursor.
|
||||
if m.scrollList.AtBottom() && !m.scrollList.IsMouseDown() {
|
||||
m.scrollList.autoScroll = true
|
||||
}
|
||||
}
|
||||
@@ -1259,9 +1320,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// ── Mouse click selection (crush-style character-level) ──────────────────
|
||||
case tea.MouseClickMsg:
|
||||
if msg.Button == tea.MouseLeft {
|
||||
// Calculate viewport-relative coordinates.
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
// Compute the scrollback origin from the current frame's layout
|
||||
// rather than the stale cached value from the previous View().
|
||||
// scrollbackYOffset/scrollList.height are only refreshed inside
|
||||
// View() and lag behind any state change that resized the header
|
||||
// (extension widgets, warning rows, etc.) since the last render.
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
// Clear any previous selection on a new click.
|
||||
// HandleMouseDown will set up new selection state.
|
||||
if m.scrollList.HandleMouseDown(msg.X, viewY) {
|
||||
@@ -1272,8 +1338,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Mouse motion/drag for character-level selection ──────────────────────
|
||||
case tea.MouseMotionMsg:
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
m.scrollList.HandleMouseDrag(msg.X, viewY)
|
||||
}
|
||||
|
||||
@@ -1603,10 +1670,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Cancel timer expired ─────────────────────────────────────────────────
|
||||
case uicore.CancelTimerExpiredMsg:
|
||||
if m.canceling {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.canceling = false
|
||||
|
||||
// ── Ctrl+C reset timer expired ────────────────────────────────────────────
|
||||
case uicore.CtrlCResetMsg:
|
||||
if m.ctrlCPressedOnce {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.ctrlCPressedOnce = false
|
||||
|
||||
// ── Input submitted ──────────────────────────────────────────────────────
|
||||
@@ -2328,6 +2401,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
} else {
|
||||
m.refreshExtensionItems()
|
||||
m.printSystemMessage("Extensions reloaded.")
|
||||
}
|
||||
|
||||
@@ -3095,6 +3169,8 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
|
||||
return m.handleResumeCommand()
|
||||
case "/export":
|
||||
return m.handleExportCommand(args)
|
||||
case "/copy":
|
||||
return m.handleCopyCommand()
|
||||
case "/share":
|
||||
return m.handleShareCommand()
|
||||
case "/import":
|
||||
@@ -3395,13 +3471,56 @@ func (m *AppModel) refreshPromptTemplates() {
|
||||
}
|
||||
}
|
||||
|
||||
// refreshSkillItems reloads skill items from the provider callback.
|
||||
// Called on ContentReloadEvent.
|
||||
// refreshSkillItems reloads skill items from the provider callback and
|
||||
// updates the autocomplete entries. Called on ContentReloadEvent.
|
||||
func (m *AppModel) refreshSkillItems() {
|
||||
if m.getSkillItems == nil {
|
||||
return
|
||||
}
|
||||
m.skillItems = m.getSkillItems()
|
||||
newItems := m.getSkillItems()
|
||||
m.skillItems = newItems
|
||||
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
// Remove old Skills commands and add fresh ones.
|
||||
var kept []commands.SlashCommand
|
||||
for _, sc := range ic.commands {
|
||||
if sc.Category != "Skills" {
|
||||
kept = append(kept, sc)
|
||||
}
|
||||
}
|
||||
for _, s := range newItems {
|
||||
kept = append(kept, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
ic.commands = kept
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExtensionItems reloads extension items from the provider callback
|
||||
// so the [Extensions] startup section reflects the current set after a
|
||||
// hot-reload. Called from the extReloadResultMsg handler.
|
||||
func (m *AppModel) refreshExtensionItems() {
|
||||
if m.getExtensionItems == nil {
|
||||
return
|
||||
}
|
||||
m.extensionItems = m.getExtensionItems()
|
||||
}
|
||||
|
||||
// formatSkillDescription returns the autocomplete description for a skill,
|
||||
// prefixed with [project] or [user] so users can tell colliding names apart.
|
||||
func formatSkillDescription(s SkillItem) string {
|
||||
prefix := "[user]"
|
||||
if s.Source == "project" {
|
||||
prefix = "[project]"
|
||||
}
|
||||
if s.Description == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + " " + s.Description
|
||||
}
|
||||
|
||||
// refreshMCPPrompts reloads MCP prompts from the provider callback and
|
||||
@@ -3476,6 +3595,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**System:**\n" +
|
||||
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
|
||||
"- `/clear`: Clear message history\n" +
|
||||
"- `/copy`: Copy the last message to the system clipboard\n" +
|
||||
"- `/export [path]`: Export session as JSONL\n" +
|
||||
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
|
||||
"- `/reset-usage`: Reset usage statistics\n" +
|
||||
@@ -3712,7 +3832,12 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
}
|
||||
// Auto-scroll to bottom if enabled (iteratr pattern)
|
||||
// Don't call SetItems() - the slice reference hasn't changed
|
||||
if m.scrollList != nil {
|
||||
//
|
||||
// CRITICAL: never scroll the viewport while the user is actively
|
||||
// selecting text (mouse button held). Doing so shifts the
|
||||
// highlighted content out from under the cursor and produces the
|
||||
// off-by-N-row drift users see when copy-selecting during streaming.
|
||||
if m.scrollList != nil && !m.scrollList.IsMouseDown() {
|
||||
if m.scrollList.autoScroll {
|
||||
m.scrollList.GotoBottom()
|
||||
} else if m.scrollList.AtBottom() {
|
||||
@@ -3740,6 +3865,36 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
m.refreshContent()
|
||||
}
|
||||
|
||||
// currentScrollbackBounds returns the live (yOffset, viewportHeight) for the
|
||||
// scrollback region, computed from the current state — not from the cached
|
||||
// values populated inside View().
|
||||
//
|
||||
// scrollbackYOffset and scrollList.height are refreshed once per render, so
|
||||
// any state change that resizes the header (extension widget toggles,
|
||||
// warning rows, queued messages, etc.) leaves the cached values one frame
|
||||
// stale. Mouse click handlers in Update() can then place the cursor on the
|
||||
// wrong line, producing the off-by-N-row drift seen during copy-selection.
|
||||
//
|
||||
// This recomputes the header height by rendering it (cheap — the renderer
|
||||
// returns "" when no extension header is set) and recomputes the viewport
|
||||
// height the same way distributeHeight() does, so both inputs to the
|
||||
// y → (item, line) mapping are always current.
|
||||
func (m *AppModel) currentScrollbackBounds() (yOffset, viewportHeight int) {
|
||||
// Force a fresh layout if anything in Update() marked the state dirty;
|
||||
// otherwise scrollList.height still reflects the previous frame.
|
||||
if m.layoutDirty {
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = false
|
||||
}
|
||||
if headerView := m.renderHeaderFooter(m.getHeader); headerView != "" {
|
||||
yOffset = lipgloss.Height(headerView)
|
||||
}
|
||||
if m.scrollList != nil {
|
||||
viewportHeight = m.scrollList.height
|
||||
}
|
||||
return yOffset, viewportHeight
|
||||
}
|
||||
|
||||
// distributeHeight recalculates child component heights after a window resize,
|
||||
// queue change, widget update, or state transition, and propagates the computed
|
||||
// stream height to the StreamComponent.
|
||||
@@ -3812,7 +3967,20 @@ func (m *AppModel) distributeHeight() {
|
||||
headerFooterLines += lipgloss.Height(footerView)
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines, 0)
|
||||
// Account for transient warning rows that View() injects between the
|
||||
// scrollback and the separator. These flags are toggled by ESC/Ctrl+C
|
||||
// handlers; without subtracting them here the joined view exceeds
|
||||
// m.height by one line per active warning and the bottom of the screen
|
||||
// gets silently clipped — which in turn invalidates scrollbackYOffset.
|
||||
var warningLines int
|
||||
if m.canceling {
|
||||
warningLines++
|
||||
}
|
||||
if m.ctrlCPressedOnce {
|
||||
warningLines++
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines-warningLines, 0)
|
||||
|
||||
// In alt screen mode, give the calculated height to ScrollList instead of stream.
|
||||
// The stream component still exists but is embedded as the last item in scrollList.
|
||||
@@ -4236,6 +4404,48 @@ func (m *AppModel) handleNameCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCopyCommand copies the last user or assistant message to the system
|
||||
// clipboard. Skips transient system messages (e.g. /help output) so the user
|
||||
// gets the actual last conversational message.
|
||||
func (m *AppModel) handleCopyCommand() tea.Cmd {
|
||||
if len(m.messages) == 0 {
|
||||
m.printSystemMessage("No messages to copy.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
text string
|
||||
role string
|
||||
)
|
||||
for i := len(m.messages) - 1; i >= 0; i-- {
|
||||
switch msg := m.messages[i].(type) {
|
||||
case *TextMessageItem:
|
||||
if msg.role == "user" || msg.role == "assistant" {
|
||||
text = msg.content
|
||||
role = msg.role
|
||||
}
|
||||
case *StreamingMessageItem:
|
||||
if msg.role == "assistant" || msg.role == "reasoning" {
|
||||
text = msg.content.String()
|
||||
role = msg.role
|
||||
}
|
||||
}
|
||||
if text != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(text) == "" {
|
||||
m.printSystemMessage("No copyable message found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Copied last %s message to clipboard (%d chars).", role, len(text),
|
||||
))
|
||||
return clipboard.CopyToClipboard(text)
|
||||
}
|
||||
|
||||
// handleExportCommand exports the current session to a file.
|
||||
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
|
||||
//
|
||||
|
||||
@@ -19,28 +19,7 @@ import (
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
// Any @file tokens in the content are highlighted with the theme accent color.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
// Highlight @file tokens with accent color so file references are
|
||||
// visually distinct from surrounding prompt text.
|
||||
content = HighlightFileTokens(content, theme)
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
// UserBlock-related rendering helpers and herald typography.
|
||||
|
||||
// HighlightFileTokens wraps @file tokens in the given text with the theme
|
||||
// accent color so they stand out visually in rendered user messages.
|
||||
@@ -154,44 +133,6 @@ func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) strin
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ToolBlock renders a tool execution result with header and body.
|
||||
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
return styleMarginBottom(theme, fullContent)
|
||||
}
|
||||
|
||||
// styleMarginBottom applies a 1-line margin bottom using the theme.
|
||||
func styleMarginBottom(theme style.Theme, content string) string {
|
||||
return style.GetCachedStyles().MarginBottom1.Render(content)
|
||||
|
||||
@@ -4,30 +4,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// testTypography creates a herald Typography for tests.
|
||||
func testTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertLabel(herald.AlertTip, ""),
|
||||
herald.WithAlertIcon(herald.AlertTip, ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHighlightFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
|
||||
@@ -88,24 +67,25 @@ func TestHighlightFileTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserBlockHighlightsFileTokens(t *testing.T) {
|
||||
// TestHighlightFileTokensInjectsANSI verifies that HighlightFileTokens
|
||||
// preserves the original @file references in the output and wraps each
|
||||
// token with ANSI escape codes for the theme accent color.
|
||||
func TestHighlightFileTokensInjectsANSI(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
ty := testTypography(theme)
|
||||
|
||||
// A user message with @file tokens should contain ANSI escapes around the token.
|
||||
content := "refactor @main.go and @utils.go"
|
||||
result := UserBlock(content, 80, ty, theme)
|
||||
result := HighlightFileTokens(content, theme)
|
||||
|
||||
// The rendered output should contain both file references.
|
||||
// The output should still contain both file references.
|
||||
if !strings.Contains(result, "@main.go") {
|
||||
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @main.go, got:\n%s", result)
|
||||
}
|
||||
if !strings.Contains(result, "@utils.go") {
|
||||
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @utils.go, got:\n%s", result)
|
||||
}
|
||||
|
||||
// Verify ANSI codes are present (the tokens are styled).
|
||||
if !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
|
||||
t.Errorf("HighlightFileTokens output should contain ANSI escape codes for styled @file tokens")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,13 @@ func NewScrollList(width, height int) *ScrollList {
|
||||
}
|
||||
|
||||
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
|
||||
// the viewport will scroll to the bottom to show the latest content.
|
||||
// the viewport will scroll to the bottom to show the latest content — EXCEPT
|
||||
// when the user is actively selecting text (mouse button held), in which case
|
||||
// the scroll position is locked so the highlighted content stays under the
|
||||
// cursor. The pending bottom-scroll is deferred to MouseUp.
|
||||
func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
s.items = items
|
||||
if s.autoScroll {
|
||||
if s.autoScroll && !s.sel.MouseDown {
|
||||
s.GotoBottom()
|
||||
}
|
||||
}
|
||||
@@ -157,6 +160,10 @@ func (s *ScrollList) HandleMouseDown(x, y int) bool {
|
||||
// HandleMouseDrag handles mouse motion while button is held.
|
||||
// Updates the selection endpoint for character-level precision.
|
||||
// Returns true if selection was updated.
|
||||
//
|
||||
// Defensively disables auto-scroll on every drag update — even if the
|
||||
// MouseDown handler missed (e.g. click landed in viewport padding), any
|
||||
// active drag means the user is selecting and the viewport must not jump.
|
||||
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
if !s.sel.MouseDown {
|
||||
return false
|
||||
@@ -171,6 +178,9 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Hard-lock the viewport while dragging.
|
||||
s.autoScroll = false
|
||||
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
@@ -178,6 +188,13 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsMouseDown reports whether the user currently has the mouse button held
|
||||
// (i.e. a selection drag is in progress). Used by the parent model to avoid
|
||||
// re-enabling auto-scroll during streaming while the user is selecting.
|
||||
func (s *ScrollList) IsMouseDown() bool {
|
||||
return s.sel.MouseDown
|
||||
}
|
||||
|
||||
// HandleMouseUp handles mouse button release.
|
||||
// Returns true if there was an active selection.
|
||||
func (s *ScrollList) HandleMouseUp() bool {
|
||||
@@ -521,6 +538,21 @@ func (s *ScrollList) View() string {
|
||||
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
|
||||
item := s.items[idx]
|
||||
content := item.Render(s.width)
|
||||
|
||||
// Items that render to an empty string contribute zero height to
|
||||
// the viewport. This MUST match renderedHeight()'s semantics —
|
||||
// otherwise getItemAndLineAtY (which uses renderedHeight) treats
|
||||
// the item as 0 lines while View() emits one blank line via
|
||||
// strings.Split("", "\n") = [""], producing a 1-row downward
|
||||
// drift in mouse hit-testing per empty item between offsetIdx
|
||||
// and the cursor (most visibly streaming-reasoning items before
|
||||
// any reasoning has streamed, which extension widgets surface by
|
||||
// shrinking the scrollback).
|
||||
if content == "" {
|
||||
s.heightCache[item.ID()] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
// Refresh height cache from the actual render (authoritative).
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fakeItem is a deterministic MessageItem for ScrollList tests.
|
||||
type fakeItem struct {
|
||||
id string
|
||||
lines int
|
||||
}
|
||||
|
||||
func (f *fakeItem) ID() string { return f.id }
|
||||
func (f *fakeItem) Render(_ int) string {
|
||||
if f.lines <= 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, f.lines)
|
||||
for i := range parts {
|
||||
parts[i] = fmt.Sprintf("%s-line-%d", f.id, i)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
func (f *fakeItem) Height() int { return f.lines }
|
||||
|
||||
// makeItems builds n fake items of `lines` height each.
|
||||
func makeItems(n, lines int) []MessageItem {
|
||||
out := make([]MessageItem, n)
|
||||
for i := range out {
|
||||
out[i] = &fakeItem{id: fmt.Sprintf("item-%d", i), lines: lines}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TestScrollList_MouseDownPreventsAutoScroll verifies the core fix for the
|
||||
// copy-selection drift bug: while the user has the mouse button held
|
||||
// (drag-selecting), incoming content updates must NOT shift the viewport,
|
||||
// because doing so moves the highlighted content out from under the cursor.
|
||||
func TestScrollList_MouseDownPreventsAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2)) // 40 lines of content into a 10-line viewport
|
||||
// Capture the auto-scrolled-to-bottom position.
|
||||
startOffsetIdx := sl.offsetIdx
|
||||
startOffsetLine := sl.offsetLine
|
||||
|
||||
// User clicks somewhere in the visible area, starting a drag-select.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown should accept a click inside the viewport")
|
||||
}
|
||||
if !sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be true after HandleMouseDown")
|
||||
}
|
||||
|
||||
// New content arrives. With autoScroll still true, SetItems would
|
||||
// normally call GotoBottom() and shift the viewport. The fix should
|
||||
// suppress that while MouseDown is held.
|
||||
sl.SetItems(makeItems(30, 2)) // 60 lines now
|
||||
if sl.offsetIdx != startOffsetIdx || sl.offsetLine != startOffsetLine {
|
||||
t.Errorf("viewport scrolled during active drag: was (%d,%d), now (%d,%d)",
|
||||
startOffsetIdx, startOffsetLine, sl.offsetIdx, sl.offsetLine)
|
||||
}
|
||||
|
||||
// User releases the mouse — drag is over.
|
||||
sl.HandleMouseUp()
|
||||
if sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be false after HandleMouseUp")
|
||||
}
|
||||
|
||||
// After release, a fresh content update should resume auto-scrolling
|
||||
// (move the offset to track the new bottom).
|
||||
afterReleaseIdx := sl.offsetIdx
|
||||
afterReleaseLine := sl.offsetLine
|
||||
sl.SetItems(makeItems(50, 2))
|
||||
if sl.offsetIdx == afterReleaseIdx && sl.offsetLine == afterReleaseLine {
|
||||
t.Errorf("autoscroll did not resume after MouseUp: offset stuck at (%d,%d)",
|
||||
afterReleaseIdx, afterReleaseLine)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_DragDisablesAutoScroll verifies that any successful
|
||||
// HandleMouseDrag call clears autoScroll, even when HandleMouseDown didn't
|
||||
// observe it (e.g. a stale wheel-down event set it back to true mid-stream).
|
||||
func TestScrollList_DragDisablesAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2))
|
||||
|
||||
// Begin a selection.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
// Simulate an external code path that re-enabled autoScroll while
|
||||
// MouseDown is still held (the precise condition that caused drift).
|
||||
sl.autoScroll = true
|
||||
|
||||
// Drag motion should hard-lock the viewport again.
|
||||
if !sl.HandleMouseDrag(10, 4) {
|
||||
t.Fatalf("HandleMouseDrag failed")
|
||||
}
|
||||
if sl.autoScroll {
|
||||
t.Errorf("HandleMouseDrag must clear autoScroll to prevent mid-drag jumps")
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_SetItemsRespectsMouseDown is the most direct regression
|
||||
// test: even with autoScroll enabled and new content appended at the
|
||||
// bottom, SetItems must not move the viewport while a mouse drag is in
|
||||
// progress. This is what caused the "highlighting shifts by 1+ rows
|
||||
// during streaming" symptom reported by the user.
|
||||
func TestScrollList_SetItemsRespectsMouseDown(t *testing.T) {
|
||||
sl := NewScrollList(80, 5)
|
||||
sl.SetItems(makeItems(10, 2)) // 20 lines into a 5-line viewport
|
||||
// At bottom.
|
||||
preIdx, preLine := sl.offsetIdx, sl.offsetLine
|
||||
|
||||
// Hold mouse down (no actual drag needed).
|
||||
if !sl.HandleMouseDown(0, 0) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
|
||||
// Append several more items as if streaming. With the bug, each
|
||||
// SetItems would call GotoBottom and shift the offset.
|
||||
for n := 11; n <= 15; n++ {
|
||||
sl.SetItems(makeItems(n, 2))
|
||||
if sl.offsetIdx != preIdx || sl.offsetLine != preLine {
|
||||
t.Fatalf("viewport drifted during streaming with mouse held: "+
|
||||
"start=(%d,%d) now=(%d,%d) after adding item %d",
|
||||
preIdx, preLine, sl.offsetIdx, sl.offsetLine, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_EmptyItemsDoNotShiftMouseMapping is the regression test
|
||||
// for the second drift bug: items that render to "" must contribute the
|
||||
// same number of rows in View() (zero) as in renderedHeight(), or mouse
|
||||
// hit-testing drifts by one row per empty item between offsetIdx and the
|
||||
// cursor. This was surfaced by extension widgets (e.g. subagent-monitor)
|
||||
// that shrink the scrollback so empty streaming-reasoning items end up
|
||||
// in the visible window.
|
||||
//
|
||||
// Setup: 1 normal item + 1 empty item + 1 normal item. Click on the line
|
||||
// where the third item begins. With the bug, getItemAndLineAtY skips the
|
||||
// empty item (renderedHeight=0) and reports lineIdx pointing one row
|
||||
// past where View() actually painted that line.
|
||||
func TestScrollList_EmptyItemsDoNotShiftMouseMapping(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems([]MessageItem{
|
||||
&fakeItem{id: "a", lines: 2}, // viewY 0–1
|
||||
&fakeItem{id: "empty", lines: 0}, // renders "" — contributes 0 rows
|
||||
&fakeItem{id: "b", lines: 2}, // viewY 2–3
|
||||
})
|
||||
|
||||
// Render the viewport once so the cache reflects what View() actually
|
||||
// emits (this is the path that previously diverged from renderedHeight
|
||||
// for empty items).
|
||||
rendered := sl.View()
|
||||
lines := strings.Split(rendered, "\n")
|
||||
|
||||
// Sanity: View() must emit exactly height lines.
|
||||
if len(lines) != 10 {
|
||||
t.Fatalf("View() returned %d lines, want 10", len(lines))
|
||||
}
|
||||
// Item b's first line should appear at viewY=2, NOT viewY=3.
|
||||
if !strings.Contains(lines[2], "b-line-0") {
|
||||
t.Errorf("viewY=2 should render b-line-0 (empty item contributes 0 rows), got %q", lines[2])
|
||||
}
|
||||
|
||||
// Now the actual hit-test contract: clicking on viewY=2 must map to
|
||||
// item b line 0 — the same coordinate View() rendered there.
|
||||
idx, line := sl.getItemAndLineAtY(2)
|
||||
if idx != 2 || line != 0 {
|
||||
t.Errorf("getItemAndLineAtY(2) = (%d,%d), want (2,0)", idx, line)
|
||||
}
|
||||
|
||||
// And clicking on the second line of b (viewY=3) must map to b line 1.
|
||||
idx, line = sl.getItemAndLineAtY(3)
|
||||
if idx != 2 || line != 1 {
|
||||
t.Errorf("getItemAndLineAtY(3) = (%d,%d), want (2,1)", idx, line)
|
||||
}
|
||||
}
|
||||
@@ -230,8 +230,10 @@ func FindWordBoundaries(line string, col int) (startCol, endCol int) {
|
||||
|
||||
// HighlightLine applies reverse-video highlighting to a portion of a rendered
|
||||
// line (which may contain ANSI escape codes). startCol/endCol are in display
|
||||
// columns. If startCol == -1, the entire line is highlighted. If startCol ==
|
||||
// endCol, returns the line unchanged.
|
||||
// columns. If startCol == -1, the entire line is highlighted. If endCol ==
|
||||
// -1, the highlight runs from startCol to the end of the line (the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection).
|
||||
// If startCol == endCol, returns the line unchanged.
|
||||
//
|
||||
// Uses ultraviolet ScreenBuffer for cell-level ANSI manipulation.
|
||||
func HighlightLine(line string, startCol, endCol int) string {
|
||||
@@ -250,6 +252,16 @@ func HighlightLine(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (returned by IsLineInRange
|
||||
// for the first line of a multi-line selection). Without this branch,
|
||||
// the start line of a multi-line drag would never be highlighted —
|
||||
// the user perceives this as the selection being shifted one row down
|
||||
// from the cursor, especially when extension widgets shrink the
|
||||
// scrollback and make the start line land on a tall styled block.
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return line
|
||||
}
|
||||
@@ -296,6 +308,11 @@ func ExtractText(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (see HighlightLine).
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -357,6 +357,54 @@ func TestHighlightLine_NoSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlightLine_EndOfLineSentinel verifies that endCol=-1 is interpreted
|
||||
// as "highlight from startCol to end of line", matching the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection.
|
||||
//
|
||||
// Regression: without this contract, the start line of any multi-line drag
|
||||
// would silently fall through HighlightLine's startCol >= endCol guard and
|
||||
// render unstyled, making the selection appear to begin one row below the
|
||||
// cursor — the exact "tracking gets shifted" symptom users reported when
|
||||
// extension widgets shrank the scrollback enough that the click landed on a
|
||||
// styled tool-result block.
|
||||
func TestHighlightLine_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
result := HighlightLine(line, 0, -1)
|
||||
if result == line {
|
||||
t.Errorf("endCol=-1 should highlight from startCol to end of line; got unchanged input")
|
||||
}
|
||||
if len(result) <= len(line) {
|
||||
t.Errorf("highlighted result should be longer than plain input (ANSI codes added); got len=%d want > %d", len(result), len(line))
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractText_EndOfLineSentinel mirrors TestHighlightLine_EndOfLineSentinel
|
||||
// for the extraction path used by the clipboard copy.
|
||||
func TestExtractText_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
got := ExtractText(line, 7, -1)
|
||||
want := "World!"
|
||||
if got != want {
|
||||
t.Errorf("ExtractText(line, 7, -1) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLineInRange_StartLineSentinelHighlights composes IsLineInRange with
|
||||
// HighlightLine end-to-end: the start line of a multi-line, single-item
|
||||
// selection must actually emit highlight ANSI codes. This is the contract
|
||||
// the rendering path in scrolllist.View() relies on.
|
||||
func TestIsLineInRange_StartLineSentinelHighlights(t *testing.T) {
|
||||
r := Range{StartItemIdx: 5, EndItemIdx: 5, StartLine: 0, EndLine: 2, StartCol: 0, EndCol: 10}
|
||||
inRange, sc, ec := IsLineInRange(r, 5, 0)
|
||||
if !inRange {
|
||||
t.Fatalf("item 5 line 0 should be in range")
|
||||
}
|
||||
highlighted := HighlightLine("first line of selection", sc, ec)
|
||||
if highlighted == "first line of selection" {
|
||||
t.Errorf("first line of multi-line selection was not highlighted (sc=%d ec=%d)", sc, ec)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiClickDetection verifies the click counting logic.
|
||||
func TestMultiClickDetection(t *testing.T) {
|
||||
s := NewState()
|
||||
|
||||
@@ -211,106 +211,11 @@ func DefaultTheme() Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
|
||||
// padding, and appropriate width. Used for grouping related content in a visually
|
||||
// distinct box.
|
||||
func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Border).
|
||||
Padding(1, 2).
|
||||
MarginBottom(1)
|
||||
}
|
||||
|
||||
// IsDarkBackground returns the cached terminal background detection result.
|
||||
func IsDarkBackground() bool {
|
||||
return isDarkBg
|
||||
}
|
||||
|
||||
// StyleHeader creates a lipgloss style for primary headers using the theme's
|
||||
// primary color with bold text for emphasis and hierarchy.
|
||||
func StyleHeader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
|
||||
// secondary color with bold text, providing visual hierarchy below primary headers.
|
||||
func StyleSubheader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
|
||||
// and italic formatting, suitable for supplementary or less important information.
|
||||
func StyleMuted(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
}
|
||||
|
||||
// StyleSuccess creates a lipgloss style for success messages using green colors
|
||||
// with bold text to indicate successful operations or positive outcomes.
|
||||
func StyleSuccess(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleError creates a lipgloss style for error messages using red colors
|
||||
// with bold text to ensure visibility of problems or failures.
|
||||
func StyleError(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
|
||||
// colors with bold text to draw attention to potential issues or cautions.
|
||||
func StyleWarning(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleInfo creates a lipgloss style for informational messages using blue colors
|
||||
// with bold text for general notifications and status updates.
|
||||
func StyleInfo(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// CreateSeparator generates a horizontal separator line with the specified width,
|
||||
// character, and color. Useful for visually dividing sections of content in the UI.
|
||||
func CreateSeparator(width int, char string, c color.Color) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Width(width).
|
||||
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
|
||||
}
|
||||
|
||||
// CreateProgressBar generates a visual progress bar with filled and empty segments
|
||||
// based on the percentage complete. The bar uses Unicode block characters for smooth
|
||||
// appearance and theme colors to indicate progress.
|
||||
func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
filled := int(float64(width) * percentage / 100)
|
||||
empty := width - filled
|
||||
|
||||
filledBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Render(lipgloss.PlaceHorizontal(filled, lipgloss.Left, "█"))
|
||||
|
||||
emptyBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Render(lipgloss.PlaceHorizontal(empty, lipgloss.Left, "░"))
|
||||
|
||||
return filledBar + emptyBar
|
||||
}
|
||||
|
||||
// CreateBadge generates a styled badge or label with inverted colors (text on
|
||||
// colored background) for highlighting important tags, statuses, or categories.
|
||||
func CreateBadge(text string, c color.Color) string {
|
||||
|
||||
@@ -6,13 +6,6 @@ import (
|
||||
heraldmd "github.com/indaco/herald-md"
|
||||
)
|
||||
|
||||
// BaseStyle returns a new, empty lipgloss style that can be customized with
|
||||
// additional styling methods. This serves as the foundation for building more
|
||||
// complex styled components.
|
||||
func BaseStyle() lipgloss.Style {
|
||||
return lipgloss.NewStyle()
|
||||
}
|
||||
|
||||
// markdownTypographyCache holds the last-created Typography instance for
|
||||
// herald-md rendering. It is cached to avoid re-initialization on every
|
||||
// streaming flush tick. The cache is invalidated by SetTheme when the
|
||||
|
||||
@@ -543,12 +543,6 @@ func ApplyThemeWithoutSave(name string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshThemeRegistry re-scans the themes directory. Call after the user
|
||||
// drops a new file into ~/.config/kit/themes/.
|
||||
func RefreshThemeRegistry() {
|
||||
initThemeRegistry()
|
||||
}
|
||||
|
||||
// RegisterThemeFromConfig adds a theme to the runtime registry from an
|
||||
// extension's ThemeColorConfig (string hex pairs). Replaces any existing
|
||||
// entry with the same name. The theme is immediately available via
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"charm.land/bubbles/v2/textarea"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
)
|
||||
|
||||
type ToolApprovalInput struct {
|
||||
textarea textarea.Model
|
||||
toolName string
|
||||
toolArgs string
|
||||
width int
|
||||
selected bool // true when "yes" is highlighted and false when "no" is
|
||||
approved bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func NewToolApprovalInput(toolName, toolArgs string, width int) *ToolApprovalInput {
|
||||
ta := textarea.New()
|
||||
ta.Placeholder = ""
|
||||
ta.ShowLineNumbers = false
|
||||
ta.CharLimit = 0
|
||||
ta.SetWidth(width - 8) // Account for container padding, border and internal padding
|
||||
ta.SetHeight(4) // Default to 3 lines like huh
|
||||
ta.Focus()
|
||||
|
||||
// Style the textarea using theme colors.
|
||||
theme := GetTheme()
|
||||
styles := ta.Styles()
|
||||
styles.Focused.Base = lipgloss.NewStyle()
|
||||
styles.Focused.Placeholder = lipgloss.NewStyle().Foreground(theme.VeryMuted)
|
||||
styles.Focused.Text = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
styles.Focused.Prompt = lipgloss.NewStyle()
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &ToolApprovalInput{
|
||||
textarea: ta,
|
||||
toolName: toolName,
|
||||
toolArgs: toolArgs,
|
||||
width: width,
|
||||
selected: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Init() tea.Cmd {
|
||||
return textarea.Blink
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyPressMsg:
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
t.approved = true
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "n", "N":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "left":
|
||||
t.selected = true
|
||||
return t, nil
|
||||
case "right":
|
||||
t.selected = false
|
||||
return t, nil
|
||||
case "enter":
|
||||
t.approved = t.selected
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
case "esc", "ctrl+c":
|
||||
t.approved = false
|
||||
t.done = true
|
||||
return t, tea.Quit
|
||||
}
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *ToolApprovalInput) View() tea.View {
|
||||
if t.done {
|
||||
return tea.NewView("we are done")
|
||||
}
|
||||
|
||||
containerStyle := lipgloss.NewStyle()
|
||||
|
||||
theme := GetTheme()
|
||||
|
||||
// PaddingLeft(3) aligns with message content: border(1) + paddingLeft(2).
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Text).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(3)
|
||||
|
||||
// Input box with huh-like styling
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.ThickBorder()).
|
||||
BorderLeft(true).
|
||||
BorderRight(false).
|
||||
BorderTop(false).
|
||||
BorderBottom(false).
|
||||
BorderForeground(theme.Primary).
|
||||
PaddingLeft(2). // match message block paddingLeft
|
||||
Width(t.width - 1) // full width minus left border
|
||||
|
||||
// Style for the currently selected/highlighted option
|
||||
selectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true).
|
||||
Underline(true)
|
||||
|
||||
// Style for the unselected/unhighlighted option
|
||||
unselectedStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.VeryMuted)
|
||||
|
||||
// Build the view
|
||||
var view strings.Builder
|
||||
view.WriteString(titleStyle.Render("Allow tool execution"))
|
||||
view.WriteString("\n")
|
||||
details := fmt.Sprintf("Tool: %s\nArguments: %s\n\n", t.toolName, t.toolArgs)
|
||||
view.WriteString(details)
|
||||
view.WriteString("Allow tool execution: ")
|
||||
|
||||
var yesText, noText string
|
||||
if t.selected {
|
||||
yesText = selectedStyle.Render("[y]es")
|
||||
noText = unselectedStyle.Render("[n]o")
|
||||
} else {
|
||||
yesText = unselectedStyle.Render("[y]es")
|
||||
noText = selectedStyle.Render("[n]o")
|
||||
}
|
||||
view.WriteString(yesText + "/" + noText + "\n")
|
||||
|
||||
return tea.NewView(containerStyle.Render(inputBoxStyle.Render(view.String())))
|
||||
}
|
||||
+17
-2
@@ -243,7 +243,7 @@ host.ClearSession()
|
||||
|
||||
## Re-exported Types
|
||||
|
||||
The SDK re-exports types so you don't need direct internal imports:
|
||||
The SDK re-exports message/session/MCP types so you don't need direct internal imports. Agent-configuration types are Kit-owned (not aliases) and use only SDK types in their signatures, so consumers never need to import the underlying LLM-provider package.
|
||||
|
||||
```go
|
||||
// Message types
|
||||
@@ -251,13 +251,28 @@ kit.Message, kit.MessageRole, kit.ContentPart
|
||||
kit.TextContent, kit.ReasoningContent, kit.ToolCall, kit.ToolResult, kit.Finish
|
||||
kit.RoleUser, kit.RoleAssistant, kit.RoleTool, kit.RoleSystem
|
||||
|
||||
// LLM types — concrete Kit-owned structs, no external library dependency
|
||||
// LLM types — Kit-owned `LLM*` aliases over the underlying provider types,
|
||||
// so consumers never import the provider package directly
|
||||
kit.LLMMessage // {Role LLMMessageRole, Content string}
|
||||
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ...}
|
||||
kit.LLMResponse // {Content, FinishReason, Usage}
|
||||
kit.LLMFilePart // {Filename, Data []byte, MediaType}
|
||||
|
||||
// Agent configuration — concrete Kit-owned structs and function types.
|
||||
// All fields use SDK types (e.g. `[]kit.Tool`), so consumers can construct
|
||||
// these without importing any LLM-provider package.
|
||||
kit.AgentConfig // Lower-level agent config — prefer Options unless you need direct control
|
||||
kit.DebugLogger // Interface: LogDebug(string) / IsDebugEnabled() bool
|
||||
kit.MCPTaskConfig // Task-aware MCP tools/call config (modes, polling, progress)
|
||||
kit.ToolCallHandler // func(toolCallID, toolName, toolArgs string)
|
||||
kit.ToolExecutionHandler // func(toolCallID, toolName, toolArgs string, isStarting bool)
|
||||
kit.ToolResultHandler // func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
|
||||
kit.ResponseHandler // func(content string)
|
||||
kit.StreamingResponseHandler // func(content string)
|
||||
kit.ToolCallContentHandler // func(content string)
|
||||
kit.SpinnerFunc // func(fn func() error) error
|
||||
|
||||
// MCP OAuth types
|
||||
kit.MCPServer // *server.MCPServer for in-process MCP transport
|
||||
kit.MCPServerConfig // Configuration for an MCP server (stdio, SSE, or in-process)
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
package kit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/agent"
|
||||
)
|
||||
|
||||
// TestAgentConfigToInternal verifies that the SDK-side AgentConfig converts
|
||||
// faithfully to the internal agent.AgentConfig representation, preserving
|
||||
// every field consumed by the internal agent layer.
|
||||
//
|
||||
// Regression test for https://github.com/mark3labs/kit/issues/30.
|
||||
func TestAgentConfigToInternal(t *testing.T) {
|
||||
t.Run("nil receiver returns nil", func(t *testing.T) {
|
||||
var c *AgentConfig
|
||||
if got := c.toInternal(); got != nil {
|
||||
t.Errorf("nil.toInternal() = %v, want nil", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("scalar fields round-trip", func(t *testing.T) {
|
||||
c := &AgentConfig{
|
||||
SystemPrompt: "sys",
|
||||
MaxSteps: 7,
|
||||
StreamingEnabled: true,
|
||||
DisableCoreTools: true,
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got == nil {
|
||||
t.Fatal("toInternal() = nil")
|
||||
}
|
||||
if got.SystemPrompt != "sys" {
|
||||
t.Errorf("SystemPrompt = %q, want %q", got.SystemPrompt, "sys")
|
||||
}
|
||||
if got.MaxSteps != 7 {
|
||||
t.Errorf("MaxSteps = %d, want 7", got.MaxSteps)
|
||||
}
|
||||
if !got.StreamingEnabled {
|
||||
t.Error("StreamingEnabled = false, want true")
|
||||
}
|
||||
if !got.DisableCoreTools {
|
||||
t.Error("DisableCoreTools = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool slices propagate without conversion", func(t *testing.T) {
|
||||
// Tool is a type alias for the underlying LLM-tool type, so the
|
||||
// SDK []Tool and internal []fantasy.AgentTool slices share the
|
||||
// same backing array after conversion.
|
||||
tool := NewTool[struct{}]("noop", "noop", nil)
|
||||
c := &AgentConfig{
|
||||
CoreTools: []Tool{tool},
|
||||
ExtraTools: []Tool{tool, tool},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if len(got.CoreTools) != 1 {
|
||||
t.Errorf("CoreTools len = %d, want 1", len(got.CoreTools))
|
||||
}
|
||||
if len(got.ExtraTools) != 2 {
|
||||
t.Errorf("ExtraTools len = %d, want 2", len(got.ExtraTools))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool wrapper is invoked through internal config", func(t *testing.T) {
|
||||
called := false
|
||||
c := &AgentConfig{
|
||||
ToolWrapper: func(in []Tool) []Tool {
|
||||
called = true
|
||||
return in
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.ToolWrapper == nil {
|
||||
t.Fatal("internal ToolWrapper is nil")
|
||||
}
|
||||
_ = got.ToolWrapper(nil)
|
||||
if !called {
|
||||
t.Error("SDK ToolWrapper was not invoked through the internal config")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OnMCPServerLoaded propagates", func(t *testing.T) {
|
||||
var captured string
|
||||
wantErr := errors.New("boom")
|
||||
c := &AgentConfig{
|
||||
OnMCPServerLoaded: func(name string, _ int, _ error) {
|
||||
captured = name
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
got.OnMCPServerLoaded("svr", 3, wantErr)
|
||||
if captured != "svr" {
|
||||
t.Errorf("OnMCPServerLoaded captured = %q, want %q", captured, "svr")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DebugLogger propagates", func(t *testing.T) {
|
||||
dl := &fakeDebugLogger{enabled: true}
|
||||
c := &AgentConfig{DebugLogger: dl}
|
||||
got := c.toInternal()
|
||||
if got.DebugLogger == nil {
|
||||
t.Fatal("internal DebugLogger is nil")
|
||||
}
|
||||
if !got.DebugLogger.IsDebugEnabled() {
|
||||
t.Error("IsDebugEnabled = false, want true")
|
||||
}
|
||||
got.DebugLogger.LogDebug("hello")
|
||||
if len(dl.messages) != 1 || dl.messages[0] != "hello" {
|
||||
t.Errorf("messages = %v, want [hello]", dl.messages)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MCPTaskConfig propagates with mode + progress", func(t *testing.T) {
|
||||
c := &AgentConfig{
|
||||
MCPTaskConfig: MCPTaskConfig{
|
||||
PerServerMode: map[string]MCPTaskMode{
|
||||
"build-svr": MCPTaskModeAlways,
|
||||
},
|
||||
DefaultTTL: 30 * time.Second,
|
||||
PollInterval: 250 * time.Millisecond,
|
||||
MaxPollInterval: 2 * time.Second,
|
||||
Timeout: 5 * time.Minute,
|
||||
Progress: func(_ MCPTaskProgress) {},
|
||||
},
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.MCPTaskConfig.DefaultTTL != 30*time.Second {
|
||||
t.Errorf("DefaultTTL = %v, want 30s", got.MCPTaskConfig.DefaultTTL)
|
||||
}
|
||||
if got.MCPTaskConfig.PollInterval != 250*time.Millisecond {
|
||||
t.Errorf("PollInterval = %v, want 250ms", got.MCPTaskConfig.PollInterval)
|
||||
}
|
||||
if got.MCPTaskConfig.MaxPollInterval != 2*time.Second {
|
||||
t.Errorf("MaxPollInterval = %v, want 2s", got.MCPTaskConfig.MaxPollInterval)
|
||||
}
|
||||
if got.MCPTaskConfig.Timeout != 5*time.Minute {
|
||||
t.Errorf("Timeout = %v, want 5m", got.MCPTaskConfig.Timeout)
|
||||
}
|
||||
mode, ok := got.MCPTaskConfig.PerServerMode["build-svr"]
|
||||
if !ok {
|
||||
t.Fatal("PerServerMode missing 'build-svr'")
|
||||
}
|
||||
if string(mode) != string(MCPTaskModeAlways) {
|
||||
t.Errorf("mode = %q, want %q", mode, MCPTaskModeAlways)
|
||||
}
|
||||
if got.MCPTaskConfig.Progress == nil {
|
||||
t.Fatal("internal Progress handler is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("auth and token store factories are wired", func(t *testing.T) {
|
||||
auth := &fakeAuthHandler{}
|
||||
tokenCalls := 0
|
||||
var tokenServer string
|
||||
factory := MCPTokenStoreFactory(func(server string) (MCPTokenStore, error) {
|
||||
tokenCalls++
|
||||
tokenServer = server
|
||||
return nil, nil
|
||||
})
|
||||
c := &AgentConfig{
|
||||
AuthHandler: auth,
|
||||
TokenStoreFactory: factory,
|
||||
}
|
||||
got := c.toInternal()
|
||||
if got.AuthHandler == nil {
|
||||
t.Fatal("internal AuthHandler is nil")
|
||||
}
|
||||
if got.TokenStoreFactory == nil {
|
||||
t.Fatal("internal TokenStoreFactory is nil")
|
||||
}
|
||||
_, _ = got.TokenStoreFactory("https://example.test")
|
||||
if tokenCalls != 1 {
|
||||
t.Errorf("token factory call count = %d, want 1", tokenCalls)
|
||||
}
|
||||
if tokenServer != "https://example.test" {
|
||||
t.Errorf("token factory server arg = %q", tokenServer)
|
||||
}
|
||||
if got.AuthHandler.RedirectURI() != "redirect" {
|
||||
t.Errorf("RedirectURI = %q, want %q", got.AuthHandler.RedirectURI(), "redirect")
|
||||
}
|
||||
})
|
||||
|
||||
// Compile-time check that the internal type is what we expect.
|
||||
//nolint:staticcheck // QF1011: explicit type asserts the conversion target.
|
||||
var _ *agent.AgentConfig = (&AgentConfig{}).toInternal()
|
||||
}
|
||||
|
||||
// fakeAuthHandler implements both kit.MCPAuthHandler and the structurally
|
||||
// identical tools.MCPAuthHandler used by the internal layer.
|
||||
type fakeAuthHandler struct{}
|
||||
|
||||
func (f *fakeAuthHandler) RedirectURI() string { return "redirect" }
|
||||
func (f *fakeAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// fakeDebugLogger implements kit.DebugLogger for tests.
|
||||
type fakeDebugLogger struct {
|
||||
enabled bool
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (f *fakeDebugLogger) LogDebug(m string) { f.messages = append(f.messages, m) }
|
||||
func (f *fakeDebugLogger) IsDebugEnabled() bool { return f.enabled }
|
||||
+3
-3
@@ -148,9 +148,9 @@ func parseToolArgs(toolArgs string) map[string]any {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Finish reasons reported by the LLM provider on a completed turn. These
|
||||
// mirror fantasy.FinishReason string values so comparisons against
|
||||
// TurnEndEvent.StopReason / TurnResult.StopReason are stable across
|
||||
// providers.
|
||||
// mirror the underlying provider's finish reason string values so
|
||||
// comparisons against TurnEndEvent.StopReason / TurnResult.StopReason are
|
||||
// stable across providers.
|
||||
const (
|
||||
// FinishReasonStop: the model produced a natural stop (e.g. stop sequence
|
||||
// or end-of-turn signal).
|
||||
|
||||
@@ -76,6 +76,22 @@ type ExtensionAPI interface {
|
||||
// Lifecycle
|
||||
Reload() error
|
||||
HasExtensions() bool
|
||||
|
||||
// Loaded returns metadata about the extensions currently loaded.
|
||||
Loaded() []ExtensionInfo
|
||||
}
|
||||
|
||||
// ExtensionInfo describes a single loaded extension for display purposes
|
||||
// (e.g. the startup banner or `kit extensions list`).
|
||||
type ExtensionInfo struct {
|
||||
// Path is the absolute path of the extension's .go file.
|
||||
Path string
|
||||
// ToolCount is the number of tools registered by the extension.
|
||||
ToolCount int
|
||||
// CommandCount is the number of slash commands registered.
|
||||
CommandCount int
|
||||
// HandlerCount is the total number of event handlers registered.
|
||||
HandlerCount int
|
||||
}
|
||||
|
||||
// extensionAPI implements ExtensionAPI by wrapping a Kit instance.
|
||||
@@ -456,3 +472,27 @@ func (e *extensionAPI) Reload() error {
|
||||
func (e *extensionAPI) HasExtensions() bool {
|
||||
return e.kit.extRunner != nil
|
||||
}
|
||||
|
||||
func (e *extensionAPI) Loaded() []ExtensionInfo {
|
||||
if e.kit.extRunner == nil {
|
||||
return nil
|
||||
}
|
||||
exts := e.kit.extRunner.Extensions()
|
||||
if len(exts) == 0 {
|
||||
return nil
|
||||
}
|
||||
infos := make([]ExtensionInfo, 0, len(exts))
|
||||
for _, ex := range exts {
|
||||
handlerCount := 0
|
||||
for _, hs := range ex.Handlers {
|
||||
handlerCount += len(hs)
|
||||
}
|
||||
infos = append(infos, ExtensionInfo{
|
||||
Path: ex.Path,
|
||||
ToolCount: len(ex.Tools),
|
||||
CommandCount: len(ex.Commands),
|
||||
HandlerCount: handlerCount,
|
||||
})
|
||||
}
|
||||
return infos
|
||||
}
|
||||
|
||||
+36
-2
@@ -58,6 +58,9 @@ type Kit struct {
|
||||
// When false, per-model system prompts from modelSettings/customModels
|
||||
// can replace the default prompt on model switch.
|
||||
hasCustomSystemPrompt bool
|
||||
// systemPromptSource holds the raw configured value (file path or text)
|
||||
// when hasCustomSystemPrompt is true; empty when the built-in default is in use.
|
||||
systemPromptSource string
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
@@ -632,6 +635,21 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasCustomSystemPrompt reports whether the user explicitly configured a system
|
||||
// prompt via --system-prompt, a config file entry, or SDK Options.SystemPrompt.
|
||||
// When false, the built-in default (or a per-model override) is in use and can
|
||||
// be replaced transparently on model switch.
|
||||
func (m *Kit) HasCustomSystemPrompt() bool {
|
||||
return m.hasCustomSystemPrompt
|
||||
}
|
||||
|
||||
// GetSystemPromptSource returns the raw configured value — a file path or
|
||||
// inline text — when HasCustomSystemPrompt is true; returns an empty string
|
||||
// when the built-in default prompt is active.
|
||||
func (m *Kit) GetSystemPromptSource() string {
|
||||
return m.systemPromptSource
|
||||
}
|
||||
|
||||
// composeSystemPrompt takes a base system prompt and composes it with the
|
||||
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
|
||||
// This mirrors the composition done during Kit.New() initialization.
|
||||
@@ -1179,6 +1197,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
maxSteps int
|
||||
streaming bool
|
||||
hasCustomSystemPrompt bool
|
||||
systemPromptSource string
|
||||
)
|
||||
|
||||
if err := func() error {
|
||||
@@ -1285,13 +1304,27 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// explicitly set system-prompt, use the per-model prompt as the
|
||||
// base instead of the global default.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
rawPromptInput := viper.GetString("system-prompt")
|
||||
|
||||
// Resolve a file path to its content so PromptBuilder receives the
|
||||
// actual prompt text rather than a literal path string. Without this,
|
||||
// when system-prompt is set to a file path in the config file or via
|
||||
// --system-prompt, the path itself becomes the effective system prompt
|
||||
// sent to the model (LoadSystemPrompt only ran later, after viper had
|
||||
// been overwritten with the augmented base text).
|
||||
basePrompt, _ := config.LoadSystemPrompt(rawPromptInput)
|
||||
if basePrompt == "" {
|
||||
basePrompt = rawPromptInput
|
||||
}
|
||||
|
||||
// Track whether the user explicitly configured a custom system
|
||||
// prompt. When they haven't (basePrompt is the built-in default
|
||||
// or empty), per-model system prompts can replace it on switch.
|
||||
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
|
||||
hasCustomSystemPrompt = userSetSystemPrompt
|
||||
if hasCustomSystemPrompt {
|
||||
systemPromptSource = rawPromptInput
|
||||
}
|
||||
|
||||
// Check for per-model system prompt override when no explicit
|
||||
// global system-prompt was configured by the user.
|
||||
@@ -1456,7 +1489,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
setupOpts.SpinnerFunc = agent.SpinnerFunc(opts.CLI.SpinnerFunc)
|
||||
setupOpts.UseBufferedLogger = opts.CLI.UseBufferedLogger
|
||||
if opts.CLI.ProgressReaderFunc != nil {
|
||||
providerConfig.ProgressReaderFunc = opts.CLI.ProgressReaderFunc
|
||||
@@ -1500,6 +1533,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
opts: opts,
|
||||
mcpConfig: mcpConfig,
|
||||
hasCustomSystemPrompt: hasCustomSystemPrompt,
|
||||
systemPromptSource: systemPromptSource,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
|
||||
@@ -3,6 +3,7 @@ package kit_test
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
@@ -306,3 +307,92 @@ func TestSessionManagement(t *testing.T) {
|
||||
// resetViper wipes viper's global state so a test case doesn't leak
|
||||
// viper.Set() calls into the next one. Used via defer in subtests.
|
||||
func resetViper() { viper.Reset() }
|
||||
|
||||
// TestNewSystemPromptFilePath is a regression test for issue #25.
|
||||
//
|
||||
// When Options.SystemPrompt (or the --system-prompt flag / config entry) is a
|
||||
// file path, Kit must resolve the path to its file contents *before* the
|
||||
// PromptBuilder composes the runtime context. Previously the path string
|
||||
// itself was used verbatim as the base prompt, so the LLM received the path —
|
||||
// not the prompt — as its system message.
|
||||
func TestNewSystemPromptFilePath(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
defer resetViper()
|
||||
|
||||
const promptContent = "You are a strict regression-test persona. Marker: KIT-25-OK"
|
||||
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "kit-system-prompt-*.md")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temp prompt file: %v", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(promptContent); err != nil {
|
||||
t.Fatalf("failed to write temp prompt file: %v", err)
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
t.Fatalf("failed to close temp prompt file: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Kit with system-prompt file: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if !host.HasCustomSystemPrompt() {
|
||||
t.Error("HasCustomSystemPrompt() = false; want true when --system-prompt is set")
|
||||
}
|
||||
if got, want := host.GetSystemPromptSource(), tmpFile.Name(); got != want {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, want)
|
||||
}
|
||||
|
||||
// The composed system prompt is written back to viper after PromptBuilder
|
||||
// runs. It must contain the file's contents, not the file path.
|
||||
composed := viper.GetString("system-prompt")
|
||||
if !strings.Contains(composed, promptContent) {
|
||||
t.Errorf("composed system-prompt does not contain file contents\n composed = %q\n want substring = %q", composed, promptContent)
|
||||
}
|
||||
if strings.TrimSpace(composed) == tmpFile.Name() {
|
||||
t.Errorf("composed system-prompt is the file path verbatim (%q); LoadSystemPrompt was not applied before PromptBuilder", composed)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewSystemPromptInline confirms that inline system-prompt strings still
|
||||
// flow through unchanged after the file-path resolution change.
|
||||
func TestNewSystemPromptInline(t *testing.T) {
|
||||
if os.Getenv("ANTHROPIC_API_KEY") == "" {
|
||||
t.Skip("Skipping test: ANTHROPIC_API_KEY not set")
|
||||
}
|
||||
defer resetViper()
|
||||
|
||||
const inline = "You are a concise inline-prompt persona."
|
||||
|
||||
ctx := context.Background()
|
||||
host, err := kit.New(ctx, &kit.Options{
|
||||
Model: "anthropic/claude-sonnet-4-5-20250929",
|
||||
SystemPrompt: inline,
|
||||
Quiet: true,
|
||||
NoSession: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Kit with inline system-prompt: %v", err)
|
||||
}
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
if !host.HasCustomSystemPrompt() {
|
||||
t.Error("HasCustomSystemPrompt() = false; want true for inline prompt")
|
||||
}
|
||||
if got := host.GetSystemPromptSource(); got != inline {
|
||||
t.Errorf("GetSystemPromptSource() = %q; want %q", got, inline)
|
||||
}
|
||||
if composed := viper.GetString("system-prompt"); !strings.Contains(composed, inline) {
|
||||
t.Errorf("composed system-prompt missing inline content; got %q", composed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,6 +98,70 @@ type MCPTaskProgress struct {
|
||||
// dispatched on a goroutine.
|
||||
type MCPTaskProgressHandler func(MCPTaskProgress)
|
||||
|
||||
// MCPTaskConfig configures task-aware MCP tools/call execution. All fields
|
||||
// are optional; the zero value disables progress callbacks and applies
|
||||
// sensible polling defaults inside the engine.
|
||||
//
|
||||
// For most consumers, the flat [Options] fields (`MCPTaskMode`,
|
||||
// `MCPTaskTTL`, `MCPTaskPollInterval`, `MCPTaskMaxPollInterval`,
|
||||
// `MCPTaskTimeout`, `MCPTaskProgress`) are the preferred entry point.
|
||||
// MCPTaskConfig is exposed for the low-level [AgentConfig] path.
|
||||
type MCPTaskConfig struct {
|
||||
// PerServerMode overrides the per-server task mode resolved from
|
||||
// [MCPServerConfig]. Keys are server names. Missing entries fall back
|
||||
// to the configured value.
|
||||
PerServerMode map[string]MCPTaskMode
|
||||
|
||||
// DefaultTTL is the TTL hint sent in TaskParams when augmenting a
|
||||
// tools/call. Zero means omit the TTL — let the server pick its own.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// PollInterval is the fallback interval between tasks/get requests
|
||||
// when the server does not suggest one. Zero defaults to 1 second.
|
||||
PollInterval time.Duration
|
||||
|
||||
// MaxPollInterval caps the polling interval. Zero defaults to 5 seconds.
|
||||
MaxPollInterval time.Duration
|
||||
|
||||
// Timeout is the maximum wall-clock duration to wait for a task to
|
||||
// reach a terminal state. Zero defaults to 15 minutes. Independent
|
||||
// of the per-call context deadline; whichever fires first wins.
|
||||
Timeout time.Duration
|
||||
|
||||
// Progress, if non-nil, receives every status transition observed by
|
||||
// the polling loop.
|
||||
Progress MCPTaskProgressHandler
|
||||
}
|
||||
|
||||
// toToolsConfig converts the SDK-level [MCPTaskConfig] to the internal
|
||||
// tools-package representation. Keeps the dependency arrow internal-only.
|
||||
func (c MCPTaskConfig) toToolsConfig() tools.MCPTaskConfig {
|
||||
cfg := tools.MCPTaskConfig{
|
||||
DefaultTTL: c.DefaultTTL,
|
||||
PollInterval: c.PollInterval,
|
||||
MaxPollInterval: c.MaxPollInterval,
|
||||
Timeout: c.Timeout,
|
||||
}
|
||||
if len(c.PerServerMode) > 0 {
|
||||
cfg.PerServerMode = make(map[string]tools.MCPTaskMode, len(c.PerServerMode))
|
||||
for k, v := range c.PerServerMode {
|
||||
cfg.PerServerMode[k] = tools.MCPTaskMode(v)
|
||||
}
|
||||
}
|
||||
if c.Progress != nil {
|
||||
h := c.Progress
|
||||
cfg.Progress = func(p tools.MCPTaskProgress) {
|
||||
h(MCPTaskProgress{
|
||||
Server: p.Server,
|
||||
TaskID: p.TaskID,
|
||||
Status: MCPTaskStatus(p.Status),
|
||||
Message: p.Message,
|
||||
})
|
||||
}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// mcpTaskOptions carries SDK consumer configuration into the agent setup.
|
||||
// Stored on Options as a single value so the public surface stays compact;
|
||||
// individual fields are exposed via WithMCP* builder functions.
|
||||
|
||||
+145
-18
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
@@ -75,25 +76,151 @@ type Config = config.Config
|
||||
// local (stdio) and remote (StreamableHTTP/SSE) server types.
|
||||
type MCPServerConfig = config.MCPServerConfig
|
||||
|
||||
// ==== Agent Types (internal/agent/) ====
|
||||
// ==== Agent Types ====
|
||||
|
||||
// AgentConfig holds configuration options for creating a new Agent.
|
||||
type AgentConfig = agent.AgentConfig
|
||||
// DebugLogger is an SDK-owned interface for low-level debug logging from
|
||||
// the engine and MCP tool plumbing. Implementations must be safe for
|
||||
// concurrent use.
|
||||
//
|
||||
// Most consumers do not need to provide one; pass [Options.Debug] = true
|
||||
// to use the default logger. DebugLogger is exposed for the low-level
|
||||
// [AgentConfig] path and for embedders that want to route debug output
|
||||
// into their own logging system.
|
||||
type DebugLogger interface {
|
||||
// LogDebug records a single debug message. Implementations may drop,
|
||||
// buffer, or render the message however they choose.
|
||||
LogDebug(message string)
|
||||
// IsDebugEnabled reports whether debug logging is active. Callers may
|
||||
// check this before doing expensive formatting work.
|
||||
IsDebugEnabled() bool
|
||||
}
|
||||
|
||||
type (
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
ToolCallHandler = agent.ToolCallHandler
|
||||
// ToolExecutionHandler is a function type for handling tool execution start/end events.
|
||||
ToolExecutionHandler = agent.ToolExecutionHandler
|
||||
// ToolResultHandler is a function type for handling tool results.
|
||||
ToolResultHandler = agent.ToolResultHandler
|
||||
// ResponseHandler is a function type for handling LLM responses.
|
||||
ResponseHandler = agent.ResponseHandler
|
||||
// StreamingResponseHandler is a function type for handling streaming LLM responses.
|
||||
StreamingResponseHandler = agent.StreamingResponseHandler
|
||||
// ToolCallContentHandler is a function type for handling content that accompanies tool calls.
|
||||
ToolCallContentHandler = agent.ToolCallContentHandler
|
||||
)
|
||||
// AgentConfig holds configuration options for constructing an agent at the
|
||||
// SDK boundary. All fields use SDK-owned types, so consumers can populate
|
||||
// this struct without importing any underlying LLM-provider package.
|
||||
//
|
||||
// For most use cases, prefer the high-level [New] entry point with
|
||||
// [Options]. AgentConfig is exposed for advanced consumers that need
|
||||
// direct access to the lower-level agent configuration shape.
|
||||
type AgentConfig struct {
|
||||
// ModelConfig holds the LLM provider configuration. A nil value means
|
||||
// that the default provider/model resolution will be used.
|
||||
ModelConfig *ProviderConfig
|
||||
|
||||
// MCPConfig describes any MCP servers whose tools should be loaded
|
||||
// alongside core tools.
|
||||
MCPConfig *Config
|
||||
|
||||
// SystemPrompt is the system prompt sent to the LLM.
|
||||
SystemPrompt string
|
||||
|
||||
// MaxSteps caps the number of LLM iterations per turn. A value of
|
||||
// zero means no cap is applied at this layer.
|
||||
MaxSteps int
|
||||
|
||||
// StreamingEnabled controls whether the agent streams responses.
|
||||
StreamingEnabled bool
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When nil, remote MCP servers requiring OAuth will fail to connect.
|
||||
AuthHandler MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory MCPTokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, [AllTools]
|
||||
// is used. Provide a custom tool set (e.g. [CodingTools] or tools
|
||||
// built with a custom WorkDir) to scope agent capabilities.
|
||||
CoreTools []Tool
|
||||
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// Combined with empty CoreTools this yields a chat-only agent with
|
||||
// no built-in tools.
|
||||
DisableCoreTools bool
|
||||
|
||||
// ExtraTools are additional tools loaded alongside core and MCP tools.
|
||||
ExtraTools []Tool
|
||||
|
||||
// ToolWrapper, if non-nil, wraps the combined tool list before it is
|
||||
// handed to the LLM. Used to intercept tool calls or results.
|
||||
ToolWrapper func([]Tool) []Tool
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is invoked once for each MCP server
|
||||
// when its tools have finished loading (or failed). Called from a
|
||||
// background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// DebugLogger receives low-level debug output from the engine and the
|
||||
// MCP tool plumbing. Nil means no debug output is emitted at this
|
||||
// layer (regardless of [Options.Debug], which feeds the higher-level
|
||||
// [New] entry point). Pass an implementation here when wiring a custom
|
||||
// logger through the lower-level AgentConfig path.
|
||||
DebugLogger DebugLogger
|
||||
|
||||
// MCPTaskConfig configures task-aware MCP tools/call execution — mode
|
||||
// overrides, polling intervals, timeouts, and the progress handler.
|
||||
// The zero value preserves historical synchronous-only behaviour for
|
||||
// any server that didn't advertise task support during initialize.
|
||||
MCPTaskConfig MCPTaskConfig
|
||||
}
|
||||
|
||||
// toInternal converts an AgentConfig to its internal representation.
|
||||
// Slice and function fields convert without allocation because [Tool]
|
||||
// is a type alias for the underlying LLM-tool type.
|
||||
func (c *AgentConfig) toInternal() *agent.AgentConfig {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
out := &agent.AgentConfig{
|
||||
ModelConfig: c.ModelConfig,
|
||||
MCPConfig: c.MCPConfig,
|
||||
SystemPrompt: c.SystemPrompt,
|
||||
MaxSteps: c.MaxSteps,
|
||||
StreamingEnabled: c.StreamingEnabled,
|
||||
CoreTools: c.CoreTools,
|
||||
DisableCoreTools: c.DisableCoreTools,
|
||||
ExtraTools: c.ExtraTools,
|
||||
ToolWrapper: c.ToolWrapper,
|
||||
OnMCPServerLoaded: c.OnMCPServerLoaded,
|
||||
}
|
||||
if c.AuthHandler != nil {
|
||||
out.AuthHandler = c.AuthHandler
|
||||
}
|
||||
if c.TokenStoreFactory != nil {
|
||||
out.TokenStoreFactory = tools.TokenStoreFactory(c.TokenStoreFactory)
|
||||
}
|
||||
if c.DebugLogger != nil {
|
||||
out.DebugLogger = c.DebugLogger
|
||||
}
|
||||
out.MCPTaskConfig = c.MCPTaskConfig.toToolsConfig()
|
||||
return out
|
||||
}
|
||||
|
||||
// ToolCallHandler is invoked when the LLM produces a tool call. It receives
|
||||
// the call ID, tool name, and the JSON-encoded input arguments.
|
||||
type ToolCallHandler func(toolCallID, toolName, toolArgs string)
|
||||
|
||||
// ToolExecutionHandler is invoked at the start and end of tool execution.
|
||||
// The isStarting flag distinguishes the two phases.
|
||||
type ToolExecutionHandler func(toolCallID, toolName, toolArgs string, isStarting bool)
|
||||
|
||||
// ToolResultHandler is invoked after a tool finishes executing. The metadata
|
||||
// parameter carries optional structured data (e.g. file-diff info) from the
|
||||
// tool execution, JSON-encoded; it may be empty.
|
||||
type ToolResultHandler func(toolCallID, toolName, toolArgs, result, metadata string, isError bool)
|
||||
|
||||
// ResponseHandler is invoked with the final assistant text for each turn.
|
||||
type ResponseHandler func(content string)
|
||||
|
||||
// StreamingResponseHandler is invoked with each streamed text delta as it
|
||||
// arrives from the LLM.
|
||||
type StreamingResponseHandler func(content string)
|
||||
|
||||
// ToolCallContentHandler is invoked with any assistant text that accompanies
|
||||
// a tool call within the same step.
|
||||
type ToolCallContentHandler func(content string)
|
||||
|
||||
// ==== Provider & Model Types (internal/models/) ====
|
||||
|
||||
@@ -126,7 +253,7 @@ type ModelsRegistry = models.ModelsRegistry
|
||||
|
||||
// SpinnerFunc wraps a function in a loading spinner animation. Used for
|
||||
// Ollama model loading. Signature: func(fn func() error) error.
|
||||
type SpinnerFunc = agent.SpinnerFunc
|
||||
type SpinnerFunc func(fn func() error) error
|
||||
|
||||
// ==== LLM Types ====
|
||||
//
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
@@ -263,6 +264,101 @@ func TestConvertFromLLMMessage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentConfigNoFantasyImport verifies AgentConfig can be populated with
|
||||
// every field — including CoreTools, ExtraTools, and ToolWrapper — using
|
||||
// only SDK-owned types. This test deliberately does not import
|
||||
// "charm.land/fantasy"; the package compiling at all is the proof that the
|
||||
// SDK no longer leaks the dependency name through AgentConfig.
|
||||
//
|
||||
// Regression test for https://github.com/mark3labs/kit/issues/30.
|
||||
func TestAgentConfigNoFantasyImport(t *testing.T) {
|
||||
myTool := kit.NewTool[struct{}]("noop", "does nothing", func(_ context.Context, _ struct{}) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("ok"), nil
|
||||
})
|
||||
|
||||
wrapperCalled := false
|
||||
cfg := kit.AgentConfig{
|
||||
SystemPrompt: "you are a tester",
|
||||
MaxSteps: 5,
|
||||
StreamingEnabled: true,
|
||||
CoreTools: []kit.Tool{myTool},
|
||||
ExtraTools: []kit.Tool{myTool},
|
||||
DisableCoreTools: false,
|
||||
ToolWrapper: func(in []kit.Tool) []kit.Tool {
|
||||
wrapperCalled = true
|
||||
return in
|
||||
},
|
||||
OnMCPServerLoaded: func(_ string, _ int, _ error) {},
|
||||
}
|
||||
|
||||
if cfg.SystemPrompt != "you are a tester" {
|
||||
t.Errorf("SystemPrompt = %q, want %q", cfg.SystemPrompt, "you are a tester")
|
||||
}
|
||||
if cfg.MaxSteps != 5 {
|
||||
t.Errorf("MaxSteps = %d, want 5", cfg.MaxSteps)
|
||||
}
|
||||
if !cfg.StreamingEnabled {
|
||||
t.Error("StreamingEnabled = false, want true")
|
||||
}
|
||||
if len(cfg.CoreTools) != 1 {
|
||||
t.Errorf("CoreTools len = %d, want 1", len(cfg.CoreTools))
|
||||
}
|
||||
if len(cfg.ExtraTools) != 1 {
|
||||
t.Errorf("ExtraTools len = %d, want 1", len(cfg.ExtraTools))
|
||||
}
|
||||
|
||||
// Exercise the wrapper to confirm the func type is usable.
|
||||
out := cfg.ToolWrapper(cfg.CoreTools)
|
||||
if !wrapperCalled {
|
||||
t.Error("ToolWrapper was not invoked")
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Errorf("wrapped tool list len = %d, want 1", len(out))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentConfigToolWrapperSignature documents that AgentConfig.ToolWrapper
|
||||
// uses kit.Tool (not the underlying provider type) in its signature.
|
||||
func TestAgentConfigToolWrapperSignature(t *testing.T) {
|
||||
//nolint:staticcheck // QF1011: explicit type asserts the SDK-side func signature.
|
||||
var _ func([]kit.Tool) []kit.Tool = func(in []kit.Tool) []kit.Tool { return in }
|
||||
cfg := kit.AgentConfig{
|
||||
ToolWrapper: func(in []kit.Tool) []kit.Tool { return in },
|
||||
}
|
||||
if cfg.ToolWrapper == nil {
|
||||
t.Fatal("ToolWrapper assignment failed")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpinnerFuncSignature verifies SpinnerFunc has the documented signature
|
||||
// and can be constructed without importing any provider package.
|
||||
func TestSpinnerFuncSignature(t *testing.T) {
|
||||
called := false
|
||||
var sp kit.SpinnerFunc = func(fn func() error) error {
|
||||
called = true
|
||||
return fn()
|
||||
}
|
||||
err := sp(func() error { return nil })
|
||||
if err != nil {
|
||||
t.Errorf("SpinnerFunc returned err: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("SpinnerFunc did not invoke fn")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerTypesSignatures verifies the SDK-owned handler function types
|
||||
// can be assigned from plain function literals using only standard library
|
||||
// types in their signatures (no provider-package import required).
|
||||
func TestHandlerTypesSignatures(t *testing.T) {
|
||||
var _ kit.ToolCallHandler = func(_, _, _ string) {}
|
||||
var _ kit.ToolExecutionHandler = func(_, _, _ string, _ bool) {}
|
||||
var _ kit.ToolResultHandler = func(_, _, _, _, _ string, _ bool) {}
|
||||
var _ kit.ResponseHandler = func(_ string) {}
|
||||
var _ kit.StreamingResponseHandler = func(_ string) {}
|
||||
var _ kit.ToolCallContentHandler = func(_ string) {}
|
||||
}
|
||||
|
||||
// containsStr is a tiny helper to avoid importing strings in test.
|
||||
func containsStr(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && indexStr(s, substr) >= 0)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# Specs
|
||||
|
||||
| Spec | Status | Description |
|
||||
|------|--------|-------------|
|
||||
| [unified-bubbletea-architecture](unified-bubbletea-architecture.md) | Draft | Replace micro-program pattern with single Bubble Tea program + thick app layer |
|
||||
Reference in New Issue
Block a user