mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 19:50:13 +00:00
Compare commits
62 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7f366eab84 | |||
| e8e99b19a8 | |||
| ef072f6e59 | |||
| 49f8b485be | |||
| febdc530e1 | |||
| e610bdd2d0 | |||
| 6100e8b3a8 | |||
| 9f125f3400 | |||
| 00eab47218 | |||
| 06bf6d087a | |||
| fd960921ca | |||
| 0b651a8df9 | |||
| 7315c1dea7 | |||
| 0313fa03ad | |||
| d27022bcfb | |||
| ae722d520f | |||
| 7a04bdfeba | |||
| 7e4708f511 | |||
| 1e12102b92 | |||
| ab2a77c95e | |||
| 1e78153b50 | |||
| a613361969 | |||
| 67722b0c24 | |||
| 1a2f6da40f | |||
| 747f5be099 | |||
| d7c4565999 | |||
| bd24f3315c | |||
| 592f8dc84f | |||
| 66c4a1eb15 | |||
| 5104477631 | |||
| 394a4676a1 | |||
| 30f2bc243d | |||
| 922e246098 | |||
| 32b6376515 | |||
| cf194ff89a | |||
| 03006425fa | |||
| a322dfc59a | |||
| b1387d837e | |||
| f561f4cfd9 | |||
| 64caed57d4 | |||
| 975c30a773 | |||
| 35b9360d64 | |||
| 1b8373e133 | |||
| 1a5e4ce7c5 | |||
| 8823977612 | |||
| 24e2ea111c | |||
| 31ea80ec4f | |||
| 99f2680c2e | |||
| da7e05eb87 | |||
| a95714a22d | |||
| c4a2b0f1a3 | |||
| 2016570e2d | |||
| d557f4b870 | |||
| 65054fe3db | |||
| 97d2246375 | |||
| 1e12505741 | |||
| 6755597c9b | |||
| 45689cb30d | |||
| 78570d4188 | |||
| 7cf38b37ee | |||
| 4ef57eec4e | |||
| cbd828e190 |
@@ -1,268 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
diagnosticsTimeout = 20 * time.Second
|
||||
maxOutputBytes = 12_000
|
||||
)
|
||||
|
||||
type toolPathInput struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
type lintResult struct {
|
||||
Output string
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
}
|
||||
|
||||
absPath, ok := resolveGoFilePath(e.Input, ctx.CWD)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
borderColor = "#f38ba8" // red
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: strings.Join(msgLines, "\n"),
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func isEditOrWrite(toolName string) bool {
|
||||
return strings.EqualFold(toolName, "edit") || strings.EqualFold(toolName, "write")
|
||||
}
|
||||
|
||||
func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
var args toolPathInput
|
||||
if err := json.Unmarshal([]byte(inputJSON), &args); err != nil || args.Path == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
absPath := args.Path
|
||||
if !filepath.IsAbs(absPath) {
|
||||
absPath = filepath.Join(cwd, absPath)
|
||||
}
|
||||
|
||||
if strings.ToLower(filepath.Ext(absPath)) != ".go" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "gopls", "check", absPath)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes), Err: fmt.Errorf("failed to run gopls check: %w", err)}
|
||||
}
|
||||
|
||||
return lintResult{Output: truncate(string(out), maxOutputBytes)}
|
||||
}
|
||||
|
||||
func runGolangCILint(cwd, target string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
|
||||
args := []string{
|
||||
"run",
|
||||
target,
|
||||
"--show-stats=false",
|
||||
"--output.text.path", "stdout",
|
||||
"--output.text.colors=false",
|
||||
"--output.text.print-issued-lines=false",
|
||||
}
|
||||
cmd := exec.CommandContext(ctx, "golangci-lint", args...)
|
||||
cmd.Dir = cwd
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return lintResult{Err: fmt.Errorf("timed out after %s", diagnosticsTimeout)}
|
||||
}
|
||||
|
||||
trimmed := truncate(string(out), maxOutputBytes)
|
||||
if err == nil {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
exitErr, ok := err.(*exec.ExitError)
|
||||
if ok && exitErr.ExitCode() == 1 {
|
||||
return lintResult{Output: trimmed}
|
||||
}
|
||||
|
||||
return lintResult{Output: trimmed, Err: fmt.Errorf("failed to run golangci-lint: %w", err)}
|
||||
}
|
||||
|
||||
func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
var lines []string
|
||||
if res.Err != nil {
|
||||
lines = append(lines, "ERROR: "+res.Err.Error())
|
||||
}
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return emptyFallback
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "\n... output truncated ..."
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
lintCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return goplsCount, lintCount
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
---
|
||||
description: Read-only audit for dead code, duplication, boundary violations, and refactor opportunities
|
||||
---
|
||||
|
||||
Perform a comprehensive **read-only** audit of this repository and report
|
||||
findings. **Do not edit, rename, or delete any files.** Optional focus / scope
|
||||
hints from the user: $@
|
||||
|
||||
## Scope
|
||||
|
||||
If the user supplied focus hints above (a package path, a subsystem name, a
|
||||
concern like "TUI" or "extensions"), scope the audit accordingly. Otherwise
|
||||
audit the whole repo, prioritising the highest-traffic packages first
|
||||
(`cmd/`, `internal/`, `pkg/kit/` for this repo).
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Map the repo first**:
|
||||
- `ls` / `find` the top-level layout and list every Go package
|
||||
- Read `AGENTS.md`, `README.md`, and any `pkg/*/doc.go` to understand the
|
||||
intended architectural boundaries (SDK vs internal vs TUI vs cmd vs
|
||||
extension surface)
|
||||
- Note the public SDK surface (`pkg/kit/`) and any documented invariants
|
||||
(e.g. "no dependency name leakage", "UI never imports extensions
|
||||
directly") — these define what counts as a violation
|
||||
|
||||
2. **Hunt for dead code**:
|
||||
- Run `go vet ./...` and capture warnings
|
||||
- Use `grep` to find exported symbols (`^func [A-Z]`, `^type [A-Z]`,
|
||||
`^var [A-Z]`, `^const [A-Z]`) and cross-reference call sites. Symbols
|
||||
with zero non-test references inside the module are suspects
|
||||
- Check for unreferenced files, `// TODO: remove` markers, commented-out
|
||||
blocks, and `_ = x` discard patterns
|
||||
- If `staticcheck`, `deadcode`, or `unused` are available on PATH, run
|
||||
them and include their output verbatim
|
||||
- **Do not delete anything** — list candidates with file:line and a
|
||||
confidence level (high / medium / low)
|
||||
|
||||
3. **Find unnecessary duplication**:
|
||||
- Look for near-identical function bodies, struct shapes, or switch
|
||||
statements across packages — `grep` for repeated function signatures
|
||||
and copy-pasted string literals / error messages is a fast first pass
|
||||
- Distinguish *coincidental* duplication (two things that happen to look
|
||||
alike but evolve independently) from *unnecessary* duplication (same
|
||||
intent, drifting in lockstep) — only flag the latter
|
||||
- For each cluster, propose where the extracted helper should live
|
||||
(which package, which file) and whether it crosses a boundary
|
||||
|
||||
4. **Check concerns / boundary violations**:
|
||||
- **SDK leakage**: grep `pkg/kit/` for imports of `internal/...` types
|
||||
in exported signatures, and for dependency-name leakage in exported
|
||||
names / godoc (e.g. library jargon appearing in `LLM*` types)
|
||||
- **UI ↔ extensions**: grep `internal/ui/` for any import of
|
||||
`internal/extensions/` — per AGENTS.md the UI must not import
|
||||
extensions directly; converters in `cmd/root.go` should bridge them
|
||||
- **cmd vs internal**: business logic living in `cmd/` that should be
|
||||
in `internal/` (and vice versa)
|
||||
- **Cyclic risk**: packages that import each other transitively or that
|
||||
reach across sibling boundaries unexpectedly
|
||||
- For each violation, cite the offending import / signature with
|
||||
file:line
|
||||
|
||||
5. **Spot refactor opportunities**:
|
||||
- Long functions (>80 lines) doing multiple unrelated things
|
||||
- Deeply nested conditionals that flatten well with early returns
|
||||
- Repeated `if err != nil { return fmt.Errorf("...: %w", err) }` chains
|
||||
that could become helpers — but only where the wrapping context is
|
||||
genuinely uniform
|
||||
- Structs with too many fields that hint at split responsibilities
|
||||
- Exported APIs that would be cleaner with options structs / functional
|
||||
options
|
||||
- Tests that share setup boilerplate ripe for a helper
|
||||
- Flag each with: location, current shape (1-2 lines), proposed shape
|
||||
(1-2 lines), and estimated risk (low / medium / high)
|
||||
|
||||
6. **Cross-check against project rules**:
|
||||
- Re-read `AGENTS.md` "Key Patterns" section and verify nothing in your
|
||||
findings contradicts the documented gotchas (Yaegi interface ban,
|
||||
`prog.Send()` from `Update()`, function-field bug, etc.) — if a
|
||||
"refactor" would reintroduce a known pitfall, drop it from the report
|
||||
and note why
|
||||
|
||||
7. **Write the report** as your final message (do not write it to disk)
|
||||
structured as:
|
||||
|
||||
```
|
||||
# Code Audit Report
|
||||
|
||||
## Summary
|
||||
- N dead-code candidates
|
||||
- N duplication clusters
|
||||
- N boundary violations
|
||||
- N refactor opportunities
|
||||
|
||||
## Dead Code
|
||||
### High confidence
|
||||
- path/to/file.go:LINE — symbol — reason
|
||||
|
||||
### Medium confidence
|
||||
...
|
||||
|
||||
## Duplication
|
||||
### Cluster: <short name>
|
||||
- Sites: file:line, file:line, …
|
||||
- Suggested home: package/path
|
||||
- Notes: …
|
||||
|
||||
## Boundary Violations
|
||||
- Rule: <which rule from AGENTS.md / project convention>
|
||||
- Offender: file:line
|
||||
- Fix sketch: …
|
||||
|
||||
## Refactor Opportunities
|
||||
- Location: file:line
|
||||
- Current: …
|
||||
- Proposed: …
|
||||
- Risk: low/medium/high
|
||||
- Why it's worth it: …
|
||||
|
||||
## Suggested Next Steps
|
||||
1. …
|
||||
2. …
|
||||
```
|
||||
|
||||
8. **End the report with an explicit reminder** that no files were modified,
|
||||
and recommend the user pick the highest-leverage items to act on
|
||||
manually (or via a follow-up `/fix-issue` style prompt) rather than
|
||||
running a sweeping refactor.
|
||||
|
||||
## Guidelines
|
||||
|
||||
- **Read-only, always**: no `edit`, no `write`, no `git commit`, no `go mod
|
||||
tidy`. Use only `read`, `grep`, `find`, `ls`, and read-only `bash`
|
||||
commands (`go vet`, `go build -o /tmp/...`, `staticcheck`, etc.)
|
||||
- **Cite every finding** with `path/to/file.go:LINE` so the user can jump
|
||||
straight to it
|
||||
- **Be honest about confidence**: false positives in a code audit are
|
||||
expensive — prefer "medium confidence, worth a look" over confidently
|
||||
wrong claims
|
||||
- **Quantity isn't quality**: 10 sharp findings beat 100 nitpicks. Cut
|
||||
anything that's purely stylistic unless it directly causes one of the
|
||||
four issue categories above
|
||||
- **Skip generated code** (`*.pb.go`, `*_gen.go`, anything under
|
||||
`vendor/`) and obvious third-party copies
|
||||
- **Don't propose architectural rewrites** — stay within the existing
|
||||
shape of the repo and recommend incremental, reviewable changes
|
||||
@@ -127,6 +127,13 @@ max-tokens: 4096
|
||||
temperature: 0.7
|
||||
stream: true
|
||||
thinking-level: off # off, none, minimal, low, medium, high
|
||||
no-core-tools: false # set to true to disable all built-in core tools
|
||||
|
||||
# Skills — all three keys are optional
|
||||
no-skills: false # set to true to disable all skill loading
|
||||
skill: # explicit skill files/dirs (disables auto-discovery)
|
||||
- /path/to/skill.md
|
||||
skills-dir: "" # override project-local directory for auto-discovery
|
||||
```
|
||||
|
||||
All of the above keys can also be set programmatically via the SDK
|
||||
@@ -195,12 +202,18 @@ mcpServers:
|
||||
--compact Enable compact output mode
|
||||
--auto-compact Auto-compact conversation near context limit
|
||||
|
||||
# Extensions
|
||||
# Extensions and tools
|
||||
--extension, -e Load additional extension file(s) (repeatable)
|
||||
--no-extensions Disable all extensions
|
||||
--no-core-tools Disable all built-in core tools (bash, read, write, edit, grep, find, ls, subagent)
|
||||
--prompt-template Load a specific prompt template by name
|
||||
--no-prompt-templates Disable prompt template loading
|
||||
|
||||
# Skills
|
||||
--skill Load skill file or directory (repeatable)
|
||||
--skills-dir Override the project-local skills directory for auto-discovery
|
||||
--no-skills Disable skill loading (auto-discovery and explicit)
|
||||
|
||||
# Generation parameters
|
||||
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
|
||||
--temperature Randomness 0.0-1.0 (default: 0.7)
|
||||
@@ -226,6 +239,10 @@ kit auth login [provider] --set-default # Set provider's default model as syste
|
||||
kit auth logout [provider] # Remove credentials for provider
|
||||
kit auth status # Check authentication status
|
||||
|
||||
# GitHub Copilot login (experimental; requires active Copilot subscription)
|
||||
kit auth login copilot
|
||||
kit --model copilot/gpt-5.5 "Hello"
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
@@ -306,12 +323,15 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
### Extension Capabilities
|
||||
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
**Lifecycle Events**: OnSessionStart, OnSessionShutdown, OnBeforeAgentStart, OnAgentStart, OnAgentEnd, OnLLMUsage, OnToolCall, OnToolCallInputStart, OnToolCallInputDelta, OnToolCallInputEnd, OnToolExecutionStart, OnToolOutput, OnToolExecutionEnd, OnToolResult, OnInput, OnMessageStart, OnMessageUpdate, OnMessageEnd, OnModelChange, OnContextPrepare, OnBeforeFork, OnBeforeSessionSwitch, OnBeforeCompact, OnCustomEvent, OnSubagentStart, OnSubagentChunk, OnSubagentEnd
|
||||
|
||||
`OnAgentEnd` carries per-turn aggregates (`ToolCallCount`, `ToolNames`, `LLMCallCount`, `InputTokensDelta`, `OutputTokensDelta`, `CostDelta`, `DurationMs`) so observers don't need to maintain parallel bookkeeping. `OnLLMUsage` fires after each LLM provider call with token + cost deltas attributed to that specific call/model — use it for accurate budget enforcement *between* calls instead of waiting for the turn to finish.
|
||||
|
||||
**Custom Components**:
|
||||
- **Tools**: Add new tools the LLM can invoke
|
||||
- **Commands**: Register slash commands (e.g., `/mycommand`)
|
||||
- **Options**: Register configurable extension options
|
||||
- **Session State**: Last-write-wins key-value store via `ctx.SetState` / `GetState` / `DeleteState` / `ListState`, persisted to a per-session sidecar file outside the conversation tree
|
||||
- **Widgets**: Persistent status displays above/below input
|
||||
- **Headers/Footers**: Persistent content above/below the conversation
|
||||
- **Status Bar**: Custom status bar entries
|
||||
@@ -367,6 +387,7 @@ See the `examples/extensions/` directory:
|
||||
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
|
||||
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
|
||||
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
|
||||
- [`usage-budget.go`](examples/extensions/usage-budget.go) - Per-call usage callback (`OnLLMUsage`), session state, and enriched `OnAgentEnd` per-turn report
|
||||
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
|
||||
|
||||
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
@@ -507,6 +528,8 @@ During an interactive session, use these slash commands:
|
||||
|
||||
| Shortcut | Description |
|
||||
|----------|-------------|
|
||||
| `Ctrl+V` | Paste an image from the clipboard — shows an inline low-res thumbnail preview (tmux/zellij-safe) |
|
||||
| `Ctrl+U` | Clear all pending image attachments |
|
||||
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
|
||||
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
|
||||
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
|
||||
@@ -554,7 +577,7 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
SystemPrompt: "You are a helpful bot",
|
||||
ConfigFile: "/path/to/config.yml",
|
||||
MaxSteps: 10,
|
||||
Streaming: true,
|
||||
Streaming: ptr(true), // *bool: nil = unset (default true), &false = off
|
||||
Quiet: true,
|
||||
|
||||
// Generation parameters (override env/config/per-model defaults)
|
||||
@@ -579,7 +602,9 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
// Tool options
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
DisableCoreTools: true, // Disable all built-in core tools; also controllable via
|
||||
// --no-core-tools flag, KIT_NO_CORE_TOOLS env var,
|
||||
// or no-core-tools: true in .kit.yml
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
@@ -599,6 +624,38 @@ are pointer types so explicit `0.0` is distinguishable from "leave alone"; a
|
||||
non-zero `MaxTokens` suppresses automatic right-sizing the same way `--max-tokens`
|
||||
does on the CLI.
|
||||
|
||||
### Functional options (`NewAgent`)
|
||||
|
||||
For simple programmatic setups, `kit.NewAgent` offers an ergonomic
|
||||
functional-options front door over `kit.New`. Streaming is **enabled by
|
||||
default**; pass `kit.WithStreaming(false)` to opt out.
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.WithMaxTokens(8192),
|
||||
kit.WithThinkingLevel("medium"),
|
||||
kit.Ephemeral(), // in-memory session, no persistence
|
||||
)
|
||||
```
|
||||
|
||||
Available options: `WithModel`, `WithSystemPrompt`, `WithStreaming`,
|
||||
`WithMaxTokens`, `WithThinkingLevel`, `WithTools`, `WithExtraTools`,
|
||||
`WithProviderAPIKey`, `WithProviderURL`, `WithConfigFile`, `WithDebug`, and
|
||||
`Ephemeral`. For advanced configuration not covered by the helpers (custom MCP
|
||||
config, in-process MCP servers, session backends, MCP task tuning) construct an
|
||||
`Options` value explicitly and call `kit.New`.
|
||||
|
||||
### Per-instance config isolation
|
||||
|
||||
Each `kit.New` / `kit.NewAgent` call owns an **isolated configuration store**,
|
||||
so constructing multiple Kit instances in the same process is safe: setting the
|
||||
model, thinking level, or generation parameters on one never affects another,
|
||||
and runtime mutators (`SetModel`, `SetThinkingLevel`) only touch the owning
|
||||
instance. This makes subagent spawning and multi-Kit embedding race-free with
|
||||
no external synchronization required.
|
||||
|
||||
### MCP OAuth (remote MCP servers)
|
||||
|
||||
When a remote MCP server returns 401, Kit runs the full OAuth flow (dynamic
|
||||
@@ -756,6 +813,45 @@ host, _ := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Runtime Skills & Context Files
|
||||
|
||||
For multi-tenant hosts (chatbots, per-user agents, web services), the SDK
|
||||
lets you swap skills and `AGENTS.md`-style context files **after** Kit
|
||||
construction. Every mutation recomposes the system prompt and applies it to
|
||||
the agent so the next turn picks up the new instructions — no restart needed.
|
||||
|
||||
```go
|
||||
// Programmatic skill (no file on disk required).
|
||||
host.AddSkill(&kit.Skill{
|
||||
Name: "polite-french",
|
||||
Description: "Respond in French and always greet the user.",
|
||||
Content: "Always reply in French. Open every response with 'Bonjour'.",
|
||||
})
|
||||
|
||||
// Or load one from disk.
|
||||
host.LoadAndAddSkill("/var/skills/refund-policy.md")
|
||||
|
||||
// Per-user AGENTS.md content pulled from a database.
|
||||
host.AddContextFileContent(
|
||||
fmt.Sprintf("session://%s/AGENTS.md", userID),
|
||||
rulesFromDB,
|
||||
)
|
||||
|
||||
// Tear down session-specific state on logout.
|
||||
host.RemoveSkill("polite-french")
|
||||
host.RemoveContextFile(fmt.Sprintf("session://%s/AGENTS.md", userID))
|
||||
|
||||
// Or replace the whole set atomically.
|
||||
host.SetSkills(activeSkillsForUser)
|
||||
host.SetContextFiles(activeContextForUser)
|
||||
```
|
||||
|
||||
Skills dedupe by `Name`, context files dedupe by `Path` (which can be any
|
||||
opaque identifier — it doesn't have to be a real filesystem path). All
|
||||
mutators and readers (`GetSkills`, `GetContextFiles`) are safe to call
|
||||
concurrently from multiple goroutines. See the [SDK overview docs](/sdk/overview#runtime-skills-and-context-files)
|
||||
for the full reference.
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Subagent Pattern
|
||||
@@ -872,6 +968,7 @@ npm/ - NPM package wrapper for distribution
|
||||
|
||||
- **Anthropic** - Claude models (native, prompt caching, OAuth)
|
||||
- **OpenAI** - GPT models
|
||||
- **Copilot** - GitHub Copilot models (`copilot`, requires active Copilot subscription)
|
||||
- **Google** - Gemini models
|
||||
- **Ollama** - Local models
|
||||
- **Azure OpenAI** - Azure-hosted OpenAI
|
||||
@@ -897,6 +994,31 @@ This automatically defaults to `custom/custom` without needing to specify a mode
|
||||
- Reasoning and temperature support
|
||||
- Optional `CUSTOM_API_KEY` environment variable or `--provider-api-key` flag
|
||||
|
||||
### Auto-routed Providers
|
||||
|
||||
Any provider in the [models.dev](https://models.dev) database can be used as
|
||||
`provider/model` without a dedicated native integration. Kit auto-routes the
|
||||
request through the matching **wire protocol** based on the provider's npm package
|
||||
(or per-model override), using its `api` URL as the base:
|
||||
|
||||
| npm package | Wire protocol |
|
||||
|-------------|---------------|
|
||||
| `@ai-sdk/openai` | OpenAI (Responses API) |
|
||||
| `@ai-sdk/openai-compatible` | OpenAI (chat completions) |
|
||||
| `@ai-sdk/anthropic` | Anthropic |
|
||||
| `@ai-sdk/google` | Google Gemini |
|
||||
|
||||
Providers with an `api` URL but an unrecognized npm package fall back to the
|
||||
OpenAI-compatible wire. Because routing follows the wire protocol, aggregator/proxy
|
||||
providers work across all of their models — including Claude, GPT, *and* Gemini
|
||||
routes:
|
||||
|
||||
```bash
|
||||
kit --model opencode/claude-haiku-4-5 "Hello" # → Anthropic wire
|
||||
kit --model opencode/gpt-5 "Hello" # → OpenAI wire
|
||||
kit --model opencode/gemini-3.5-flash "Hello" # → Google wire
|
||||
```
|
||||
|
||||
### Model String Format
|
||||
|
||||
```bash
|
||||
|
||||
+157
-4
@@ -31,10 +31,12 @@ using OAuth flows. Stored credentials take precedence over environment variables
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI API (OAuth and API key)
|
||||
- copilot: GitHub Copilot (GitHub device login)
|
||||
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth login copilot
|
||||
kit auth logout anthropic
|
||||
kit auth status`,
|
||||
}
|
||||
@@ -54,6 +56,7 @@ environment variables when making API calls.
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API (OAuth)
|
||||
- openai: OpenAI ChatGPT Plus/Pro (Codex OAuth)
|
||||
- copilot: GitHub Copilot (GitHub device login, experimental)
|
||||
|
||||
Flags:
|
||||
--set-default Set this provider's default model as the system default
|
||||
@@ -61,7 +64,8 @@ Flags:
|
||||
Examples:
|
||||
kit auth login anthropic
|
||||
kit auth login openai
|
||||
kit auth login openai --set-default`,
|
||||
kit auth login copilot
|
||||
kit auth login copilot --set-default`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogin,
|
||||
}
|
||||
@@ -80,10 +84,12 @@ You will need to use environment variables or command-line flags for authenticat
|
||||
Available providers:
|
||||
- anthropic: Anthropic Claude API
|
||||
- openai: OpenAI API
|
||||
- copilot: GitHub Copilot
|
||||
|
||||
Example:
|
||||
kit auth logout anthropic
|
||||
kit auth logout openai`,
|
||||
kit auth logout openai
|
||||
kit auth logout copilot`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runAuthLogout,
|
||||
}
|
||||
@@ -113,6 +119,7 @@ var (
|
||||
var defaultModels = map[string]string{
|
||||
"anthropic": "anthropic/claude-sonnet-4-5-20250929",
|
||||
"openai": "openai/gpt-5.4",
|
||||
"copilot": "copilot/gpt-5.5",
|
||||
}
|
||||
|
||||
// setDefaultModelIfRequested sets the default model for the given provider
|
||||
@@ -143,6 +150,7 @@ func init() {
|
||||
authLoginCmd.Flags().BoolVar(&loginSetDefault, "set-default", false, "Set this provider's default model as the system default after login")
|
||||
}
|
||||
|
||||
// runAuthLogin dispatches OAuth login to the selected provider.
|
||||
func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
provider := strings.ToLower(args[0])
|
||||
|
||||
@@ -151,8 +159,10 @@ func runAuthLogin(cmd *cobra.Command, args []string) error {
|
||||
return loginAnthropic()
|
||||
case "openai":
|
||||
return loginOpenAI()
|
||||
case "copilot":
|
||||
return loginCopilot(cmd.Context())
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,8 +174,10 @@ func runAuthLogout(cmd *cobra.Command, args []string) error {
|
||||
return logoutAnthropic()
|
||||
case "openai":
|
||||
return logoutOpenAI()
|
||||
case "copilot":
|
||||
return logoutCopilot()
|
||||
default:
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai", provider)
|
||||
return fmt.Errorf("unsupported provider: %s. Available providers: anthropic, openai, copilot", provider)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,9 +256,31 @@ func runAuthStatus(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check GitHub Copilot credentials
|
||||
fmt.Print("\nGitHub Copilot: ")
|
||||
if hasCopilotCreds, err := cm.HasCopilotCredentials(); err != nil {
|
||||
fmt.Printf("Error checking credentials: %v\n", err)
|
||||
} else if hasCopilotCreds {
|
||||
if creds, err := cm.GetCopilotCredentials(); err != nil {
|
||||
fmt.Printf("Error reading credentials: %v\n", err)
|
||||
} else {
|
||||
status := "✓ Authenticated"
|
||||
if creds.IsExpired() {
|
||||
status = "⚠️ Token expired (will refresh automatically)"
|
||||
} else if creds.NeedsRefresh() {
|
||||
status = "⚠️ Token expires soon (will refresh automatically)"
|
||||
}
|
||||
|
||||
fmt.Printf("%s (GitHub OAuth, stored %s)\n", status, creds.CreatedAt.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
} else {
|
||||
fmt.Println("✗ Not authenticated")
|
||||
}
|
||||
|
||||
fmt.Println("\nTo authenticate with a provider:")
|
||||
fmt.Println(" kit auth login anthropic")
|
||||
fmt.Println(" kit auth login openai")
|
||||
fmt.Println(" kit auth login copilot")
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -517,6 +551,85 @@ func loginOpenAI() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// loginCopilot authenticates GitHub Copilot using GitHub device flow.
|
||||
func loginCopilot(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
if hasAuth, err := cm.HasCopilotCredentials(); err == nil && hasAuth {
|
||||
var reauth bool
|
||||
err := huh.NewConfirm().
|
||||
Title("You are already authenticated with GitHub Copilot").
|
||||
Description("Do you want to re-authenticate?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&reauth).
|
||||
Run()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prompt for re-authentication: %w", err)
|
||||
}
|
||||
if !reauth {
|
||||
fmt.Println("Authentication cancelled.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
client := auth.NewCopilotOAuthClient()
|
||||
|
||||
fmt.Println("🔐 Starting GitHub Copilot authentication...")
|
||||
fmt.Println("This uses GitHub device login and requires an active GitHub Copilot subscription.")
|
||||
fmt.Println("Experimental: this uses VS Code Copilot Chat client identifiers.")
|
||||
fmt.Println()
|
||||
|
||||
deviceCode, err := client.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start GitHub device login: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("📱 Open this page and enter the code:")
|
||||
fmt.Printf("\n%s\n\n", deviceCode.VerificationURI)
|
||||
fmt.Printf("Code: %s\n\n", deviceCode.UserCode)
|
||||
auth.TryOpenBrowser(deviceCode.VerificationURI)
|
||||
|
||||
fmt.Println("Waiting for GitHub authorization...")
|
||||
githubToken, err := client.PollDeviceToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to complete GitHub device login: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("\n🔄 Exchanging GitHub token for Copilot access token...")
|
||||
creds, err := client.ExchangeGitHubToken(ctx, githubToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get GitHub Copilot token: %w", err)
|
||||
}
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
|
||||
return fmt.Errorf("failed to store credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Successfully authenticated with GitHub Copilot!")
|
||||
fmt.Printf("📁 Credentials stored in: %s\n", cm.GetCredentialsPath())
|
||||
fmt.Println("\n🎉 Your GitHub Copilot credentials will now be used for copilot/* models.")
|
||||
fmt.Println("💡 You can check your authentication status with: kit auth status")
|
||||
|
||||
if err := setDefaultModelIfRequested("copilot"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !loginSetDefault {
|
||||
fmt.Println("\n💡 To set Copilot as your default model, run:")
|
||||
fmt.Println(" kit auth login copilot --set-default")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// callbackServer holds the HTTP server and channel for receiving the OAuth callback
|
||||
type callbackServer struct {
|
||||
Server *http.Server
|
||||
@@ -635,3 +748,43 @@ func logoutOpenAI() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func logoutCopilot() error {
|
||||
cm, err := kit.NewCredentialManager()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
hasAuth, err := cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check authentication status: %w", err)
|
||||
}
|
||||
|
||||
if !hasAuth {
|
||||
fmt.Println("You are not currently authenticated with GitHub Copilot.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var confirm bool
|
||||
err = huh.NewConfirm().
|
||||
Title("Remove GitHub Copilot credentials").
|
||||
Description("Are you sure you want to remove your stored credentials?").
|
||||
Affirmative("Yes").
|
||||
Negative("No").
|
||||
Value(&confirm).
|
||||
Run()
|
||||
if err != nil || !confirm {
|
||||
fmt.Println("Logout cancelled.")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
return fmt.Errorf("failed to remove credentials: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("✓ Successfully logged out from GitHub Copilot!")
|
||||
fmt.Println("You will need to authenticate again with 'kit auth login copilot'.")
|
||||
fmt.Println("Tip: this removes local credentials only. Revoke the GitHub OAuth grant at https://github.com/settings/applications")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/ui"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// extensionContextDeps groups the runtime dependencies needed to wire up
|
||||
// an extensions.Context for the interactive TUI mode.
|
||||
type extensionContextDeps struct {
|
||||
ctx context.Context
|
||||
cwd string
|
||||
modelName string
|
||||
interactive bool
|
||||
kitInstance *kit.Kit
|
||||
appInstance *app.App
|
||||
usageTracker *ui.UsageTracker
|
||||
}
|
||||
|
||||
// buildInteractiveExtensionContext returns an extensions.Context with every
|
||||
// field except Print / PrintInfo / PrintError populated. Callers must set
|
||||
// the three print routes appropriately for their phase (startup buffering
|
||||
// vs. live runtime routing).
|
||||
//
|
||||
// The headless half (data access, state, options, tree navigation, skills,
|
||||
// templates, model resolution, subagents) comes from extbridge.BaseContext;
|
||||
// this function overlays the TUI-specific fields and overrides SetModel /
|
||||
// ReloadExtensions with TUI-aware versions.
|
||||
func buildInteractiveExtensionContext(deps extensionContextDeps) extensions.Context {
|
||||
kitInstance := deps.kitInstance
|
||||
appInstance := deps.appInstance
|
||||
usageTracker := deps.usageTracker
|
||||
|
||||
ec := extbridge.BaseContext(deps.ctx, kitInstance)
|
||||
|
||||
ec.CWD = deps.cwd
|
||||
ec.Model = deps.modelName
|
||||
ec.Interactive = deps.interactive
|
||||
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
appInstance.PrintBlockFromExtension(opts)
|
||||
}
|
||||
ec.SendMessage = func(text string) { appInstance.Run(text) }
|
||||
ec.CancelAndSend = func(text string) { appInstance.InterruptAndSend(text) }
|
||||
ec.Abort = func() { appInstance.Abort() }
|
||||
ec.IsIdle = func() bool { return !appInstance.IsBusy() }
|
||||
ec.Compact = func(cfg extensions.CompactConfig) error {
|
||||
return appInstance.CompactAsync(cfg.CustomInstructions, cfg.OnComplete, cfg.OnError)
|
||||
}
|
||||
ec.SendMultimodalMessage = func(text string, files []extensions.FilePart) {
|
||||
parts := make([]kit.LLMFilePart, len(files))
|
||||
for i, f := range files {
|
||||
parts[i] = kit.LLMFilePart{
|
||||
Filename: f.Filename,
|
||||
Data: f.Data,
|
||||
MediaType: f.MediaType,
|
||||
}
|
||||
}
|
||||
appInstance.RunWithFiles(text, parts)
|
||||
}
|
||||
ec.GetSessionUsage = func() extensions.SessionUsage {
|
||||
if usageTracker == nil {
|
||||
return extensions.SessionUsage{}
|
||||
}
|
||||
stats := usageTracker.GetSessionStats()
|
||||
return extensions.SessionUsage{
|
||||
TotalInputTokens: stats.TotalInputTokens,
|
||||
TotalOutputTokens: stats.TotalOutputTokens,
|
||||
TotalCacheReadTokens: stats.TotalCacheReadTokens,
|
||||
TotalCacheWriteTokens: stats.TotalCacheWriteTokens,
|
||||
TotalCost: stats.TotalCost,
|
||||
RequestCount: stats.RequestCount,
|
||||
}
|
||||
}
|
||||
ec.Exit = func() { appInstance.QuitFromExtension() }
|
||||
|
||||
// TUI widgets/chrome — mutate runner state, then notify the TUI.
|
||||
// Always use a goroutine for NotifyWidgetUpdate: prog.Send() deadlocks
|
||||
// if called synchronously from inside BubbleTea's Update() handler.
|
||||
// All call sites use go-routines uniformly.
|
||||
ec.SetWidget = func(config extensions.WidgetConfig) {
|
||||
kitInstance.Extensions().SetWidget(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveWidget = func(id string) {
|
||||
kitInstance.Extensions().RemoveWidget(id)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetHeader = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetHeader(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveHeader = func() {
|
||||
kitInstance.Extensions().RemoveHeader()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetFooter = func(config extensions.HeaderFooterConfig) {
|
||||
kitInstance.Extensions().SetFooter(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveFooter = func() {
|
||||
kitInstance.Extensions().RemoveFooter()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetUIVisibility = func(v extensions.UIVisibility) {
|
||||
kitInstance.Extensions().SetUIVisibility(v)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditor = func(config extensions.EditorConfig) {
|
||||
kitInstance.Extensions().SetEditor(config)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.ResetEditor = func() {
|
||||
kitInstance.Extensions().ResetEditor()
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.SetEditorText = func(text string) {
|
||||
appInstance.SetEditorTextFromExtension(text)
|
||||
}
|
||||
ec.SetStatus = func(key string, text string, priority int) {
|
||||
kitInstance.Extensions().SetStatus(extensions.StatusBarEntry{
|
||||
Key: key,
|
||||
Text: text,
|
||||
Priority: priority,
|
||||
})
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
ec.RemoveStatus = func(key string) {
|
||||
kitInstance.Extensions().RemoveStatus(key)
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
}
|
||||
|
||||
// Interactive prompts — channel-based round trips through the TUI.
|
||||
ec.PromptSelect = func(config extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "select",
|
||||
Message: config.Message,
|
||||
Options: config.Options,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptSelectResult{Value: resp.Value, Index: resp.Index}
|
||||
}
|
||||
ec.PromptConfirm = func(config extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
def := "false"
|
||||
if config.DefaultValue {
|
||||
def = "true"
|
||||
}
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "confirm",
|
||||
Message: config.Message,
|
||||
Default: def,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptConfirmResult{Value: resp.Confirmed}
|
||||
}
|
||||
ec.PromptInput = func(config extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
ch := make(chan app.PromptResponse, 1)
|
||||
appInstance.SendPromptRequest(app.PromptRequestEvent{
|
||||
PromptType: "input",
|
||||
Message: config.Message,
|
||||
Placeholder: config.Placeholder,
|
||||
Default: config.Default,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
return extensions.PromptInputResult{Value: resp.Value}
|
||||
}
|
||||
ec.ShowOverlay = func(config extensions.OverlayConfig) extensions.OverlayResult {
|
||||
ch := make(chan app.OverlayResponse, 1)
|
||||
appInstance.SendOverlayRequest(app.OverlayRequestEvent{
|
||||
Title: config.Title,
|
||||
Content: config.Content.Text,
|
||||
Markdown: config.Content.Markdown,
|
||||
BorderColor: config.Style.BorderColor,
|
||||
Background: config.Style.Background,
|
||||
Width: config.Width,
|
||||
MaxHeight: config.MaxHeight,
|
||||
Anchor: string(config.Anchor),
|
||||
Actions: config.Actions,
|
||||
ResponseCh: ch,
|
||||
})
|
||||
resp := <-ch
|
||||
if resp.Cancelled {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
return extensions.OverlayResult{
|
||||
Action: resp.Action,
|
||||
Index: resp.Index,
|
||||
}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error {
|
||||
return appInstance.SuspendTUI(callback)
|
||||
}
|
||||
|
||||
// TUI-aware model switch: also notifies the TUI status bar and
|
||||
// refreshes the usage tracker for correct token counting.
|
||||
ec.SetModel = func(modelString string) error {
|
||||
// Capture previous model for the ModelChange event.
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
err := kitInstance.SetModel(context.Background(), modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI so it updates model in status bar.
|
||||
p, m, _ := models.ParseModelString(modelString)
|
||||
appInstance.NotifyModelChanged(p, m)
|
||||
// Update the context's Model field so handlers see it.
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
// Fire OnModelChange event to extensions.
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
// Update usage tracker with new model info for correct token counting.
|
||||
ui.UpdateUsageTrackerForModel(usageTracker, modelString, viper.GetString("provider-api-key"))
|
||||
return nil
|
||||
}
|
||||
|
||||
ec.RenderMessage = func(rendererName, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(rendererName)
|
||||
if renderer == nil || renderer.Render == nil {
|
||||
appInstance.PrintFromExtension("", content)
|
||||
return
|
||||
}
|
||||
w, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if w == 0 {
|
||||
w = 80
|
||||
}
|
||||
rendered := renderer.Render(content, w)
|
||||
appInstance.PrintFromExtension("", rendered)
|
||||
}
|
||||
ec.ReloadExtensions = func() error {
|
||||
err := kitInstance.Extensions().Reload()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Notify TUI that widgets/status/commands may have changed.
|
||||
go appInstance.NotifyWidgetUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Theme management (TUI only).
|
||||
ec.RegisterTheme = func(name string, config extensions.ThemeColorConfig) {
|
||||
tc := func(c extensions.ThemeColor) [2]string { return [2]string{c.Light, c.Dark} }
|
||||
ui.RegisterThemeFromConfig(name,
|
||||
tc(config.Primary), tc(config.Secondary),
|
||||
tc(config.Success), tc(config.Warning),
|
||||
tc(config.Error), tc(config.Info),
|
||||
tc(config.Text), tc(config.Muted),
|
||||
tc(config.VeryMuted), tc(config.Background),
|
||||
tc(config.Border), tc(config.MutedBorder),
|
||||
tc(config.System), tc(config.Tool),
|
||||
tc(config.Accent), tc(config.Highlight),
|
||||
tc(config.MdHeading), tc(config.MdLink),
|
||||
tc(config.MdKeyword), tc(config.MdString),
|
||||
tc(config.MdNumber), tc(config.MdComment),
|
||||
)
|
||||
}
|
||||
ec.SetTheme = func(name string) error {
|
||||
return ui.ApplyTheme(name)
|
||||
}
|
||||
ec.ListThemes = func() []string {
|
||||
return ui.ListThemes()
|
||||
}
|
||||
|
||||
// Skill context-injection (drives a new agent turn through the TUI).
|
||||
ec.InjectSkillAsContext = func(skillName string) string {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
for _, s := range skills {
|
||||
if s.Name == skillName {
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("skill not found: %s", skillName)
|
||||
}
|
||||
ec.InjectRawSkillAsContext = func(path string) string {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
if err != "" {
|
||||
return err
|
||||
}
|
||||
appInstance.Run(fmt.Sprintf("<skill name=%q>\n%s\n</skill>", s.Name, s.Content))
|
||||
return ""
|
||||
}
|
||||
|
||||
return ec
|
||||
}
|
||||
+339
-890
File diff suppressed because it is too large
Load Diff
@@ -58,6 +58,7 @@ kit install github.com/mark3labs/kit/examples/extensions --local
|
||||
| `project-rules.go` | Project-specific rules | Session data, file reading |
|
||||
| `protected-paths.go` | Block dangerous operations | `OnToolCall` with blocking |
|
||||
| `permission-gate.go` | Confirm destructive actions | `OnToolCall` with confirmation |
|
||||
| `usage-budget.go` | Soft cost cap + per-turn report | `OnLLMUsage`, `SetState`/`GetState`, enriched `AgentEndEvent` |
|
||||
|
||||
### Tools & Commands
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// without panicking and properly guards nil ctx calls.
|
||||
func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Emit SessionStart - should not panic even with nil ctx functions
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -26,7 +26,7 @@ func TestSubagentMonitor_SessionStart(t *testing.T) {
|
||||
// creates entries and emits widget updates.
|
||||
func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start session
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
@@ -84,7 +84,7 @@ func TestSubagentMonitor_SubagentLifecycle(t *testing.T) {
|
||||
// TestSubagentMonitor_MultipleSubagents verifies multiple parallel subagents.
|
||||
func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -134,7 +134,7 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
@@ -186,7 +186,7 @@ func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
harness.LoadFile("./subagent-monitor.go")
|
||||
|
||||
// Start then shutdown
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// Init demonstrates the three primitives added in issue #53:
|
||||
//
|
||||
// 1. api.OnLLMUsage(...) — per-LLM-call usage callback with token + cost
|
||||
// deltas. Use this for budget enforcement that reacts between calls
|
||||
// within a single agent turn, rather than only at turn boundaries.
|
||||
//
|
||||
// 2. ctx.SetState / ctx.GetState / ctx.DeleteState / ctx.ListState —
|
||||
// last-write-wins, session-scoped key-value store backed by a sidecar
|
||||
// file. Use this for snapshot state (current value of X) instead of
|
||||
// ctx.AppendEntry, which is append-only and bloats branch reads.
|
||||
//
|
||||
// 3. ext.AgentEndEvent.ToolCallCount / .ToolNames / .LLMCallCount /
|
||||
// .InputTokensDelta / .OutputTokensDelta / .CostDelta / .DurationMs —
|
||||
// per-turn aggregates so observer extensions don't need to maintain
|
||||
// parallel bookkeeping.
|
||||
//
|
||||
// Together these support a simple soft-budget cap: warn when the
|
||||
// cumulative cost in this session exceeds a threshold, and print a
|
||||
// per-turn report on AgentEnd.
|
||||
//
|
||||
// Usage: kit -e examples/extensions/usage-budget.go
|
||||
func Init(api ext.API) {
|
||||
const warnAtKey = "usage-budget:warn-at-usd"
|
||||
|
||||
// 1. Print per-LLM-call usage with provider, model, and cost.
|
||||
api.OnLLMUsage(func(e ext.LLMUsageEvent, ctx ext.Context) {
|
||||
ctx.Print(fmt.Sprintf(
|
||||
"[usage] step=%d %s/%s tokens=↑%d ↓%d cache=↑%d/↓%d cost=$%.4f (%s)",
|
||||
e.StepNumber, e.Provider, e.Model,
|
||||
e.InputTokens, e.OutputTokens,
|
||||
e.CacheWriteTokens, e.CacheReadTokens,
|
||||
e.Cost, e.FinishReason,
|
||||
))
|
||||
|
||||
// 2. Persist running total in last-write-wins state.
|
||||
current := 0.0
|
||||
if raw, ok := ctx.GetState("usage-budget:total-cost"); ok {
|
||||
current, _ = strconv.ParseFloat(raw, 64)
|
||||
}
|
||||
current += e.Cost
|
||||
ctx.SetState("usage-budget:total-cost", strconv.FormatFloat(current, 'f', 6, 64))
|
||||
|
||||
// Soft warn-at threshold (configurable via state).
|
||||
warnAt := 0.50
|
||||
if raw, ok := ctx.GetState(warnAtKey); ok {
|
||||
if v, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||
warnAt = v
|
||||
}
|
||||
}
|
||||
if current > warnAt {
|
||||
ctx.PrintError(fmt.Sprintf(
|
||||
"[usage] session cost $%.4f exceeds soft cap $%.2f",
|
||||
current, warnAt,
|
||||
))
|
||||
}
|
||||
})
|
||||
|
||||
// 3. Print a per-turn summary using the enriched AgentEndEvent.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
ctx.Print(fmt.Sprintf(
|
||||
"[turn] stop=%s tools=%d llm-calls=%d tokens=↑%d ↓%d cost=$%.4f duration=%dms",
|
||||
e.StopReason, e.ToolCallCount, e.LLMCallCount,
|
||||
e.InputTokensDelta, e.OutputTokensDelta, e.CostDelta, e.DurationMs,
|
||||
))
|
||||
if len(e.ToolNames) > 0 {
|
||||
ctx.Print(fmt.Sprintf("[turn] tool order: %v", e.ToolNames))
|
||||
}
|
||||
})
|
||||
|
||||
// Bootstrap default soft cap once per session.
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
if _, ok := ctx.GetState(warnAtKey); !ok {
|
||||
ctx.SetState(warnAtKey, "0.50")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -42,4 +42,14 @@ defer host.Close()
|
||||
response, err := host.Prompt(ctx, "Hello!")
|
||||
```
|
||||
|
||||
Or use the functional-options constructor for quick setups (streaming defaults on):
|
||||
|
||||
```go
|
||||
host, err := kit.NewAgent(ctx,
|
||||
kit.WithModel("anthropic/claude-sonnet-4-5-20250929"),
|
||||
kit.WithSystemPrompt("You are a helpful assistant."),
|
||||
kit.Ephemeral(),
|
||||
)
|
||||
```
|
||||
|
||||
See the [SDK README](../../pkg/kit/README.md) for the full API reference.
|
||||
|
||||
@@ -1,32 +1,34 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.2
|
||||
go 1.26.3
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.6
|
||||
charm.land/fantasy v0.23.0
|
||||
charm.land/bubbletea/v2 v2.0.7
|
||||
charm.land/fantasy v0.25.0
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.24.1
|
||||
github.com/alecthomas/chroma/v2 v2.26.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/colorprofile v0.4.3
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.12.2
|
||||
github.com/coder/acp-go-sdk v0.13.5
|
||||
github.com/fsnotify/fsnotify v1.10.1
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.51.0
|
||||
github.com/mark3labs/mcp-go v0.54.1
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.42.0
|
||||
golang.org/x/image v0.41.0
|
||||
golang.org/x/term v0.43.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -37,39 +39,39 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 // indirect
|
||||
github.com/aws/smithy-go v1.25.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 // indirect
|
||||
github.com/aws/smithy-go v1.26.0 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.3 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/dlclark/regexp2 v1.12.0 // indirect
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -79,13 +81,13 @@ require (
|
||||
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.7 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.5 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.13 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
@@ -97,7 +99,7 @@ require (
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/gjson v1.19.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
@@ -105,21 +107,21 @@ require (
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 // indirect
|
||||
go.opentelemetry.io/otel v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.44.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.44.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/crypto v0.52.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect
|
||||
golang.org/x/net v0.55.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.277.0 // indirect
|
||||
google.golang.org/genai v1.55.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 // indirect
|
||||
google.golang.org/grpc v1.81.0 // indirect
|
||||
google.golang.org/api v0.282.0 // indirect
|
||||
google.golang.org/genai v1.58.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
|
||||
google.golang.org/grpc v1.81.1 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
@@ -131,12 +133,12 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.24 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
golang.org/x/sys v0.45.0 // indirect
|
||||
golang.org/x/text v0.37.0
|
||||
)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
cel.dev/expr v0.25.2/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo=
|
||||
charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g=
|
||||
charm.land/fantasy v0.23.0 h1:pocjwC5CxfEg1Bpwb0raML2d5ijo3op33Mmd6hYJyo4=
|
||||
charm.land/fantasy v0.23.0/go.mod h1:4yzSsd9XmFEVjRnF1P0LTEbLTmQX6OLnPkrHaf7iruo=
|
||||
charm.land/bubbletea/v2 v2.0.7 h1:7qw2tTAVar7m7klOPBYfTB0mniv/RuexsYwMRNxSeL0=
|
||||
charm.land/bubbletea/v2 v2.0.7/go.mod h1:DGW2q8gvzHnOpMpZTORs0aySVHCox5C+2Svk0fci1qs=
|
||||
charm.land/fantasy v0.25.0 h1:oXOWY1ivmTSnhYGzAolscF8zKtavWZyBWv0LHRSwN5Q=
|
||||
charm.land/fantasy v0.25.0/go.mod h1:8QrWUzIcKwZQP+aAnC9vLu3iID6hu9/Jt+rPMiieBkc=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
|
||||
charm.land/x/vcr v0.1.1/go.mod h1:eByq2gqzWvcct/8XE2XO5KznoWEBiXH56+y2gphbltM=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
@@ -16,6 +18,11 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
cloud.google.com/go/iam v1.11.0/go.mod h1:KP+nKGugNJW4LcLx1uEZcq1ok5sQHFaQehQNl4QDgV4=
|
||||
cloud.google.com/go/longrunning v0.5.6/go.mod h1:vUaDrWYOMKRuhiv6JBnn49YxCPz2Ayn9GqyjaBT8/mA=
|
||||
cloud.google.com/go/monitoring v1.29.0/go.mod h1:72NOVjJXHY/HBfoLT0+qlCZBT059+9VXLeAnL2PeeVM=
|
||||
cloud.google.com/go/storage v1.62.1/go.mod h1:cpYz/kRVZ+UQAF1uHeea10/9ewcRbxGoGNKsS9daSXA=
|
||||
cloud.google.com/go/translate v1.10.3/go.mod h1:GW0vC1qvPtd3pgtypCv4k4U8B7EdgK9/QEF2aJEUovs=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1/go.mod h1:pzBXCYn05zvYIrwLgtK8Ap8QcjRg+0i76tMQdWN6wOk=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
@@ -24,52 +31,68 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6Xu
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.32.0/go.mod h1:RD2SsorTmYhF6HkTmDw7KmPYQk8OBYwTkuasChwv7R4=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.56.0/go.mod h1:hEpiGU18xf70qb3jbTcIggWAiEfX/cOIVc2OTe4OegA=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.56.0/go.mod h1:6ZZMQhZKDvUvkJw2rc+oDP90tMMzuU/J+5HG1ZmPOmE=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/Rhymond/go-money v1.0.15/go.mod h1:iHvCuIvitxu2JIlAlhF0g9jHqjRSr+rpdOs7Omqlupg=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1 h1:m5ffpfZbIb++k8AqFEKy9uVgY12xIQtBsQlc6DfZJQM=
|
||||
github.com/alecthomas/chroma/v2 v2.24.1/go.mod h1:l+ohZ9xRXIbGe7cIW+YZgOGbvuVLjMps/FYN/CwuabI=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1 h1:2X21EdxGZNv5GF9mG5u+uzc02GCFyGxbcBm3Grd9A78=
|
||||
github.com/alecthomas/chroma/v2 v2.26.1/go.mod h1:lxhRRa9H4hPmRLOOdYga4zkQIQjq3dtrrdwQeCfu78Y=
|
||||
github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs=
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/ardanlabs/jinja v1.2.0/go.mod h1:aXXzlJfjA+T3XNKA/YT5ZtDq2VJxt5a5siZ8cl9B35Q=
|
||||
github.com/ardanlabs/kronk v1.25.2/go.mod h1:b5Gg4jDqvHDklkeHNB8+7treZRxUiCFsV65zphrTloY=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7 h1:DWpAJt66FmnnaRIOT/8ASTucrvuDPZASqhhLey6tLY8=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.7/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8 h1:sRs7nG6/RiEBZ/K5UO2sNw0w40U02Nmz1VtARloTZXk=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.8/go.mod h1:4LAfZOPHNVNQEckOACQx60Y8pSRjIkNZQz1w92xpMJc=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10 h1:gx1AwW1Iyk9Z9dD9F4akX5gnN3QZwUB20GGKH/I+Rho=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.10/go.mod h1:qqY157uZoqm5OXq/amuaBJyC9hgBCBQnsaWnPe905GY=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17 h1:FpL4/758/diKwqbytU0prpuiu60fgXKUWCpDJtApclU=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.17/go.mod h1:OXqUMzgXytfoF9JaKkhrOYsyh72t9G+MJH8mMRaexOE=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16 h1:r3RJBuU7X9ibt8RHbMjWE6y60QbKBiII6wSrXnapxSU=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.16/go.mod h1:6cx7zqDENJDbBIIWX6P8s0h6hqHC8Avbjh9Dseo27ug=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23 h1:UuSfcORqNSz/ey3VPRS8TcVH2Ikf0/sC+Hdj400QI6U=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.23/go.mod h1:+G/OSGiOFnSOkYloKj/9M35s74LgVAdJBSD5lsFfqKg=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23 h1:GpT/TrnBYuE5gan2cZbTtvP+JlHsutdmlV2YfEyNde0=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.23/go.mod h1:xYWD6BS9ywC5bS3sz9Xh04whO/hzK2plt2Zkyrp4JuA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23 h1:bpd8vxhlQi2r1hiueOw02f/duEPTMK59Q4QMAoTTtTo=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.23/go.mod h1:15DfR2nw+CRHIk0tqNyifu3G1YdAOy68RftkhMDDwYk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24 h1:OQqn11BtaYv1WLUowvcA30MpzIu8Ti4pcLPIIyoKZrA=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.24/go.mod h1:X5ZJyfwVrWA96GzPmUCWFQaEARPR7gCrpq2E92PJwAE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19 h1:qRhIJMbevHUvIE7X4TK8N8zye5+5AhapcslPrvB+qKE=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.19/go.mod h1:RbJ24nfoya63+Mf5VI+CGCGk9vEdv28xPeii+gojRYs=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18 h1:GcXQz2M/0ZvMo0v5DakUqbDBeBM1ZNaivkolEF4Esgw=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.18/go.mod h1:sHJ06tMGcD3ZpmMyJqV+VBsGilhSIZPIN+ZFy5Dg0C4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24 h1:FQm5ApnyzkuJdXLGskPce83CK1CQKC4RUnIHKVe4BU4=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.24/go.mod h1:JsC7dqQc55MlZ5mvNsDMMge71u8pVcSzU3RNz2h/5yQ=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24 h1:u6kJU2i0va1AgtJsH3RdWKWqHULlTh7zHwb35Womf74=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.24/go.mod h1:7GY+xLcXOFUpCkNwDReft9qOAVg54A4/AnjHIU7sSAY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24 h1:Xhbcf3KugX6vX7SDyUK205Oicyfg7EGuvoVNyP5L6DM=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.24/go.mod h1:rwDgb2HNOGZsnTHylOUedM7Vnl+bCfnXDqUNPsFWYfk=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25 h1:54CTMmlJ71Rk2dYvM9qZOob+39wjlVja2zDLxCu69Ew=
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.25/go.mod h1:BZaHqxsS9vN1fvV5EfEl0OBLOk5+AajWsMu6MjqnZB4=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9 h1:FLudkZLt5ci0ozzgkVo8BJGwvqNaZbTWb3UcucAateA=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.9/go.mod h1:w7wZ/s9qK7c8g4al+UyoF1Sp/Z45UwMGcqIzLWVQHWk=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23 h1:pbrxO/kuIwgEsOPLkaHu0O+m4fNgLU8B3vxQ+72jTPw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.23/go.mod h1:/CMNUqoj46HpS3MNRDEDIwcgEnrtZlKRaHNaHxIFpNA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11 h1:TdJ+HdzOBhU8+iVAOGUTU63VXopcumCOF1paFulHWZc=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.11/go.mod h1:R82ZRExE/nheo0N+T8zHPcLRTcH8MGsnR3BiVGX0TwI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17 h1:7byT8HUWrgoRp6sXjxtZwgOKfhss5fW6SkLBtqzgRoE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.17/go.mod h1:xNWknVi4Ezm1vg1QsB/5EWpAJURq22uqd38U8qKvOJc=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21 h1:+1Kl1zx6bWi4X7cKi3VYh29h8BvsCoHQEQ6ST9X8w7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.21/go.mod h1:4vIRDq+CJB2xFAXZ+YgGUTiEft7oAQlhIs71xcSeuVg=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1 h1:F/M5Y9I3nwr2IEpshZgh1GeHpOItExNM9L1euNuh/fk=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.1/go.mod h1:mTNxImtovCOEEuD65mKW7DCsL+2gjEH+RPEAexAzAio=
|
||||
github.com/aws/smithy-go v1.25.1 h1:J8ERsGSU7d+aCmdQur5Txg6bVoYelvQJgtZehD12GkI=
|
||||
github.com/aws/smithy-go v1.25.1/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.15/go.mod h1:e3IzZvQ3kAWNykvE0Tr0RDZCMFInMvhku3qNpcIQXhM=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24 h1:CQW2FTrflfoslYWLf3fv7vG28Q219+v8YJS5QTQb2+Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.24/go.mod h1:Xfx13T+u3nH6EEzgl9fBSO6nDRmze1FvnZNYkctQ2zw=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.23/go.mod h1:M8l3mwgx5ToK7wot2sBBce/ojzgnPzZXUV445gTSyE8=
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.101.0/go.mod h1:L2dcoOgS2VSgbPLvpak2NyUPsO1TBN7M45Z4H7DlRc4=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0 h1:yQo3eZ5qFaL1sJWqs1nL6j3yPHA2/R7c6tQ4T+0IO10=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.1.0/go.mod h1:3Zzou41Qt/ueXfIzHvTEjDNuR5IjCUBVF01SNhrt1e8=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18 h1:ApLTFdAZfDhZSiY5uskwECKHkSNNF83y2Ru2r7SezWA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.18/go.mod h1:A9K9qx2l6nK89hp+a350FdGfRkrkH5HdiEjHbiy/Q/c=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1 h1:4VD7TIZOGzehrgQ8vDE+1c6BQW4ErZPGY8ohZT5LXEE=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.36.1/go.mod h1:er0SFJfdV89Rit5hIJu/EXtv+qC2XMnxoksLmcUFkqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2 h1:XKnxlM4KZH1gktcsh3zSWc7GW4KivEv/OkifmHOhCUY=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.42.2/go.mod h1:KJYmkQaFB3SUW2j3aBkPsxNmAb4ZsSOvbvCpuxzHJA0=
|
||||
github.com/aws/smithy-go v1.26.0 h1:9ouqbi+NyKP7fV3Te7UElCwdAb6Y8uk7LGwPE5tVe/s=
|
||||
github.com/aws/smithy-go v1.26.0/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bgentry/go-netrc v0.0.0-20140422174119-9fd32a8b3d3d/go.mod h1:6QX/PXZ00z/TKoufEY6K/a0k6AhaJrQKdFe6OfVXsa4=
|
||||
github.com/bits-and-blooms/bitset v1.24.4/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab h1:J7XQLgl9sefgTnTGrmX3xqvp5o6MCiBzEjGv5igAlc4=
|
||||
@@ -86,8 +109,8 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be h1:j7w8VP/D4lu5+/4GamMmFy8nrtadcl82/fjvDgSHwLo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260428153724-66037269d7be/go.mod h1:3YdTxlnV/L0bQ3VN8WOSw8doF7LZV/xawUQ4MuAPDvo=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f h1:vKsPSlO4g4jKfJ9enESgNZ45BkbHngTIq3UxNOzic74=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260601155805-6cf7526a1b3f/go.mod h1:hFpumms29Smx3LStRfku8vcCTBe1Kq8aCXtHUJa3mjY=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
@@ -98,14 +121,14 @@ github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIR
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310 h1:rByFKh9JgQScu7oy0+TlUbC2e93woW/QNZmNXbbbw/E=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260503005035-c113ba3d2310/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a h1:aVvnksCVgxB2igk7jERL9ARIkbDXccp1gXCFqhGlamQ=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260602025833-85a30b5e440a/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310 h1:PMjHdSo8Vpq9psUw9BoHo9JLPMkm9Hqb+Whk64n3AQQ=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260503005035-c113ba3d2310/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d h1:RxcAR+vJCoD8QqT1cqLtkQKw+1cqvjqnu5IpPqYzPco=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260527151214-009e6338d40d/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -120,12 +143,13 @@ github.com/charmbracelet/x/xpty v0.1.3 h1:eGSitii4suhzrISYH50ZfufV3v085BXQwIytcO
|
||||
github.com/charmbracelet/x/xpty v0.1.3/go.mod h1:poPYpWuLDBFCKmKLDnhBp51ATa0ooD8FhypRwEFtH3Y=
|
||||
github.com/clipperhouse/displaywidth v0.11.0 h1:lBc6kY44VFw+TDx4I8opi/EtL9m20WSEFgwIwO+UVM8=
|
||||
github.com/clipperhouse/displaywidth v0.11.0/go.mod h1:bkrFNkf81G8HyVqmKGxsPufD3JhNl3dSqnGhOoSD/o0=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 h1:+gs4oBZ2gPfVrKPthwbMzWZDaAFPGYK72F0NJv2v7Vk=
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2 h1:aBangftG7EVZoUb69Os8IaYg++6uMOdKK83QtkkvJik=
|
||||
github.com/cncf/xds/go v0.0.0-20260202195803-dba9d589def2/go.mod h1:qwXFYgsP6T7XnJtbKlf1HP8AjxZZyzxMmc+Lq5GjlU4=
|
||||
github.com/coder/acp-go-sdk v0.12.2 h1:fpRJ8Z5HMSr5cZ5IywzFlFZcIxZOsto+laNVu7XelFA=
|
||||
github.com/coder/acp-go-sdk v0.12.2/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/coder/acp-go-sdk v0.13.5 h1:LI9jq5xon7xslaYlnoktvTVyDlE37yIk2daT7N9ASYk=
|
||||
github.com/coder/acp-go-sdk v0.13.5/go.mod h1:yKzM/3R9uELp4+nBAwwtkS0aN1FOFjo11CNPy37yFko=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
@@ -133,13 +157,20 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.12.0 h1:0j4c5qQmnC6XOWNjP3PIXURXN2gWx76rd3KvgdPkCz8=
|
||||
github.com/dlclark/regexp2 v1.12.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1 h1:LCUGyd9Wf+r+VVOl8Ny38JTpWJcAsdVnCIuhhtthmKw=
|
||||
github.com/dlclark/regexp2/v2 v2.1.1/go.mod h1:avUrQvPaLz2DrFNHJF0taWAFFX2C1GMSSoeiqFjcBmU=
|
||||
github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI=
|
||||
github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ=
|
||||
github.com/dromara/carbon/v2 v2.6.16/go.mod h1:NGo3reeV5vhWCYWcSqbJRZm46MEwyfYI5EJRdVFoLJo=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/ebitengine/purego v0.10.0/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/eliben/go-sentencepiece v0.6.0/go.mod h1:nNYk4aMzgBoI6QFp4LUG8Eu1uO9fHD9L5ZEre93o9+c=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA=
|
||||
github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0 h1:u3riX6BoYRfF4Dr7dwSOroNfdSbEPe9Yyl09/B6wBrQ=
|
||||
github.com/envoyproxy/go-control-plane/envoy v1.37.0/go.mod h1:DReE9MMrmecPy+YvQOAOHNYMALuowAnbjjEMkkWOi6A=
|
||||
github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3 h1:MVQghNeW+LZcmXe7SY1V36Z+WFMDjpqGAGacLe2T0ds=
|
||||
github.com/envoyproxy/protoc-gen-validate v1.3.3/go.mod h1:TsndJ/ngyIdQRhMcVVGDDHINPLWB7C82oDArY51KfB0=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
@@ -148,8 +179,9 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4 h1:2WmHkJINIjgXXYDGik8d3oJvFA3DAwPy00csDJ3vo+o=
|
||||
github.com/go-json-experiment/json v0.0.0-20260430182902-b6187a392ed4/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686 h1:NZBJxCpbHS1gzS6xAmyxbJznosZIIPk9IB42v62UvKA=
|
||||
github.com/go-json-experiment/json v0.0.0-20260520185125-572e7c383686/go.mod h1:tphK2c80bpPhMOI4v6bIc2xWywPfbqi1Z06+RcrMkDg=
|
||||
github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE=
|
||||
github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
@@ -163,38 +195,53 @@ github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-pkcs11 v0.3.0/go.mod h1:6eQoGcuNJpa7jnd5pMGdkSaQpNDYvPlXWMcjXXThLlY=
|
||||
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0=
|
||||
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
|
||||
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15 h1:xolVQTEXusUcAA5UgtyRLjelpFFHWlPQ4XfWGc7MBas=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.15/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16 h1:F/VPrx0YPBdksZJQdCAp0WUsqnNmZpUZszzfYt0M5Dw=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.16/go.mod h1:9Yb0eAkH/Xqhvv3zbeKf/+wMJqCeocWc6KIhDvEAuYE=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0 h1:PjIWBpgGIVKGoCXuiCoP64altEJCj3/Ei+kSU5vlZD4=
|
||||
github.com/googleapis/gax-go/v2 v2.22.0/go.mod h1:irWBbALSr0Sk3qlqb9SyJ1h68WjgeFuiOzI4Rqw5+aY=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0/go.mod h1:Hyl3n6Twe1hvtd9XUXDec4pTvgMSEixRuQKPTMH2bNs=
|
||||
github.com/hashicorp/aws-sdk-go-base/v2 v2.0.0-beta.72/go.mod h1:Vn+BBgKQHVQYdVQ4NZDICE1Brb+JfaONyDHr3q07oQc=
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
|
||||
github.com/hashicorp/go-getter v1.8.6/go.mod h1:nVH12eOV2P58dIiL3rsU6Fh3wLeJEKBOJzhMmzlSWoo=
|
||||
github.com/hashicorp/go-version v1.9.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/hybridgroup/yzma v1.13.0/go.mod h1:zrzMgv/KVQz23+s6l16b+vJ+9uJVBdWtGcGkwRTMeiQ=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.7 h1:apjIIZHnGRyrkiX3vHj07F1BF6D0JLmV+VGSr1781Jc=
|
||||
github.com/kaptinlin/go-i18n v0.4.7/go.mod h1:+i1J0pFq/9i9ESC5qRMVkKwC+mdQTABhhBExpYOlbeM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21 h1:WVkwQbeerbHFcoXG7Yo/mlQhhZjWiTnagECEfwDXXa0=
|
||||
github.com/kaptinlin/jsonpointer v0.4.21/go.mod h1:Mo7+DX8RlQTFqS4dnYJl0izSP4ob+Rl5xO/mGDETgaU=
|
||||
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/jupiterrider/ffi v0.7.0/go.mod h1:9dauhpOfNqrqk28fxuu0kkdeFtT9Qr4vbfigiuIXN7c=
|
||||
github.com/kaptinlin/go-i18n v0.4.5 h1:9tIlo5A0RXth+yZJO2MG7Bhpu/X9PlzQnGz/qyYWNoY=
|
||||
github.com/kaptinlin/go-i18n v0.4.5/go.mod h1:mU/7BH4molY5lGZYBwBRKAaiJ70dWRHuqmQ0/pFLGno=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25 h1:iJ197e8n+WwqaqBsa53FqG3rPJCg5oijyFXEXNWWC3E=
|
||||
github.com/kaptinlin/jsonpointer v0.4.25/go.mod h1:wVOBaXGGnP42YsMb6zev/3W5POTvspdNfh8DXzf8XS8=
|
||||
github.com/kaptinlin/jsonschema v0.7.13 h1:kahVXTy/rURL0XJjyQ9WELm59wEmXi6IY0TWswQEFvU=
|
||||
github.com/kaptinlin/jsonschema v0.7.13/go.mod h1:Uh0aUBusnhXDCEXJ2oimL/hx7YTo7F+sKniE+tM0ERc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3 h1:m9ZE/fCjnsk8bdkv7Qs56L/ZoHbmQqhz9mRZSAQLU5g=
|
||||
github.com/kaptinlin/messageformat-go v0.6.3/go.mod h1:2KOZ/hgo/SveZ+uyi7vPUpUXieX65Mppzbc3VpGyqKs=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0 h1:D6jiXFsKW4/JG2CMddv/F6Rev9KVbCRKEzzV5QOAcpc=
|
||||
github.com/kaptinlin/messageformat-go v0.6.0/go.mod h1:NKjwS6e9u7DRhAK+vydjDDwJ7UbdHhYjk/yk2WPuZPs=
|
||||
github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -203,12 +250,14 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.51.0 h1:e8AhEfxzcYt7XqYzwT7uzWNhnqpu3H1Tn7dEJB9Ygj8=
|
||||
github.com/mark3labs/mcp-go v0.51.0/go.mod h1:Zg9cB2HdwdMMVgY0xtTzq3KvYIOJQDsaut+jWjwDaQY=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mark3labs/mcp-go v0.54.1 h1:Ap/ptEB9FtWzFKM8NDsTA7QDxerQOC06eZigrTldVj0=
|
||||
github.com/mark3labs/mcp-go v0.54.1/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
|
||||
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mattn/go-runewidth v0.0.24 h1:cpokDiIn0MGnhdHwuWnJBITySJ20QyNGnY2kR/ay2DU=
|
||||
github.com/mattn/go-runewidth v0.0.24/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -223,6 +272,7 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
@@ -231,6 +281,10 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgm
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -238,8 +292,10 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -251,13 +307,14 @@ github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
@@ -268,67 +325,87 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E=
|
||||
github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY=
|
||||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4=
|
||||
go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.opentelemetry.io/contrib/detectors/gcp v1.43.0/go.mod h1:RyaZMFY7yi1kAs45S6mbFGz8O8rqB0dTY14uzvG4LCs=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0 h1:2yEATaop1/a1I4psnSLgWVPLWwCzkqWakgJy7xTDVy0=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.69.0/go.mod h1:D7J12YRapIekYyPWgGPlA/23pRmpSEZC5xJC/TTLI9U=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0 h1:8tvICD4vSTOOsNrsI4Ljf6C+6UKvpTEH5XY3JMoyPoo=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.69.0/go.mod h1:z9+yiacE0IHRqM4qFfkbt/JYlmYXgss8GY/jXoNuPJI=
|
||||
go.opentelemetry.io/otel v1.44.0 h1:JjwHmHpA4iZ3wBxluu2fbbE7j4kqlE8jXyAyPXH7HqU=
|
||||
go.opentelemetry.io/otel v1.44.0/go.mod h1:BMgjTHL9WPRlRjL2oZCBTL4whCGtXch2H4BhOPIAyYc=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
|
||||
go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSYCdDdmc=
|
||||
go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0 h1:nHYwb9lK+fJPU/dnT6s7W7Z8itMWyqrnVfbheVYrZ58=
|
||||
go.opentelemetry.io/otel/sdk v1.44.0/go.mod h1:Osuydd3Se74nqjAKxid74N5eC+jfEqfTegHRnq58oK0=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0 h1:3LlKgI+VjbVsjNRFZJZAJ30WjXC5VkNRks6si09iEfI=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.44.0/go.mod h1:5B5pMARnXxKhltooO4xUuCBorl65a4EpnTalObqOigA=
|
||||
go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk=
|
||||
go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE=
|
||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
go.yaml.in/yaml/v4 v4.0.0-rc.3/go.mod h1:aZqd9kCMsGL7AuUv/m/PvWLdg5sjJsZ4oHDEnfPPfY0=
|
||||
golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988=
|
||||
golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc=
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b h1:v1uXiEBHo8QA0LiGCo7UgHMzHT4Kdfpl2zmtH5vaP1Q=
|
||||
golang.org/x/exp v0.0.0-20260603202125-055de637280b/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
|
||||
golang.org/x/image v0.41.0 h1:8wS72eGJMJaBxK6okTzd4WaXumUlTVlb753MlsSvTCo=
|
||||
golang.org/x/image v0.41.0/go.mod h1:uIc348UZMSvS5Z65CVZ7iDPaNobNFEPeJ4kbqTOszmA=
|
||||
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY=
|
||||
golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
|
||||
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0=
|
||||
golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
|
||||
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.277.0 h1:HJfyJUiNeBBUMai7ez8u14wkp/gH/I4wpGbbO9o+cSk=
|
||||
google.golang.org/api v0.277.0/go.mod h1:B9TqLBwJqVjp1mtt7WeoQwWRwvu/400y5lETOql+giQ=
|
||||
google.golang.org/genai v1.55.0 h1:iLHGk4Bj/IZ/GNNZb7hYqwSJMRBvqLeu2Hb6YQ+rYGw=
|
||||
google.golang.org/genai v1.55.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4 h1:2iMJZntwvmfgtse+s744JY7v7PgEdSBuFYXucvpOHNM=
|
||||
google.golang.org/genproto v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:v14kaaboYyXQ1Gsu489Q+Hg/oN4B33mWtuOhF1HCeXA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4 h1:yOzSCGPx+cp5VO7IxvZ9SBFF7j1tZVcNtlHR2iYKtVo=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:Q9HWtNeE7tM9npdIsEvqXj1QJIvVoeAV3rtXtS715Cw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4 h1:tEkOQcXgF6dH1G+MVKZrfpYvozGrzb91k6ha7jireSM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260427160629-7cedc36a6bc4/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.0 h1:W3G9N3KQf3BU+YuCtGKJk0CmxQNbAISICD/9AORxLIw=
|
||||
google.golang.org/grpc v1.81.0/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/api v0.282.0 h1:WmJiSVqUnKqJCpJOx7YADbXaC+9DDsnGSfllFSj7R2I=
|
||||
google.golang.org/api v0.282.0/go.mod h1:6Wssta4c5n9qHq5CBhmlai5h/PUa1djdDAIhYEHyvcM=
|
||||
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
|
||||
google.golang.org/genai v1.58.0 h1:MNA3ZkRyr7MnRwZ9RNZ60p4+UMKV3yYRw6pyHq4pp0U=
|
||||
google.golang.org/genai v1.58.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348 h1:JjVGDZYWkJWZcxveJGzfkXC5myDVWAd4dZdgbzrDUv8=
|
||||
google.golang.org/genproto v0.0.0-20260504160031-60b97b32f348/go.mod h1:95PqD4xM+AdOcBGsmgfaofXsiA37uXDtDufVbntT3TU=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348 h1:U8orV30l6KpDsi9dxU0CoJZGbjS8EEpw+6ba+XwGPQA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260504160031-60b97b32f348/go.mod h1:Yzdzr5OOZFgSsEV2D/Xi9NL3bszpXFAg0hFJiRohcD8=
|
||||
google.golang.org/genproto/googleapis/bytestream v0.0.0-20260523011958-0a33c5d7ca68/go.mod h1:6TABGosqSqU2l1+fJ3jdvOYPPVryeKybxYF0cCZkTBE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/dnaeon/go-vcr.v4 v4.0.6-0.20251110073552-01de4eb40290/go.mod h1:sbq5oMEcM4PXngbcNbHhzfCP9OdZodLhrbRYoyg09HY=
|
||||
gopkg.in/ini.v1 v1.67.1/go.mod h1:x/cyOwCgZqOkJoDIJ3c1KNHMo10+nLGAhh+kn3Zizss=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -61,6 +61,12 @@ func (a *Agent) Authenticate(_ context.Context, _ acp.AuthenticateRequest) (acp.
|
||||
return acp.AuthenticateResponse{}, nil
|
||||
}
|
||||
|
||||
// Logout handles logout requests. Kit doesn't require auth for local stdio
|
||||
// usage, so this is a no-op.
|
||||
func (a *Agent) Logout(_ context.Context, _ acp.LogoutRequest) (acp.LogoutResponse, error) {
|
||||
return acp.LogoutResponse{}, nil
|
||||
}
|
||||
|
||||
// Initialize negotiates capabilities with the ACP client.
|
||||
func (a *Agent) Initialize(_ context.Context, params acp.InitializeRequest) (acp.InitializeResponse, error) {
|
||||
log.Debug("acp: initialize", "protocol_version", params.ProtocolVersion)
|
||||
|
||||
+72
-168
@@ -7,7 +7,9 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extbridge"
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -37,10 +39,21 @@ func newSessionRegistry() *sessionRegistry {
|
||||
// given working directory. The Kit-generated session ID is used as the ACP
|
||||
// session ID so the mapping is 1:1.
|
||||
func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession, error) {
|
||||
// Each ACP session gets its own isolated config store (CLI is left nil) so
|
||||
// per-session SetModel / SetThinkingLevel calls cannot race or bleed across
|
||||
// the sessionRegistry. We seed the relevant root-command flag values from
|
||||
// the process-global store (which cobra populated from flags) so launching
|
||||
// `kit acp -m <model> [--thinking-level ...] [--provider-url ...]` is still
|
||||
// honored; .kit.yml and KIT_* env vars are loaded per session by kit.New.
|
||||
streamOn := true
|
||||
kitInstance, err := kit.New(ctx, &kit.Options{
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: true,
|
||||
SessionDir: cwd,
|
||||
Quiet: true,
|
||||
Streaming: &streamOn,
|
||||
Model: viper.GetString("model"),
|
||||
ThinkingLevel: viper.GetString("thinking-level"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
})
|
||||
if err != nil {
|
||||
// Provide actionable guidance for provider auth errors, which are
|
||||
@@ -60,142 +73,70 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Wire extension context with headless implementations so extensions
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
// become no-ops or return cancelled; all data/model/tool APIs come from
|
||||
// extbridge.BaseContext and work identically to interactive mode.
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
Interactive: false,
|
||||
// Use a background context for subagent spawns: the create() ctx is
|
||||
// request-scoped and may be cancelled before extensions spawn anything.
|
||||
ec := extbridge.BaseContext(context.Background(), kitInstance)
|
||||
|
||||
// Output — route through structured logger.
|
||||
Print: func(text string) { log.Debug("extension: print", "text", text) },
|
||||
PrintInfo: func(text string) { log.Info("extension: info", "text", text) },
|
||||
PrintError: func(text string) { log.Error("extension: error", "text", text) },
|
||||
PrintBlock: func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
},
|
||||
ec.SessionID = sessionID
|
||||
ec.CWD = cwd
|
||||
ec.Model = kitInstance.GetModelString()
|
||||
ec.Interactive = false
|
||||
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
SendMessage: func(string) {},
|
||||
CancelAndSend: func(string) {},
|
||||
Exit: func() {},
|
||||
// Output — route through structured logger.
|
||||
ec.Print = func(text string) { log.Debug("extension: print", "text", text) }
|
||||
ec.PrintInfo = func(text string) { log.Info("extension: info", "text", text) }
|
||||
ec.PrintError = func(text string) { log.Error("extension: error", "text", text) }
|
||||
ec.PrintBlock = func(opts extensions.PrintBlockOpts) {
|
||||
log.Info("extension: block", "subtitle", opts.Subtitle, "text", opts.Text)
|
||||
}
|
||||
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
SetWidget: func(extensions.WidgetConfig) {},
|
||||
RemoveWidget: func(string) {},
|
||||
SetHeader: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveHeader: func() {},
|
||||
SetFooter: func(extensions.HeaderFooterConfig) {},
|
||||
RemoveFooter: func() {},
|
||||
SetEditor: func(extensions.EditorConfig) {},
|
||||
ResetEditor: func() {},
|
||||
SetEditorText: func(string) {},
|
||||
SetUIVisibility: func(extensions.UIVisibility) {},
|
||||
SetStatus: func(string, string, int) {},
|
||||
RemoveStatus: func(string) {},
|
||||
// Message injection — no-ops for now; ACP clients drive prompts.
|
||||
ec.SendMessage = func(string) {}
|
||||
ec.CancelAndSend = func(string) {}
|
||||
ec.Exit = func() {}
|
||||
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
PromptSelect: func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
},
|
||||
PromptConfirm: func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
},
|
||||
PromptInput: func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
},
|
||||
ShowOverlay: func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
},
|
||||
SuspendTUI: func(callback func()) error { callback(); return nil },
|
||||
// TUI widgets/chrome — silent no-ops (no TUI in ACP).
|
||||
ec.SetWidget = func(extensions.WidgetConfig) {}
|
||||
ec.RemoveWidget = func(string) {}
|
||||
ec.SetHeader = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveHeader = func() {}
|
||||
ec.SetFooter = func(extensions.HeaderFooterConfig) {}
|
||||
ec.RemoveFooter = func() {}
|
||||
ec.SetEditor = func(extensions.EditorConfig) {}
|
||||
ec.ResetEditor = func() {}
|
||||
ec.SetEditorText = func(string) {}
|
||||
ec.SetUIVisibility = func(extensions.UIVisibility) {}
|
||||
ec.SetStatus = func(string, string, int) {}
|
||||
ec.RemoveStatus = func(string) {}
|
||||
|
||||
// Data access — delegate to Kit instance.
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
// Interactive prompts — return cancelled (no user to prompt).
|
||||
ec.PromptSelect = func(extensions.PromptSelectConfig) extensions.PromptSelectResult {
|
||||
return extensions.PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptConfirm = func(extensions.PromptConfirmConfig) extensions.PromptConfirmResult {
|
||||
return extensions.PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
ec.PromptInput = func(extensions.PromptInputConfig) extensions.PromptInputResult {
|
||||
return extensions.PromptInputResult{Cancelled: true}
|
||||
}
|
||||
ec.ShowOverlay = func(extensions.OverlayConfig) extensions.OverlayResult {
|
||||
return extensions.OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
ec.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
// Render — fall back to logging.
|
||||
ec.RenderMessage = func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
}
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: config.Prompt,
|
||||
Model: config.Model,
|
||||
SystemPrompt: config.SystemPrompt,
|
||||
Timeout: config.Timeout,
|
||||
NoSession: config.NoSession,
|
||||
}
|
||||
if config.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := sdkEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
config.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := kitInstance.Subagent(context.Background(), sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
},
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.Extensions().SetContext(ec)
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
@@ -269,40 +210,3 @@ func (s *acpSession) clearCancel() {
|
||||
defer s.cancelMu.Unlock()
|
||||
s.cancelFn = nil
|
||||
}
|
||||
|
||||
// sdkEventToSubagentEvent converts an SDK event to an extension SubagentEvent.
|
||||
func sdkEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
+47
-46
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -168,9 +169,9 @@ type RetryHandler func(attempt int, err error)
|
||||
type PrepareStepHandler func(stepNumber int, messages []fantasy.Message) []fantasy.Message
|
||||
|
||||
// GenerateCallbacks consolidates all callback functions for
|
||||
// GenerateWithLoopAndStreaming into a single struct. This replaces the previous
|
||||
// 16+ positional callback parameters, making it easier to add new callbacks
|
||||
// without breaking existing callers (new fields default to nil).
|
||||
// GenerateWithCallbacks into a single struct, replacing what was previously
|
||||
// 16+ positional callback parameters. New fields default to nil, so adding
|
||||
// new callbacks does not break existing callers.
|
||||
type GenerateCallbacks struct {
|
||||
OnToolCall ToolCallHandler
|
||||
OnToolExecution ToolExecutionHandler
|
||||
@@ -245,6 +246,12 @@ type Agent struct {
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
|
||||
// promptMu serializes runtime updates to systemPrompt and the
|
||||
// accompanying fantasy agent rebuild so concurrent SetSystemPrompt
|
||||
// callers (e.g. Kit.applyComposedSystemPrompt invoked from multiple
|
||||
// goroutines) don't race on a.systemPrompt / a.fantasyAgent.
|
||||
promptMu sync.Mutex
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -515,44 +522,6 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally.
|
||||
//
|
||||
// Deprecated: Use GenerateWithCallbacks instead, which takes a GenerateCallbacks
|
||||
// struct and is easier to extend with new callbacks.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
onPasswordPrompt PasswordPromptHandler,
|
||||
onToolCallStart ToolCallStartHandler,
|
||||
onToolCallDelta ToolCallDeltaHandler,
|
||||
onToolCallEnd ToolCallEndHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithCallbacks(ctx, messages, GenerateCallbacks{
|
||||
OnToolCall: onToolCall,
|
||||
OnToolExecution: onToolExecution,
|
||||
OnToolResult: onToolResult,
|
||||
OnResponse: onResponse,
|
||||
OnToolCallContent: onToolCallContent,
|
||||
OnStreamingResponse: onStreamingResponse,
|
||||
OnReasoningDelta: onReasoningDelta,
|
||||
OnReasoningComplete: onReasoningComplete,
|
||||
OnToolOutput: onToolOutput,
|
||||
OnStepMessages: onStepMessages,
|
||||
OnStepUsage: onStepUsage,
|
||||
OnPasswordPrompt: onPasswordPrompt,
|
||||
OnToolCallStart: onToolCallStart,
|
||||
OnToolCallDelta: onToolCallDelta,
|
||||
OnToolCallEnd: onToolCallEnd,
|
||||
})
|
||||
}
|
||||
|
||||
// GenerateWithCallbacks processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
@@ -585,8 +554,13 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
// Track tool call args per-ToolCallID so parallel tool calls in a single
|
||||
// step don't clobber each other. Without this, OnToolResult callbacks would
|
||||
// all see the args of the last OnToolCall in the step. The mutex guards
|
||||
// against the possibility that the underlying streaming layer dispatches
|
||||
// callbacks from multiple goroutines.
|
||||
toolCallArgs := make(map[string]string)
|
||||
var toolCallArgsMu sync.Mutex
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
@@ -773,7 +747,9 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolArgs = tc.Input
|
||||
toolCallArgsMu.Lock()
|
||||
toolCallArgs[tc.ToolCallID] = tc.Input
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify about the tool call
|
||||
if cb.OnToolCall != nil {
|
||||
@@ -793,15 +769,22 @@ func (a *Agent) GenerateWithCallbacks(ctx context.Context, messages []fantasy.Me
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
// Look up the args recorded for this specific tool call. Delete
|
||||
// the entry so the map doesn't accumulate across steps.
|
||||
toolCallArgsMu.Lock()
|
||||
args := toolCallArgs[tr.ToolCallID]
|
||||
delete(toolCallArgs, tr.ToolCallID)
|
||||
toolCallArgsMu.Unlock()
|
||||
|
||||
// Notify tool execution finished
|
||||
if cb.OnToolExecution != nil {
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, currentToolArgs, false)
|
||||
cb.OnToolExecution(tr.ToolCallID, tr.ToolName, args, false)
|
||||
}
|
||||
|
||||
if cb.OnToolResult != nil {
|
||||
// Extract result text and error status
|
||||
resultText, isError := extractToolResultText(tr)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, currentToolArgs, resultText, tr.ClientMetadata, isError)
|
||||
cb.OnToolResult(tr.ToolCallID, tr.ToolName, args, resultText, tr.ClientMetadata, isError)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1303,6 +1286,24 @@ func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// SetSystemPrompt updates the agent's system prompt and rebuilds the underlying
|
||||
// fantasy agent so subsequent turns use the new prompt. Safe to call while the
|
||||
// agent is idle; if invoked during an in-flight turn the new prompt takes
|
||||
// effect on the next LLM call.
|
||||
func (a *Agent) SetSystemPrompt(prompt string) {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
a.systemPrompt = prompt
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// GetSystemPrompt returns the agent's current system prompt.
|
||||
func (a *Agent) GetSystemPrompt() string {
|
||||
a.promptMu.Lock()
|
||||
defer a.promptMu.Unlock()
|
||||
return a.systemPrompt
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the effective max output tokens the agent currently
|
||||
// sends to the LLM provider, after per-model defaults, right-sizing, and any
|
||||
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// fakeParallelAgent simulates a provider that emits two parallel tool_use
|
||||
// blocks in a single step. It invokes the streaming callbacks in the order:
|
||||
//
|
||||
// OnToolCall(A) -> OnToolCall(B) -> OnToolResult(A) -> OnToolResult(B)
|
||||
//
|
||||
// Before the fix in #33 the agent-layer wrapper recorded a single
|
||||
// `currentToolArgs` variable that was clobbered by the second OnToolCall, so
|
||||
// both OnToolResult callbacks received B's args instead of their own.
|
||||
type fakeParallelAgent struct {
|
||||
calls []fantasy.ToolCallContent
|
||||
results []fantasy.ToolResultContent
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Generate(_ context.Context, _ fantasy.AgentCall) (*fantasy.AgentResult, error) {
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
func (f *fakeParallelAgent) Stream(_ context.Context, opts fantasy.AgentStreamCall) (*fantasy.AgentResult, error) {
|
||||
for _, tc := range f.calls {
|
||||
if opts.OnToolCall != nil {
|
||||
if err := opts.OnToolCall(tc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tr := range f.results {
|
||||
if opts.OnToolResult != nil {
|
||||
if err := opts.OnToolResult(tr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return &fantasy.AgentResult{}, nil
|
||||
}
|
||||
|
||||
// TestGenerateWithCallbacks_ParallelToolArgs is the regression test for #33.
|
||||
// It drives the streaming-callback wiring inside GenerateWithCallbacks with a
|
||||
// fake fantasy.Agent that emits two parallel tool calls before either result.
|
||||
// Each OnToolResult must receive the args of its own tool call (matched by
|
||||
// ToolCallID), not the args of the last OnToolCall in the step.
|
||||
func TestGenerateWithCallbacks_ParallelToolArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
argsA := `{"name":"scheduled_jobs"}`
|
||||
argsB := `{"name":"gmail_trigger"}`
|
||||
|
||||
fake := &fakeParallelAgent{
|
||||
calls: []fantasy.ToolCallContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Input: argsA},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Input: argsB},
|
||||
},
|
||||
results: []fantasy.ToolResultContent{
|
||||
{ToolCallID: "kit-A", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-A"}},
|
||||
{ToolCallID: "kit-B", ToolName: "load_skill", Result: fantasy.ToolResultOutputContentText{Text: "ok-B"}},
|
||||
},
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fake,
|
||||
streamingEnabled: false, // exercise the "hasCallbacks" branch
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
resultArgs := map[string]string{}
|
||||
executionArgs := map[string]string{} // captured when running == false
|
||||
|
||||
cb := GenerateCallbacks{
|
||||
OnToolExecution: func(id, _, args string, running bool) {
|
||||
if running {
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
executionArgs[id] = args
|
||||
},
|
||||
OnToolResult: func(id, _, args, _, _ string, _ bool) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
resultArgs[id] = args
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := a.GenerateWithCallbacks(context.Background(), nil, cb); err != nil {
|
||||
t.Fatalf("GenerateWithCallbacks returned error: %v", err)
|
||||
}
|
||||
|
||||
if got, want := resultArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolResult for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := resultArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolResult for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-A"], argsA; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-A: args = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := executionArgs["kit-B"], argsB; got != want {
|
||||
t.Errorf("OnToolExecution(finish) for kit-B: args = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -9,12 +9,19 @@ import (
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpExecutor is the subset of *tools.MCPToolManager that the adapter
|
||||
// actually uses. Extracted as an interface so the adapter is unit-testable
|
||||
// without constructing a full manager + connection pool.
|
||||
type mcpExecutor interface {
|
||||
ExecuteTool(ctx context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error)
|
||||
}
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
manager *tools.MCPToolManager
|
||||
exec mcpExecutor
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
@@ -29,10 +36,26 @@ func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
//
|
||||
// MCP-side failures (JSON-RPC protocol errors, transport failures, schema
|
||||
// validation rejections from the server) are surfaced to the model as soft
|
||||
// tool errors rather than escalated to a critical agent error. This matches
|
||||
// the contract that native Kit tools follow via kit.ErrorResult(...) and
|
||||
// lets the model self-correct (e.g. retry with a fixed argument shape) or
|
||||
// give up gracefully rather than aborting the turn mid-run.
|
||||
//
|
||||
// Context cancellation is the one exception: if the caller cancelled the
|
||||
// context the turn was aborted intentionally, so we propagate the ctx error
|
||||
// to let the agent loop unwind cleanly.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
result, err := t.exec.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
return fantasy.ToolResponse{}, ctxErr
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(
|
||||
fmt.Sprintf("MCP tool %q failed: %s", t.tool.Name, err.Error()),
|
||||
), nil
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
@@ -57,8 +80,8 @@ func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManage
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
manager: manager,
|
||||
tool: t,
|
||||
exec: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// stubExecutor lets each test script the (result, err) pair returned by
|
||||
// ExecuteTool. The adapter holds an mcpExecutor interface, so this is the
|
||||
// only seam the tests need.
|
||||
type stubExecutor struct {
|
||||
result *tools.MCPToolResult
|
||||
err error
|
||||
// called records the last invocation for assertion.
|
||||
called bool
|
||||
name string
|
||||
input string
|
||||
}
|
||||
|
||||
func (s *stubExecutor) ExecuteTool(_ context.Context, prefixedName, inputJSON string) (*tools.MCPToolResult, error) {
|
||||
s.called = true
|
||||
s.name = prefixedName
|
||||
s.input = inputJSON
|
||||
return s.result, s.err
|
||||
}
|
||||
|
||||
func newMCPAgentTool(exec mcpExecutor, name string) *mcpAgentTool {
|
||||
return &mcpAgentTool{
|
||||
tool: tools.MCPTool{Name: name},
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager-side Go errors (JSON-RPC protocol errors, transport failures,
|
||||
// schema validation rejections from the MCP server) must be surfaced to
|
||||
// the model as soft tool errors so the agent loop can keep going. Aborting
|
||||
// the turn would discard all prior tool results — see issue #N.
|
||||
func TestMCPAgentTool_RPCErrorBecomesSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
err: errors.New("MCP error -32602: Invalid params: missing field \"task\""),
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{
|
||||
ID: "call-1",
|
||||
Name: "pubmed__search",
|
||||
Input: `{"query":"foo"}`,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft), got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "pubmed__search") {
|
||||
t.Errorf("expected tool name in error content, got %q", resp.Content)
|
||||
}
|
||||
if !strings.Contains(resp.Content, "-32602") {
|
||||
t.Errorf("expected underlying error text in content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Context cancellation is the one error that must remain critical: it
|
||||
// means the caller intentionally aborted, and the agent loop needs to
|
||||
// unwind cleanly rather than burning more steps.
|
||||
func TestMCPAgentTool_CtxCancelStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
// Real managers typically return ctx.Err() (or a wrapper) when the
|
||||
// context is cancelled mid-call.
|
||||
err: context.Canceled,
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Deadline-exceeded behaves the same as cancellation: ctx.Err() is
|
||||
// non-nil, so the adapter must propagate the critical error rather than
|
||||
// converting the executor's error into a soft response.
|
||||
func TestMCPAgentTool_CtxDeadlineStaysCritical(t *testing.T) {
|
||||
exec := &stubExecutor{err: context.DeadlineExceeded}
|
||||
tool := newMCPAgentTool(exec, "slow__tool")
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Second))
|
||||
defer cancel()
|
||||
|
||||
resp, err := tool.Run(ctx, fantasy.ToolCall{Name: "slow__tool"})
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
if resp.IsError || resp.Content != "" {
|
||||
t.Errorf("expected empty response on critical error, got IsError=%v Content=%q", resp.IsError, resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side soft errors (CallToolResult{ isError: true }) must continue
|
||||
// to flow through as soft errors — this was the existing behavior and
|
||||
// must not regress.
|
||||
func TestMCPAgentTool_ServerIsErrorRemainsSoftError(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: true,
|
||||
Content: "search service is rate limited; try again in 30s",
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error, got %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Fatalf("expected IsError=true, got false")
|
||||
}
|
||||
if resp.Content != "search service is rate limited; try again in 30s" {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Happy path: ordinary successful tool result is passed through unchanged.
|
||||
func TestMCPAgentTool_SuccessIsPassthrough(t *testing.T) {
|
||||
exec := &stubExecutor{
|
||||
result: &tools.MCPToolResult{
|
||||
IsError: false,
|
||||
Content: `{"hits":3}`,
|
||||
},
|
||||
}
|
||||
tool := newMCPAgentTool(exec, "pubmed__search")
|
||||
|
||||
resp, err := tool.Run(context.Background(), fantasy.ToolCall{Name: "pubmed__search"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("expected IsError=false")
|
||||
}
|
||||
if resp.Content != `{"hits":3}` {
|
||||
t.Errorf("expected pass-through content, got %q", resp.Content)
|
||||
}
|
||||
}
|
||||
+217
-33
@@ -13,6 +13,7 @@ import (
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
@@ -70,14 +71,24 @@ type App struct {
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
// widgetUpdatePending is set to true while a WidgetUpdateEvent burst is
|
||||
// being coalesced. The leading edge fires immediately; subsequent calls
|
||||
// within the debounce window set widgetUpdateTrailing so a final event
|
||||
// is delivered with the latest runner state at the end of the window.
|
||||
// Without the trailing send, a rapid SetWidget→RemoveWidget pair (e.g.
|
||||
// SubagentEnd pushing a final frame then removing the widget) would let
|
||||
// the second call get silently dropped, leaving the TUI's layout stuck
|
||||
// on the pre-removal widget height — visible as empty rows below the
|
||||
// status bar after the widget disappears.
|
||||
widgetUpdatePending atomic.Bool
|
||||
widgetUpdateTrailing atomic.Bool
|
||||
|
||||
// steerDrainFn is the test seam used by releaseBusyAfterCompact to pull
|
||||
// any steer messages that arrived during compaction. In production it is
|
||||
// nil and the helper falls back to a.opts.Kit.DrainSteer(); tests that
|
||||
// need to exercise the steer-drain path without standing up a full
|
||||
// *kit.Kit can set this field directly to inject fake items.
|
||||
steerDrainFn func() []queueItem
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
@@ -333,6 +344,90 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
}
|
||||
}
|
||||
|
||||
// PopLastUserMessage truncates the tree session back to the parent of the
|
||||
// most recent user message on the current branch, syncs the in-memory
|
||||
// message store, and returns the user prompt text plus any image file
|
||||
// parts so the caller can resubmit via Run/RunWithFiles.
|
||||
//
|
||||
// This is the building block for /retry: the user message and any orphaned
|
||||
// assistant/tool entries produced by a failed turn become unreachable on
|
||||
// the current branch (they remain in the session file under a different
|
||||
// leaf) and are excluded from the next LLM context.
|
||||
//
|
||||
// Returns an error when:
|
||||
// - the agent is currently working (busy)
|
||||
// - the app has been closed
|
||||
// - no tree session is active (sessions disabled via --no-session)
|
||||
// - no user message exists on the current branch
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return "", nil, fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return "", nil, fmt.Errorf("cannot retry while the agent is working")
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
ts := a.opts.TreeSession
|
||||
if ts == nil {
|
||||
return "", nil, fmt.Errorf("no tree session active; /retry requires a session")
|
||||
}
|
||||
|
||||
// Walk the current branch backwards to find the most recent user message.
|
||||
branch := ts.GetBranch("")
|
||||
var target *session.MessageEntry
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
me, ok := branch[i].(*session.MessageEntry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if me.Role == string(message.RoleUser) {
|
||||
target = me
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return "", nil, fmt.Errorf("no user message to retry")
|
||||
}
|
||||
|
||||
// Extract the prompt text and any image parts from the target entry.
|
||||
msg, err := target.ToMessage()
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("decode user message: %w", err)
|
||||
}
|
||||
prompt := msg.Content()
|
||||
var files []kit.LLMFilePart
|
||||
for _, part := range msg.Parts {
|
||||
if ic, ok := part.(message.ImageContent); ok {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Data: ic.Data,
|
||||
MediaType: ic.MediaType,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Move the leaf to the parent of the user message. The failed turn's
|
||||
// entries (user message + any partial assistant/tool entries) are still
|
||||
// in the tree file but no longer on the active branch, so they will not
|
||||
// be re-sent to the LLM. runTurn() will append a fresh user message on
|
||||
// the next call.
|
||||
if err := ts.Branch(target.ParentID); err != nil {
|
||||
return "", nil, fmt.Errorf("branch to parent: %w", err)
|
||||
}
|
||||
|
||||
// Sync the in-memory store with the new branch position so subsequent
|
||||
// reads (and ReloadMessagesFromTree() consumers) see the truncated view.
|
||||
a.store.Clear()
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
|
||||
return prompt, files, nil
|
||||
}
|
||||
|
||||
// AddContextMessage adds a user-role message to the conversation history
|
||||
// without triggering an LLM response. Used by the ! shell command prefix
|
||||
// to inject command output into context so the LLM can reference it in
|
||||
@@ -356,6 +451,10 @@ func (a *App) AddContextMessage(text string) {
|
||||
// tea.Program. customInstructions is optional text appended to the summary
|
||||
// prompt (e.g. "Focus on the API design decisions").
|
||||
//
|
||||
// Any prompts queued via Run/RunWithFiles or steering messages injected via
|
||||
// Steer/SteerWithFiles while compaction is running are flushed automatically
|
||||
// once compaction completes (see releaseBusyAfterCompact).
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Lock()
|
||||
@@ -377,11 +476,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -420,6 +515,9 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
//
|
||||
// Like CompactConversation, any prompts/steer messages received during
|
||||
// compaction are flushed automatically once compaction finishes.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
@@ -440,11 +538,7 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
defer a.releaseBusyAfterCompact()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
@@ -489,6 +583,81 @@ func (a *App) CompactAsync(customInstructions string, onComplete func(), onError
|
||||
return nil
|
||||
}
|
||||
|
||||
// releaseBusyAfterCompact is the deferred tail that runs at the end of every
|
||||
// compaction goroutine (success, error, or panic-after-recover paths). It
|
||||
// flips a.busy back to false, but before doing so it checks whether any
|
||||
// prompts piled up while compaction was running:
|
||||
//
|
||||
// - Run/RunWithFiles append to a.queue when a.busy is set.
|
||||
// - Steer/SteerWithFiles deposit messages into the SDK steer channel via
|
||||
// Kit.InjectSteerWithFiles when a.busy is set.
|
||||
//
|
||||
// Without this hand-off the queue would sit idle until the user submits
|
||||
// another prompt — see issue #27. If we find anything pending we keep busy
|
||||
// set, splice the steer messages to the front of the queue, and start a
|
||||
// fresh drainQueue goroutine to deliver them as a single batched turn.
|
||||
func (a *App) releaseBusyAfterCompact() {
|
||||
// Pull steer messages outside the app mutex; DrainSteer takes its own
|
||||
// internal lock and we don't want to nest the two. The test seam
|
||||
// (a.steerDrainFn) takes precedence so unit tests can inject fake
|
||||
// steer items without a real *kit.Kit.
|
||||
var steerItems []queueItem
|
||||
switch {
|
||||
case a.steerDrainFn != nil:
|
||||
steerItems = a.steerDrainFn()
|
||||
case a.opts.Kit != nil:
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
steerItems = make([]queueItem, len(leftover))
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.mu.Lock()
|
||||
// If the app was closed while compaction was running, drop everything
|
||||
// and just clear busy. Run/Steer would have rejected new items already
|
||||
// after Close(), but this guards against in-flight items that slipped
|
||||
// in just before closed was set.
|
||||
if a.closed {
|
||||
a.queue = a.queue[:0]
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Combine steer-channel items (front) with the in-memory queue (back).
|
||||
// Steer messages are placed first so they retain their "act now"
|
||||
// semantics relative to ordinary queued prompts that arrived later.
|
||||
pending := append(steerItems, a.queue...)
|
||||
a.queue = a.queue[:0]
|
||||
|
||||
if len(pending) == 0 {
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hand off to drainQueue: it will pick up the first item directly and
|
||||
// scoop the rest from a.queue on its first iteration.
|
||||
first := pending[0]
|
||||
if len(pending) > 1 {
|
||||
a.queue = append(a.queue, pending[1:]...)
|
||||
}
|
||||
// Stay busy across the goroutine swap.
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Notify the UI that steer-channel messages were consumed so the
|
||||
// steering badge can clear; ordinary queued prompts will be reflected
|
||||
// by the QueueUpdatedEvent that drainQueue emits as it picks them up.
|
||||
if len(steerItems) > 0 {
|
||||
a.sendEvent(SteerConsumedEvent{})
|
||||
}
|
||||
|
||||
go a.drainQueue(first)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -1076,32 +1245,47 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
// Coalescing (leading + trailing edge): the first call in an idle period
|
||||
// fires immediately for responsiveness. Subsequent calls within a ~16 ms
|
||||
// debounce window are batched into a single trailing event delivered at
|
||||
// the end of the window. The trailing send is essential for correctness:
|
||||
// extensions routinely make tight SetWidget→RemoveWidget pairs (e.g. on
|
||||
// SubagentEnd) and silently dropping the second call would leave the TUI's
|
||||
// layout stuck on stale widget dimensions until some other event happens
|
||||
// to trigger a re-render.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
// A leading-edge event is already in flight — mark that the runner
|
||||
// state has changed again so the trailing send below picks it up.
|
||||
a.widgetUpdateTrailing.Store(true)
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
if prog == nil {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
return
|
||||
}
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
// If any extra calls came in during the debounce window, deliver
|
||||
// one trailing event so the TUI sees the latest widget state. We
|
||||
// swap-and-test instead of plain-load so concurrent calls after
|
||||
// the trailing send still race correctly with the pending reset.
|
||||
if a.widgetUpdateTrailing.Swap(false) {
|
||||
a.mu.Lock()
|
||||
p := a.program
|
||||
a.mu.Unlock()
|
||||
if p != nil {
|
||||
p.Send(WidgetUpdateEvent{})
|
||||
}
|
||||
}
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
|
||||
@@ -9,7 +9,10 @@ import (
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -763,3 +766,352 @@ func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// releaseBusyAfterCompact (issue #27)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestReleaseBusyAfterCompact_flushesQueuedMessages is a regression test for
|
||||
// issue #27: messages queued via Run() while /compact is running used to sit
|
||||
// in a.queue indefinitely until the user typed another prompt. After the fix
|
||||
// the deferred releaseBusyAfterCompact tail picks up any pending items and
|
||||
// dispatches drainQueue automatically.
|
||||
//
|
||||
// We simulate the compaction completion path directly (bypassing the SDK)
|
||||
// by toggling busy=true, populating the queue exactly as Run() would have
|
||||
// during compaction, and then invoking releaseBusyAfterCompact.
|
||||
func TestReleaseBusyAfterCompact_flushesQueuedMessages(t *testing.T) {
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("compacted then drained"), nil
|
||||
},
|
||||
)
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate the state at the start of the compaction tail: busy is set
|
||||
// and a couple of prompts have piled up in the queue while we were
|
||||
// summarising. (Run() would have appended them and returned a queue
|
||||
// length > 0 to the caller.)
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued during compact #1"},
|
||||
queueItem{Prompt: "queued during compact #2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
// Invoke the deferred tail directly. It should kick off drainQueue.
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// drainQueue runs in a goroutine. Wait for the app to come back to idle.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after releaseBusyAfterCompact: queue not drained")
|
||||
}
|
||||
|
||||
// Wait for any in-flight goroutine to finish before reading state.
|
||||
app.wg.Wait()
|
||||
|
||||
if got := app.QueueLength(); got != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d", got)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatalf("expected stub PromptFunc to fire at least once after compact, got %d calls", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_idleWhenQueueEmpty verifies that with no
|
||||
// pending messages the helper just clears busy and does NOT spawn a
|
||||
// drainQueue goroutine (no spurious agent turn).
|
||||
func TestReleaseBusyAfterCompact_idleWhenQueueEmpty(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false after releaseBusyAfterCompact with empty queue")
|
||||
}
|
||||
|
||||
// Give any rogue goroutine a moment to (incorrectly) call PromptFunc.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls when queue empty, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue exercises the SDK
|
||||
// steer-drain branch of releaseBusyAfterCompact (issue #27 follow-up).
|
||||
//
|
||||
// Production wires a.opts.Kit.DrainSteer() to pull messages that arrived via
|
||||
// Steer/SteerWithFiles during compaction, but Options.Kit is *kit.Kit (a
|
||||
// concrete struct) so unit tests cannot stand up a real instance without a
|
||||
// full LLM backend. The test uses the unexported steerDrainFn seam to inject
|
||||
// fake steer items, then asserts that:
|
||||
//
|
||||
// - Steer items are dispatched ahead of any prompts that piled up in
|
||||
// a.queue (steer retains "act now" priority over ordinary queued
|
||||
// prompts), and
|
||||
// - the helper still hands off to drainQueue so the steer item actually
|
||||
// fires (the previous behaviour left them stranded — see #27).
|
||||
func TestReleaseBusyAfterCompact_splicesSteerAheadOfQueue(t *testing.T) {
|
||||
var pmu sync.Mutex
|
||||
var firstPrompt string
|
||||
stub := newStubWithFuncs(
|
||||
func(ctx context.Context) (*kit.TurnResult, error) {
|
||||
return turnResult("steer dispatched"), nil
|
||||
},
|
||||
)
|
||||
// Wrap PromptFunc so we can capture the prompt text the stub receives
|
||||
// (newStubWithFuncs's fns ignore prompt; we need it to verify ordering).
|
||||
capturingPrompt := func(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
pmu.Lock()
|
||||
if firstPrompt == "" {
|
||||
firstPrompt = prompt
|
||||
}
|
||||
pmu.Unlock()
|
||||
return stub.fn(ctx, prompt)
|
||||
}
|
||||
app := New(Options{PromptFunc: capturingPrompt}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Inject fake steer items via the test seam. In production the same
|
||||
// items would have been delivered through Kit.InjectSteerWithFiles
|
||||
// during /compact and pulled by DrainSteer here.
|
||||
app.steerDrainFn = func() []queueItem {
|
||||
return []queueItem{
|
||||
{Prompt: "steer-1"},
|
||||
{Prompt: "steer-2"},
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate the state at the end of compaction: busy is set and a couple
|
||||
// of regular Run() prompts have piled up after the steer messages.
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue,
|
||||
queueItem{Prompt: "queued-1"},
|
||||
queueItem{Prompt: "queued-2"},
|
||||
)
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
// Wait for the dispatched batch to complete.
|
||||
ok := waitForCondition(2*time.Second, func() bool {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return !app.busy
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("app did not become idle after steer-spliced releaseBusyAfterCompact")
|
||||
}
|
||||
app.wg.Wait()
|
||||
|
||||
// drainQueue picks up `first` directly and batches the rest. With
|
||||
// PromptFunc set, executeBatch invokes us with items[0] only — that
|
||||
// item must be the first steer message, proving steer items were
|
||||
// spliced ahead of the previously queued prompts.
|
||||
pmu.Lock()
|
||||
got := firstPrompt
|
||||
pmu.Unlock()
|
||||
if got != "steer-1" {
|
||||
t.Fatalf("expected first dispatched prompt to be steer item %q (steer items must come before queued prompts), got %q",
|
||||
"steer-1", got)
|
||||
}
|
||||
|
||||
// Queue should be fully drained and PromptFunc must have actually fired.
|
||||
if n := app.QueueLength(); n != 0 {
|
||||
t.Fatalf("expected empty queue after drain, got %d entries", n)
|
||||
}
|
||||
if n := stub.callCount(); n == 0 {
|
||||
t.Fatal("expected stub PromptFunc to fire at least once after splice")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReleaseBusyAfterCompact_dropsQueueWhenClosed verifies that if the app
|
||||
// was closed during compaction the helper discards any pending items rather
|
||||
// than spawning drainQueue against a torn-down App.
|
||||
func TestReleaseBusyAfterCompact_dropsQueueWhenClosed(t *testing.T) {
|
||||
stub := newStub()
|
||||
app := newTestApp(stub)
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.queue = append(app.queue, queueItem{Prompt: "would have run"})
|
||||
app.closed = true
|
||||
app.mu.Unlock()
|
||||
|
||||
app.releaseBusyAfterCompact()
|
||||
|
||||
app.mu.Lock()
|
||||
busy := app.busy
|
||||
qLen := len(app.queue)
|
||||
app.mu.Unlock()
|
||||
if busy {
|
||||
t.Fatal("expected busy=false even when closed")
|
||||
}
|
||||
if qLen != 0 {
|
||||
t.Fatalf("expected queue cleared on closed app, got %d entries", qLen)
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if n := stub.callCount(); n != 0 {
|
||||
t.Fatalf("expected 0 PromptFunc calls on closed app, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// PopLastUserMessage (/retry building block)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// TestPopLastUserMessage_NoTreeSession verifies that PopLastUserMessage
|
||||
// returns an error when no tree session is active.
|
||||
func TestPopLastUserMessage_NoTreeSession(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
defer app.Close()
|
||||
|
||||
prompt, files, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no tree session is active")
|
||||
}
|
||||
if prompt != "" || files != nil {
|
||||
t.Fatalf("expected zero values on error, got prompt=%q files=%v", prompt, files)
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_WhileBusy verifies that PopLastUserMessage
|
||||
// refuses to truncate while the agent is busy (would race with executeBatch).
|
||||
func TestPopLastUserMessage_WhileBusy(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
defer app.Close()
|
||||
|
||||
app.mu.Lock()
|
||||
app.busy = true
|
||||
app.mu.Unlock()
|
||||
|
||||
_, _, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when agent is busy")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "working") {
|
||||
t.Fatalf("expected error mentioning busy/working, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_WhenClosed verifies that PopLastUserMessage
|
||||
// returns an error after Close().
|
||||
func TestPopLastUserMessage_WhenClosed(t *testing.T) {
|
||||
app := newTestApp(newStub())
|
||||
app.Close()
|
||||
|
||||
_, _, err := app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error on closed app")
|
||||
}
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_TruncatesAndReturnsPrompt verifies the happy path:
|
||||
// a real tree session with user→assistant→user→assistant entries is
|
||||
// truncated back to before the most recent user message, and that user's
|
||||
// text is returned.
|
||||
func TestPopLastUserMessage_TruncatesAndReturnsPrompt(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ts, err := session.CreateTreeSession(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("create tree session: %v", err)
|
||||
}
|
||||
defer func() { _ = ts.Close() }()
|
||||
|
||||
// Build history: user "first" → assistant "ack 1" → user "second" → assistant "ack 2".
|
||||
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("first")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 1"}},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.NewUserMessage("second")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := ts.AppendLLMMessage(fantasy.Message{
|
||||
Role: fantasy.MessageRoleAssistant,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "ack 2"}},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
|
||||
defer app.Close()
|
||||
|
||||
prompt, files, err := app.PopLastUserMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("PopLastUserMessage: %v", err)
|
||||
}
|
||||
if prompt != "second" {
|
||||
t.Fatalf("expected prompt=%q, got %q", "second", prompt)
|
||||
}
|
||||
if files != nil {
|
||||
t.Fatalf("expected no files, got %v", files)
|
||||
}
|
||||
|
||||
// After truncation the branch should only contain the first user
|
||||
// message and its assistant response (the "second" turn is orphaned).
|
||||
msgs := ts.GetLLMMessages()
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("expected 2 messages on truncated branch, got %d", len(msgs))
|
||||
}
|
||||
if got := messageText(msgs[0]); got != "first" {
|
||||
t.Fatalf("expected first message %q, got %q", "first", got)
|
||||
}
|
||||
if got := messageText(msgs[1]); got != "ack 1" {
|
||||
t.Fatalf("expected second message %q, got %q", "ack 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
// messageText extracts concatenated TextPart content from a fantasy.Message.
|
||||
func messageText(m fantasy.Message) string {
|
||||
var out strings.Builder
|
||||
for _, p := range m.Content {
|
||||
if tp, ok := p.(fantasy.TextPart); ok {
|
||||
out.WriteString(tp.Text)
|
||||
}
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
// TestPopLastUserMessage_NoUserOnBranch verifies that an empty tree (no
|
||||
// user messages at all) returns a friendly error rather than panicking.
|
||||
func TestPopLastUserMessage_NoUserOnBranch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
ts, err := session.CreateTreeSession(dir)
|
||||
if err != nil {
|
||||
t.Fatalf("create tree session: %v", err)
|
||||
}
|
||||
defer func() { _ = ts.Close() }()
|
||||
|
||||
app := New(Options{TreeSession: ts, PromptFunc: newStub().fn}, nil)
|
||||
defer app.Close()
|
||||
|
||||
_, _, err = app.PopLastUserMessage()
|
||||
if err == nil {
|
||||
t.Fatal("expected error when no user message exists on branch")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no user message") {
|
||||
t.Fatalf("expected error mentioning missing user message, got %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,11 +13,6 @@ type MessageStore struct {
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
func NewMessageStore() *MessageStore {
|
||||
return &MessageStore{}
|
||||
}
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
|
||||
@@ -29,7 +29,7 @@ func textOf(msg kit.LLMMessage) string {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestNewMessageStore_empty(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
if s == nil {
|
||||
t.Fatal("expected non-nil store")
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestAdd_appendsMessage(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "first"))
|
||||
s.Add(makeTextMsg("assistant", "second"))
|
||||
|
||||
@@ -82,7 +82,7 @@ func TestAdd_appendsMessage(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAdd_preservesOrder(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
texts := []string{"a", "b", "c"}
|
||||
for _, t2 := range texts {
|
||||
s.Add(makeTextMsg("user", t2))
|
||||
@@ -100,7 +100,7 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []kit.LLMMessage{
|
||||
@@ -120,7 +120,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
@@ -137,7 +137,7 @@ func TestReplace_isolatesInput(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestGetAll_returnsCopy(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "hello"))
|
||||
|
||||
got := s.GetAll()
|
||||
@@ -151,7 +151,7 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetAll_emptyStore(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
got := s.GetAll()
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty slice, got %d elements", len(got))
|
||||
@@ -163,7 +163,7 @@ func TestGetAll_emptyStore(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestClear_removesAllMessages(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "a"))
|
||||
s.Add(makeTextMsg("user", "b"))
|
||||
s.Clear()
|
||||
@@ -174,7 +174,7 @@ func TestClear_removesAllMessages(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
s.Add(makeTextMsg("user", "before"))
|
||||
s.Clear()
|
||||
s.Add(makeTextMsg("user", "after"))
|
||||
@@ -193,7 +193,7 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s := NewMessageStoreWithMessages(nil)
|
||||
done := make(chan struct{})
|
||||
|
||||
// Writer goroutine.
|
||||
|
||||
+135
-45
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -9,11 +10,11 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// CredentialStore holds all stored credentials for various providers.
|
||||
// Currently supports Anthropic and OpenAI credentials with both OAuth and API key authentication methods.
|
||||
// CredentialStore holds stored credentials for Anthropic, OpenAI, and GitHub Copilot.
|
||||
type CredentialStore struct {
|
||||
Anthropic *AnthropicCredentials `json:"anthropic,omitempty"`
|
||||
OpenAI *OpenAICredentials `json:"openai,omitempty"`
|
||||
Copilot *CopilotCredentials `json:"copilot,omitempty"`
|
||||
}
|
||||
|
||||
// AnthropicCredentials holds Anthropic API credentials supporting both OAuth
|
||||
@@ -43,6 +44,16 @@ type OpenAICredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// CopilotCredentials holds GitHub OAuth credentials and the short-lived
|
||||
// GitHub Copilot API token derived from them.
|
||||
type CopilotCredentials struct {
|
||||
Type string `json:"type"` // "oauth"
|
||||
GitHubToken string `json:"github_token,omitempty"` // GitHub device-flow OAuth token
|
||||
CopilotAccessToken string `json:"copilot_access_token,omitempty"` // Short-lived Copilot API token
|
||||
ExpiresAt int64 `json:"expires_at,omitempty"` // Copilot token expiry
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
@@ -91,6 +102,16 @@ func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the Copilot API token is expired.
|
||||
func (c *CopilotCredentials) IsExpired() bool {
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh reports whether the Copilot API token should be renewed.
|
||||
func (c *CopilotCredentials) NeedsRefresh() bool {
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
// It manages a JSON file stored in the user's config directory with appropriate
|
||||
// file permissions for security.
|
||||
@@ -222,7 +243,7 @@ func (cm *CredentialManager) RemoveAnthropicCredentials() error {
|
||||
store.Anthropic = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil {
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
@@ -255,29 +276,6 @@ func (cm *CredentialManager) HasAnthropicCredentials() (bool, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetOpenAICredentials stores OpenAI API key credentials. It validates the
|
||||
// API key format before storing. The API key must start with "sk-" and be
|
||||
// at least 20 characters long. Returns an error if the API key is invalid or
|
||||
// if storage fails.
|
||||
func (cm *CredentialManager) SetOpenAICredentials(apiKey string) error {
|
||||
if err := validateOpenAIAPIKey(apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.OpenAI = &OpenAICredentials{
|
||||
Type: "api_key",
|
||||
APIKey: apiKey,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetOpenAICredentials retrieves stored OpenAI credentials. Returns nil if
|
||||
// no credentials are stored. The returned credentials may be either OAuth or API
|
||||
// key type, check the Type field to determine which.
|
||||
@@ -302,7 +300,7 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
store.OpenAI = nil
|
||||
|
||||
// If store is empty, remove the file entirely
|
||||
if store.Anthropic == nil && store.OpenAI == nil {
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
@@ -312,6 +310,104 @@ func (cm *CredentialManager) RemoveOpenAICredentials() error {
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetCopilotCredentials retrieves stored GitHub Copilot credentials.
|
||||
func (cm *CredentialManager) GetCopilotCredentials() (*CopilotCredentials, error) {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.Copilot, nil
|
||||
}
|
||||
|
||||
// RemoveCopilotCredentials removes stored GitHub Copilot credentials.
|
||||
func (cm *CredentialManager) RemoveCopilotCredentials() error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.Copilot = nil
|
||||
|
||||
if store.Anthropic == nil && store.OpenAI == nil && store.Copilot == nil {
|
||||
if err := os.Remove(cm.credentialsPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to remove credentials file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// HasCopilotCredentials checks if valid GitHub Copilot credentials are stored.
|
||||
func (cm *CredentialManager) HasCopilotCredentials() (bool, error) {
|
||||
creds, err := cm.GetCopilotCredentials()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if creds == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return creds.Type == "oauth" && creds.GitHubToken != "", nil
|
||||
}
|
||||
|
||||
// SetCopilotOAuthCredentials stores GitHub Copilot OAuth credentials.
|
||||
func (cm *CredentialManager) SetCopilotOAuthCredentials(creds *CopilotCredentials) error {
|
||||
store, err := cm.LoadCredentials()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.Copilot = creds
|
||||
return cm.SaveCredentials(store)
|
||||
}
|
||||
|
||||
// GetValidCopilotAccessToken returns a fresh Copilot API token, renewing it
|
||||
// with the stored GitHub OAuth token when needed.
|
||||
func (cm *CredentialManager) GetValidCopilotAccessToken() (string, error) {
|
||||
return cm.GetValidCopilotAccessTokenContext(context.Background())
|
||||
}
|
||||
|
||||
// GetValidCopilotAccessTokenContext returns a fresh Copilot API token, renewing
|
||||
// it with the stored GitHub OAuth token when needed.
|
||||
func (cm *CredentialManager) GetValidCopilotAccessTokenContext(ctx context.Context) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
creds, err := cm.GetCopilotCredentials()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if creds == nil {
|
||||
return "", fmt.Errorf("no Copilot credentials found")
|
||||
}
|
||||
if creds.Type != "oauth" {
|
||||
return "", fmt.Errorf("unknown credential type: %s", creds.Type)
|
||||
}
|
||||
if creds.GitHubToken == "" {
|
||||
return "", fmt.Errorf("GitHub OAuth token missing from Copilot credentials")
|
||||
}
|
||||
|
||||
if creds.CopilotAccessToken == "" || creds.NeedsRefresh() {
|
||||
client := NewCopilotOAuthClient()
|
||||
newCreds, err := client.RefreshCopilotToken(ctx, creds.GitHubToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh Copilot token: %w", err)
|
||||
}
|
||||
newCreds.CreatedAt = creds.CreatedAt
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(newCreds); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed Copilot token: %w", err)
|
||||
}
|
||||
|
||||
return newCreds.CopilotAccessToken, nil
|
||||
}
|
||||
|
||||
return creds.CopilotAccessToken, nil
|
||||
}
|
||||
|
||||
// HasOpenAICredentials checks if valid OpenAI credentials are stored.
|
||||
// Returns true if either a non-empty OAuth access token or API key is present,
|
||||
// false otherwise. Returns an error if credentials cannot be loaded.
|
||||
@@ -417,24 +513,18 @@ func validateAnthropicAPIKey(apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateOpenAIAPIKey validates the format of an OpenAI API key
|
||||
func validateOpenAIAPIKey(apiKey string) error {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
// CredentialSourceOAuth is the source description returned by
|
||||
// GetAnthropicAPIKey when the key resolves to stored OAuth credentials.
|
||||
// Consumers should compare against this constant (or use IsAnthropicOAuth)
|
||||
// rather than matching the string literal.
|
||||
const CredentialSourceOAuth = "stored OAuth credentials"
|
||||
|
||||
if apiKey == "" {
|
||||
return fmt.Errorf("API key cannot be empty")
|
||||
}
|
||||
|
||||
// OpenAI API keys typically start with "sk-" and are quite long
|
||||
if !strings.HasPrefix(apiKey, "sk-") {
|
||||
return fmt.Errorf("invalid OpenAI API key format (should start with 'sk-')")
|
||||
}
|
||||
|
||||
if len(apiKey) < 20 {
|
||||
return fmt.Errorf("API key appears to be too short")
|
||||
}
|
||||
|
||||
return nil
|
||||
// IsAnthropicOAuth reports whether the active Anthropic credential resolves
|
||||
// to a stored OAuth token (in which case the user is not billed per-token).
|
||||
// flagValue is the --provider-api-key flag value (may be empty).
|
||||
func IsAnthropicOAuth(flagValue string) bool {
|
||||
_, source, err := GetAnthropicAPIKey(flagValue)
|
||||
return err == nil && source == CredentialSourceOAuth
|
||||
}
|
||||
|
||||
// GetAnthropicAPIKey retrieves an Anthropic API key from multiple sources in priority order:
|
||||
@@ -459,7 +549,7 @@ func GetAnthropicAPIKey(flagValue string) (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
return token, "stored OAuth credentials", nil
|
||||
return token, CredentialSourceOAuth, nil
|
||||
} else if creds.Type == "api_key" && creds.APIKey != "" {
|
||||
return creds.APIKey, "stored API key", nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCredentialManager(t *testing.T) {
|
||||
@@ -215,6 +216,7 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
credentialsPath := filepath.Join(tempDir, "credentials.json")
|
||||
@@ -252,3 +254,98 @@ func TestCredentialStorePersistence(t *testing.T) {
|
||||
t.Errorf("Expected file permissions 0600, got %v", info.Mode().Perm())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotCredentials(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "kit-auth-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
cm := &CredentialManager{
|
||||
credentialsPath: filepath.Join(tempDir, "credentials.json"),
|
||||
}
|
||||
|
||||
creds := &CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: "github-token",
|
||||
CopilotAccessToken: "copilot-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := cm.SetCopilotOAuthCredentials(creds); err != nil {
|
||||
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
hasAuth, err := cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasCopilotCredentials failed: %v", err)
|
||||
}
|
||||
if !hasAuth {
|
||||
t.Fatal("Expected Copilot credentials")
|
||||
}
|
||||
|
||||
token, err := cm.GetValidCopilotAccessToken()
|
||||
if err != nil {
|
||||
t.Fatalf("GetValidCopilotAccessToken failed: %v", err)
|
||||
}
|
||||
if token != creds.CopilotAccessToken {
|
||||
t.Fatalf("Expected Copilot token %q, got %q", creds.CopilotAccessToken, token)
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
|
||||
}
|
||||
hasAuth, err = cm.HasCopilotCredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasCopilotCredentials after removal failed: %v", err)
|
||||
}
|
||||
if hasAuth {
|
||||
t.Fatal("Expected no Copilot credentials after removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveCredentialsPreservesOtherProviders(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "kit-auth-test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(tempDir) }()
|
||||
|
||||
cm := &CredentialManager{
|
||||
credentialsPath: filepath.Join(tempDir, "credentials.json"),
|
||||
}
|
||||
|
||||
if err := cm.SetOpenAIOAuthCredentials(&OpenAICredentials{
|
||||
Type: "oauth",
|
||||
AccessToken: "openai-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
AccountID: "account",
|
||||
CreatedAt: time.Now(),
|
||||
}); err != nil {
|
||||
t.Fatalf("SetOpenAIOAuthCredentials failed: %v", err)
|
||||
}
|
||||
if err := cm.SetCopilotOAuthCredentials(&CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: "github-token",
|
||||
CopilotAccessToken: "copilot-token",
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
CreatedAt: time.Now(),
|
||||
}); err != nil {
|
||||
t.Fatalf("SetCopilotOAuthCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
if err := cm.RemoveCopilotCredentials(); err != nil {
|
||||
t.Fatalf("RemoveCopilotCredentials failed: %v", err)
|
||||
}
|
||||
|
||||
hasOpenAI, err := cm.HasOpenAICredentials()
|
||||
if err != nil {
|
||||
t.Fatalf("HasOpenAICredentials failed: %v", err)
|
||||
}
|
||||
if !hasOpenAI {
|
||||
t.Fatal("Expected OpenAI credentials to remain after removing Copilot credentials")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -211,6 +212,262 @@ type OpenAIOAuthClient struct {
|
||||
Scopes string
|
||||
}
|
||||
|
||||
// CopilotOAuthClient handles GitHub device-flow OAuth and exchanges the
|
||||
// GitHub token for a short-lived GitHub Copilot API token.
|
||||
//
|
||||
// The GitHub token comes from GitHub's OAuth device flow. It is then presented
|
||||
// to GitHub's internal Copilot token endpoint, which returns the bearer token
|
||||
// used by api.githubcopilot.com.
|
||||
type CopilotOAuthClient struct {
|
||||
ClientID string
|
||||
DeviceURL string
|
||||
TokenURL string
|
||||
CopilotURL string
|
||||
Scopes string
|
||||
PollTimeout time.Duration
|
||||
ClientTimeout time.Duration
|
||||
}
|
||||
|
||||
// CopilotDeviceCode contains data returned by GitHub's device-code endpoint.
|
||||
type CopilotDeviceCode struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// NewCopilotOAuthClient creates a GitHub Copilot OAuth client.
|
||||
func NewCopilotOAuthClient() *CopilotOAuthClient {
|
||||
return &CopilotOAuthClient{
|
||||
ClientID: "Iv1.b507a08c87ecfe98",
|
||||
DeviceURL: "https://github.com/login/device/code",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
CopilotURL: "https://api.github.com/copilot_internal/v2/token",
|
||||
Scopes: "read:user",
|
||||
PollTimeout: 15 * time.Minute,
|
||||
ClientTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow requests a GitHub device code for browser login.
|
||||
//
|
||||
// The returned user code and verification URI are displayed by loginCopilot.
|
||||
// GitHub's response may omit interval, so this method normalizes it to the
|
||||
// documented five-second default.
|
||||
func (c *CopilotOAuthClient) StartDeviceFlow(ctx context.Context) (*CopilotDeviceCode, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
"scope": {c.Scopes},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.DeviceURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create device-code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request device code: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("device-code request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var code CopilotDeviceCode
|
||||
if err := json.NewDecoder(resp.Body).Decode(&code); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode device-code response: %w", err)
|
||||
}
|
||||
if code.DeviceCode == "" || code.UserCode == "" || code.VerificationURI == "" {
|
||||
return nil, fmt.Errorf("device-code response missing required fields")
|
||||
}
|
||||
if code.Interval <= 0 {
|
||||
code.Interval = 5
|
||||
}
|
||||
return &code, nil
|
||||
}
|
||||
|
||||
// PollDeviceToken waits until the user authorizes the device code and returns
|
||||
// the resulting GitHub OAuth token.
|
||||
//
|
||||
// It follows GitHub's device-flow polling contract: authorization_pending keeps
|
||||
// polling, slow_down increases the interval, and polling stops at the earlier of
|
||||
// the client timeout or the device-code expiry.
|
||||
func (c *CopilotOAuthClient) PollDeviceToken(ctx context.Context, deviceCode *CopilotDeviceCode) (string, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if deviceCode == nil || deviceCode.DeviceCode == "" {
|
||||
return "", fmt.Errorf("device code missing")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(c.PollTimeout)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
expiresAt := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if expiresAt.Before(deadline) {
|
||||
deadline = expiresAt
|
||||
}
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = 5 * time.Second
|
||||
}
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
wait := interval
|
||||
if remaining := time.Until(deadline); remaining < wait {
|
||||
wait = remaining
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
case <-time.After(wait):
|
||||
}
|
||||
|
||||
data := url.Values{
|
||||
"client_id": {c.ClientID},
|
||||
"device_code": {deviceCode.DeviceCode},
|
||||
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.TokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create device-token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to poll device token: %w", err)
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description"`
|
||||
}
|
||||
decodeErr := json.NewDecoder(resp.Body).Decode(&tokenResp)
|
||||
_ = resp.Body.Close()
|
||||
if decodeErr != nil {
|
||||
return "", fmt.Errorf("failed to decode device-token response: %w", decodeErr)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken != "" {
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
switch tokenResp.Error {
|
||||
case "authorization_pending":
|
||||
continue
|
||||
case "slow_down":
|
||||
interval += 5 * time.Second
|
||||
continue
|
||||
case "expired_token":
|
||||
return "", fmt.Errorf("device code expired; restart login")
|
||||
case "access_denied":
|
||||
return "", fmt.Errorf("github login denied")
|
||||
case "":
|
||||
return "", fmt.Errorf("device-token request failed with status %d", resp.StatusCode)
|
||||
default:
|
||||
if tokenResp.Description != "" {
|
||||
return "", fmt.Errorf("device-token request failed: %s: %s", tokenResp.Error, tokenResp.Description)
|
||||
}
|
||||
return "", fmt.Errorf("device-token request failed: %s", tokenResp.Error)
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("timed out waiting for github device authorization")
|
||||
}
|
||||
|
||||
// ExchangeGitHubToken converts a GitHub OAuth token into a Copilot API token.
|
||||
// It is a semantic wrapper over RefreshCopilotToken used by the login flow.
|
||||
func (c *CopilotOAuthClient) ExchangeGitHubToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
|
||||
return c.RefreshCopilotToken(ctx, githubToken)
|
||||
}
|
||||
|
||||
// RefreshCopilotToken obtains a fresh short-lived Copilot token from GitHub.
|
||||
//
|
||||
// GitHub may return expires_at as either a Unix timestamp or RFC3339 string.
|
||||
// parseCopilotExpiry handles both forms and falls back to a conservative
|
||||
// 20-minute lifetime when the field is absent or unrecognized.
|
||||
func (c *CopilotOAuthClient) RefreshCopilotToken(ctx context.Context, githubToken string) (*CopilotCredentials, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", c.CopilotURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create copilot token request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "token "+githubToken)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "kit")
|
||||
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
|
||||
|
||||
resp, err := (&http.Client{Timeout: c.ClientTimeout}).Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request copilot token: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("copilot token request failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt any `json:"expires_at"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode copilot token response: %w", err)
|
||||
}
|
||||
if tokenResp.Token == "" {
|
||||
return nil, fmt.Errorf("copilot token response missing token")
|
||||
}
|
||||
|
||||
expiresAt := parseCopilotExpiry(tokenResp.ExpiresAt)
|
||||
if expiresAt == 0 {
|
||||
expiresAt = time.Now().Add(20 * time.Minute).Unix()
|
||||
}
|
||||
|
||||
return &CopilotCredentials{
|
||||
Type: "oauth",
|
||||
GitHubToken: githubToken,
|
||||
CopilotAccessToken: tokenResp.Token,
|
||||
ExpiresAt: expiresAt,
|
||||
CreatedAt: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseCopilotExpiry normalizes GitHub's expires_at variants to a Unix second.
|
||||
func parseCopilotExpiry(value any) int64 {
|
||||
switch v := value.(type) {
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
if parsed, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
return parsed.Unix()
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OAuth client configured for OpenAI Codex OAuth.
|
||||
// This uses the public client ID for CLI applications with PKCE for security.
|
||||
func NewOpenAIOAuthClient() *OpenAIOAuthClient {
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCopilotStartDeviceFlow(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm failed: %v", err)
|
||||
}
|
||||
if r.Form.Get("client_id") != "client-id" {
|
||||
t.Fatalf("expected client id, got %q", r.Form.Get("client_id"))
|
||||
}
|
||||
if r.Form.Get("scope") != "read:user" {
|
||||
t.Fatalf("expected scope, got %q", r.Form.Get("scope"))
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"device_code": "device-code",
|
||||
"user_code": "USER-CODE",
|
||||
"verification_uri": "https://github.com/login/device",
|
||||
"expires_in": 600,
|
||||
"interval": 1,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.ClientID = "client-id"
|
||||
client.DeviceURL = server.URL
|
||||
|
||||
code, err := client.StartDeviceFlow(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("StartDeviceFlow failed: %v", err)
|
||||
}
|
||||
if code.DeviceCode != "device-code" || code.UserCode != "USER-CODE" || code.Interval != 1 {
|
||||
t.Fatalf("unexpected device code: %#v", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotPollDeviceToken(t *testing.T) {
|
||||
polls := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
polls++
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm failed: %v", err)
|
||||
}
|
||||
if r.Form.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
|
||||
t.Fatalf("unexpected grant type: %q", r.Form.Get("grant_type"))
|
||||
}
|
||||
if polls == 1 {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"error": "authorization_pending"})
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "github-token"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.ClientID = "client-id"
|
||||
client.TokenURL = server.URL
|
||||
client.PollTimeout = 5 * time.Second
|
||||
client.ClientTimeout = time.Second
|
||||
|
||||
token, err := client.PollDeviceToken(context.Background(), &CopilotDeviceCode{
|
||||
DeviceCode: "device-code",
|
||||
ExpiresIn: 10,
|
||||
Interval: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("PollDeviceToken failed: %v", err)
|
||||
}
|
||||
if token != "github-token" {
|
||||
t.Fatalf("expected github-token, got %q", token)
|
||||
}
|
||||
if polls != 2 {
|
||||
t.Fatalf("expected 2 polls, got %d", polls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotRefreshToken(t *testing.T) {
|
||||
expiresAt := time.Now().Add(time.Hour).Unix()
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
t.Fatalf("expected GET, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Authorization") != "token github-token" {
|
||||
t.Fatalf("unexpected authorization header: %q", r.Header.Get("Authorization"))
|
||||
}
|
||||
if r.Header.Get("User-Agent") != "kit" {
|
||||
t.Fatalf("unexpected user agent: %q", r.Header.Get("User-Agent"))
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"token": "copilot-token",
|
||||
"expires_at": expiresAt,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewCopilotOAuthClient()
|
||||
client.CopilotURL = server.URL
|
||||
|
||||
creds, err := client.RefreshCopilotToken(context.Background(), "github-token")
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshCopilotToken failed: %v", err)
|
||||
}
|
||||
if creds.GitHubToken != "github-token" || creds.CopilotAccessToken != "copilot-token" {
|
||||
t.Fatalf("unexpected credentials: %#v", creds)
|
||||
}
|
||||
if creds.ExpiresAt != expiresAt {
|
||||
t.Fatalf("expected expires_at %d, got %d", expiresAt, creds.ExpiresAt)
|
||||
}
|
||||
}
|
||||
@@ -493,6 +493,12 @@ mcpServers:
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# Skills configuration (all optional)
|
||||
# no-skills: false # Set to true to disable all skill loading
|
||||
# skill: # Explicit skill files/dirs (disables auto-discovery)
|
||||
# - "/path/to/skill.md"
|
||||
# skills-dir: "/path/to/skills" # Override project-local directory for auto-discovery
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
@@ -205,6 +205,9 @@ func TestEnsureConfigExists(t *testing.T) {
|
||||
"type: \"local\"",
|
||||
"type: \"remote\"",
|
||||
"Core tools",
|
||||
"# Skills configuration",
|
||||
"no-skills:",
|
||||
"skills-dir:",
|
||||
}
|
||||
|
||||
for _, expected := range expectedSections {
|
||||
|
||||
@@ -7,32 +7,48 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// LoadAndValidateConfig loads configuration from viper, fixes environment variable
|
||||
// casing issues, and validates the configuration. Returns an error if loading or
|
||||
// validation fails.
|
||||
// LoadAndValidateConfig loads configuration from the process-global viper
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. Returns an error if loading or validation fails.
|
||||
//
|
||||
// This is a convenience wrapper around [LoadAndValidateConfigFrom] using the
|
||||
// shared global store; it is retained for the CLI and other callers that rely
|
||||
// on viper's process-global state.
|
||||
func LoadAndValidateConfig() (*Config, error) {
|
||||
return LoadAndValidateConfigFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// LoadAndValidateConfigFrom loads configuration from the supplied per-instance
|
||||
// store, fixes environment variable casing issues, and validates the
|
||||
// configuration. When v is nil, the process-global store is used. Threading an
|
||||
// explicit store lets each Kit instance own an isolated configuration without
|
||||
// clobbering other instances in the same process.
|
||||
func LoadAndValidateConfigFrom(v *viper.Viper) (*Config, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
config := &Config{
|
||||
MCPServers: make(map[string]MCPServerConfig),
|
||||
}
|
||||
if err := viper.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %v", err)
|
||||
if err := v.Unmarshal(config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
// Fix environment variable case sensitivity issue
|
||||
// Viper lowercases all keys, but we need to preserve the original case for environment variables
|
||||
fixEnvironmentCase(config)
|
||||
fixEnvironmentCase(v, config)
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// fixEnvironmentCase fixes the case of environment variable keys that were lowercased by Viper
|
||||
func fixEnvironmentCase(config *Config) {
|
||||
func fixEnvironmentCase(v *viper.Viper, config *Config) {
|
||||
// Get the raw config data from viper
|
||||
rawConfig := viper.AllSettings()
|
||||
rawConfig := v.AllSettings()
|
||||
|
||||
// Check if we have mcpServers in the raw config
|
||||
if mcpServersRaw, ok := rawConfig["mcpservers"]; ok {
|
||||
|
||||
@@ -56,9 +56,3 @@ func (e *EnvSubstituter) SubstituteEnvVars(content string) (string, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// HasEnvVars checks if content contains environment variable patterns (${env://...}).
|
||||
// This is useful for determining if substitution is needed before processing.
|
||||
func HasEnvVars(content string) bool {
|
||||
return envVarPattern.MatchString(content)
|
||||
}
|
||||
|
||||
@@ -187,41 +187,3 @@ func TestEnvSubstituter_SubstituteEnvVars(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasEnvVars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has env vars",
|
||||
content: `{"token": "${env://GITHUB_TOKEN}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has env vars with default",
|
||||
content: `{"debug": "${env://DEBUG:-false}"}`,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no env vars",
|
||||
content: `{"name": "${username}", "normal": "value"}`,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasEnvVars(tt.content)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+57
-78
@@ -59,12 +59,6 @@ func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
@@ -160,15 +154,6 @@ func rewriteSudoForStdin(command string) string {
|
||||
return result
|
||||
}
|
||||
|
||||
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
|
||||
// This is stored in tool response metadata to signal the TUI to prompt for password.
|
||||
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
|
||||
|
||||
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
|
||||
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
|
||||
return resp.Metadata == SudoPasswordRequiredMetadata
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
@@ -258,34 +243,37 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
// setupBashPipes opens stdout/stderr pipes (plus an optional sudo stdin),
|
||||
// starts the command, and asynchronously writes the sudo password if any.
|
||||
// Returns the readers ready for the caller to consume. If setup fails,
|
||||
// errResp is non-nil and the readers must not be used; the caller should
|
||||
// return the response directly.
|
||||
func setupBashPipes(cmd *exec.Cmd, sudoPassword string) (stdout, stderr io.Reader, errResp *fantasy.ToolResponse) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stdout pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stderr pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe and write the password
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
r := fantasy.NewTextErrorResponse("failed to create stdin pipe")
|
||||
return nil, nil, &r
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
r := fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err))
|
||||
return nil, nil, &r
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
@@ -293,19 +281,49 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
}()
|
||||
}
|
||||
|
||||
return stdoutPipe, stderrPipe, nil
|
||||
}
|
||||
|
||||
// interpretBashExit decodes cmd.Wait()'s error into an exit code, mapping
|
||||
// context-deadline-exceeded to a friendly "command timed out" response.
|
||||
// errResp is non-nil only when the caller should short-circuit and return
|
||||
// it directly (e.g. timeout).
|
||||
func interpretBashExit(waitErr error, cmdCtx context.Context) (exitCode int, errResp *fantasy.ToolResponse) {
|
||||
if waitErr == nil {
|
||||
return 0, nil
|
||||
}
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
return exitErr.ExitCode(), nil
|
||||
}
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
r := fantasy.NewTextErrorResponse("command timed out")
|
||||
return 0, &r
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, _ fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
var stdoutErr, stderrErr error
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stdoutErr = io.Copy(&stdout, stdoutPipe)
|
||||
_, _ = io.Copy(&stdout, stdoutPipe)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, stderrErr = io.Copy(&stderr, stderrPipe)
|
||||
_, _ = io.Copy(&stderr, stderrPipe)
|
||||
}()
|
||||
|
||||
// Wait for the process to exit first. cmd.WaitDelay ensures that if
|
||||
@@ -316,18 +334,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
// Wait for pipe readers to finish draining.
|
||||
wg.Wait()
|
||||
|
||||
// Ignore pipe read errors caused by WaitDelay force-closing —
|
||||
// we still have whatever was read before the close.
|
||||
_ = stdoutErr
|
||||
_ = stderrErr
|
||||
|
||||
exitCode := 0
|
||||
if waitErr != nil {
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
return buildBashResponse(stdout.String(), stderr.String(), exitCode)
|
||||
@@ -335,35 +344,9 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
stdoutPipe, stderrPipe, errResp := setupBashPipes(cmd, sudoPassword)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
@@ -400,20 +383,16 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
|
||||
// Wait for the process to exit. cmd.WaitDelay ensures that if pipes
|
||||
// remain open (held by grandchild processes), they'll be forcibly closed
|
||||
// after the grace period, which unblocks the scanners above.
|
||||
err = cmd.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Wait for the pipe readers to finish draining. This will complete
|
||||
// quickly since cmd.Wait() (with WaitDelay) has already ensured
|
||||
// the pipes are closed.
|
||||
wg.Wait()
|
||||
|
||||
exitCode := 0
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
exitCode = exitErr.ExitCode()
|
||||
} else if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fantasy.NewTextErrorResponse("command timed out"), nil
|
||||
}
|
||||
exitCode, errResp := interpretBashExit(waitErr, cmdCtx)
|
||||
if errResp != nil {
|
||||
return *errResp, nil
|
||||
}
|
||||
|
||||
return buildBashResponse(strings.Join(stdoutChunks, "\n"), strings.Join(stderrChunks, "\n"), exitCode)
|
||||
|
||||
@@ -183,7 +183,7 @@ func TestRewriteSudoForStdin(t *testing.T) {
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
ctx := context.WithValue(context.Background(), sudoPasswordKey, "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
|
||||
@@ -83,6 +83,9 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
|
||||
@@ -42,6 +42,9 @@ func NewLsTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeLs(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args lsArgs
|
||||
_ = parseArgs(call.Input, &args) // optional args
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ func NewReadTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args readArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
|
||||
@@ -41,6 +41,9 @@ func NewWriteTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
|
||||
func executeWrite(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return fantasy.ToolResponse{}, err
|
||||
}
|
||||
var args writeArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path and content parameters are required"), nil
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// BaseContext returns an extensions.Context populated with the headless,
|
||||
// TUI-independent delegation fields: data access, state, options,
|
||||
// model/tool management, completions, subagents, tree navigation, skills,
|
||||
// template parsing, and model resolution.
|
||||
//
|
||||
// Callers overlay their UI-specific fields (print routes, widgets, prompts,
|
||||
// editor, TUI-aware SetModel/ReloadExtensions, etc.) on the returned value:
|
||||
// cmd/extension_context.go for the interactive TUI and
|
||||
// internal/acpserver/session.go for headless ACP mode. Keeping the shared
|
||||
// half here means a new data-access Context field only has to be wired once.
|
||||
//
|
||||
// ctx is used for subagent spawns; pass a long-lived context (not a
|
||||
// per-request one) so later spawns aren't cancelled prematurely.
|
||||
func BaseContext(ctx context.Context, kitInstance *kit.Kit) extensions.Context {
|
||||
return extensions.Context{
|
||||
// -------------------------------------------------------------------
|
||||
// Data access
|
||||
// -------------------------------------------------------------------
|
||||
GetContextStats: func() extensions.ContextStats {
|
||||
s := kitInstance.GetContextStats()
|
||||
return extensions.ContextStats{
|
||||
EstimatedTokens: s.EstimatedTokens,
|
||||
ContextLimit: s.ContextLimit,
|
||||
UsagePercent: s.UsagePercent,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage {
|
||||
return kitInstance.Extensions().GetSessionMessages()
|
||||
},
|
||||
GetSessionPath: func() string {
|
||||
return kitInstance.GetSessionPath()
|
||||
},
|
||||
AppendEntry: func(entryType string, data string) (string, error) {
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Extension state
|
||||
// -------------------------------------------------------------------
|
||||
SetState: func(key string, value string) {
|
||||
kitInstance.Extensions().SetState(key, value)
|
||||
},
|
||||
GetState: func(key string) (string, bool) {
|
||||
return kitInstance.Extensions().GetState(key)
|
||||
},
|
||||
DeleteState: func(key string) {
|
||||
kitInstance.Extensions().DeleteState(key)
|
||||
},
|
||||
ListState: func() []string {
|
||||
return kitInstance.Extensions().ListState()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Options, model, and tool management
|
||||
// -------------------------------------------------------------------
|
||||
GetOption: func(name string) string {
|
||||
return kitInstance.Extensions().GetOption(name)
|
||||
},
|
||||
SetOption: func(name string, value string) {
|
||||
kitInstance.Extensions().SetOption(name, value)
|
||||
},
|
||||
// Headless model switch. The interactive TUI overrides this with a
|
||||
// version that also notifies the TUI and refreshes the usage tracker.
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry {
|
||||
return kitInstance.GetAvailableModels()
|
||||
},
|
||||
EmitCustomEvent: func(name string, data string) {
|
||||
kitInstance.Extensions().EmitCustomEvent(name, data)
|
||||
},
|
||||
GetAllTools: func() []extensions.ToolInfo {
|
||||
return kitInstance.Extensions().GetToolInfos()
|
||||
},
|
||||
SetActiveTools: func(names []string) {
|
||||
kitInstance.Extensions().SetActiveTools(names)
|
||||
},
|
||||
// Headless reload. The interactive TUI overrides this to also
|
||||
// refresh widgets/status/commands.
|
||||
ReloadExtensions: func() error {
|
||||
return kitInstance.Extensions().Reload()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// LLM completions and subagents
|
||||
// -------------------------------------------------------------------
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
return kitInstance.ExecuteCompletion(context.Background(), req)
|
||||
},
|
||||
SpawnSubagent: func(config extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
return SpawnSubagent(ctx, kitInstance, config)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Tree Navigation API
|
||||
// -------------------------------------------------------------------
|
||||
GetTreeNode: func(entryID string) *extensions.TreeNode {
|
||||
node := kitInstance.GetTreeNode(entryID)
|
||||
if node == nil {
|
||||
return nil
|
||||
}
|
||||
return &extensions.TreeNode{
|
||||
ID: node.ID,
|
||||
ParentID: node.ParentID,
|
||||
Type: node.Type,
|
||||
Role: node.Role,
|
||||
Content: node.Content,
|
||||
Model: node.Model,
|
||||
Provider: node.Provider,
|
||||
Timestamp: node.Timestamp,
|
||||
Children: node.Children,
|
||||
}
|
||||
},
|
||||
GetCurrentBranch: func() []extensions.TreeNode {
|
||||
nodes := kitInstance.GetCurrentBranch()
|
||||
result := make([]extensions.TreeNode, len(nodes))
|
||||
for i, n := range nodes {
|
||||
result[i] = extensions.TreeNode{
|
||||
ID: n.ID,
|
||||
ParentID: n.ParentID,
|
||||
Type: n.Type,
|
||||
Role: n.Role,
|
||||
Content: n.Content,
|
||||
Model: n.Model,
|
||||
Provider: n.Provider,
|
||||
Timestamp: n.Timestamp,
|
||||
Children: n.Children,
|
||||
}
|
||||
}
|
||||
return result
|
||||
},
|
||||
GetChildren: func(parentID string) []string {
|
||||
return kitInstance.GetChildren(parentID)
|
||||
},
|
||||
NavigateTo: func(entryID string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.NavigateTo(entryID)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
SummarizeBranch: func(fromID, toID string) string {
|
||||
summary, _ := kitInstance.SummarizeBranch(fromID, toID)
|
||||
return summary
|
||||
},
|
||||
CollapseBranch: func(fromID, toID, summary string) extensions.TreeNavigationResult {
|
||||
err := kitInstance.CollapseBranch(fromID, toID, summary)
|
||||
if err != nil {
|
||||
return extensions.TreeNavigationResult{Success: false, Error: err.Error()}
|
||||
}
|
||||
return extensions.TreeNavigationResult{Success: true}
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Skill Loading API (context-injection variants are TUI-specific and
|
||||
// wired by the interactive overlay)
|
||||
// -------------------------------------------------------------------
|
||||
LoadSkill: func(path string) (*extensions.Skill, string) {
|
||||
s, err := kitInstance.LoadSkillForExtension(path)
|
||||
return s, err
|
||||
},
|
||||
LoadSkillsFromDir: func(dir string) extensions.SkillLoadResult {
|
||||
return kitInstance.LoadSkillsFromDirForExtension(dir)
|
||||
},
|
||||
DiscoverSkills: func() extensions.SkillLoadResult {
|
||||
skills := kitInstance.DiscoverSkillsForExtension()
|
||||
return extensions.SkillLoadResult{Skills: skills}
|
||||
},
|
||||
GetAvailableSkills: func() []extensions.Skill {
|
||||
return kitInstance.DiscoverSkillsForExtension()
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Template Parsing API
|
||||
// -------------------------------------------------------------------
|
||||
ParseTemplate: func(name, content string) extensions.PromptTemplate {
|
||||
return kit.ParseTemplate(name, content)
|
||||
},
|
||||
RenderTemplate: func(tpl extensions.PromptTemplate, vars map[string]string) string {
|
||||
return kit.RenderTemplate(tpl, vars)
|
||||
},
|
||||
ParseArguments: func(input string, pattern extensions.ArgumentPattern) extensions.ParseResult {
|
||||
return kit.ParseArguments(input, pattern)
|
||||
},
|
||||
SimpleParseArguments: func(input string, count int) []string {
|
||||
return kit.SimpleParseArguments(input, count)
|
||||
},
|
||||
EvaluateModelConditional: func(condition string) bool {
|
||||
return kit.EvaluateModelConditional(kitInstance.Extensions().GetContext().Model, condition)
|
||||
},
|
||||
RenderWithModelConditionals: func(content string) string {
|
||||
return kit.RenderWithModelConditionals(content, kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
|
||||
// -------------------------------------------------------------------
|
||||
// Model Resolution API
|
||||
// -------------------------------------------------------------------
|
||||
ResolveModelChain: func(preferences []string) extensions.ModelResolutionResult {
|
||||
return kit.ResolveModelChain(preferences)
|
||||
},
|
||||
GetModelCapabilities: func(model string) (extensions.ModelCapabilities, string) {
|
||||
return kit.GetModelCapabilities(model)
|
||||
},
|
||||
CheckModelAvailable: func(model string) bool {
|
||||
return kit.CheckModelAvailable(model)
|
||||
},
|
||||
GetCurrentProvider: func() string {
|
||||
return kit.GetCurrentProvider(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
GetCurrentModelID: func() string {
|
||||
return kit.GetCurrentModelID(kitInstance.Extensions().GetContext().Model)
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Package extbridge wires the public Kit SDK to the internal extensions
|
||||
// package. It exists so that cmd/ and internal/acpserver/ don't both
|
||||
// reimplement the same SDK→extension event/subagent conversions.
|
||||
package extbridge
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// SDKEventToSubagentEvent converts an SDK [kit.Event] into the
|
||||
// extension-facing [extensions.SubagentEvent]. Returns a zero-value event
|
||||
// (Type=="") for events that don't map to anything useful — callers should
|
||||
// drop those.
|
||||
func SDKEventToSubagentEvent(e kit.Event) extensions.SubagentEvent {
|
||||
switch ev := e.(type) {
|
||||
case kit.MessageUpdateEvent:
|
||||
return extensions.SubagentEvent{Type: "text", Content: ev.Chunk}
|
||||
case kit.ReasoningDeltaEvent:
|
||||
return extensions.SubagentEvent{Type: "reasoning", Content: ev.Delta}
|
||||
case kit.ToolCallEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_call", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind, ToolArgs: ev.ToolArgs,
|
||||
}
|
||||
case kit.ToolExecutionStartEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_start", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolExecutionEndEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_execution_end", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
}
|
||||
case kit.ToolResultEvent:
|
||||
return extensions.SubagentEvent{
|
||||
Type: "tool_result", ToolCallID: ev.ToolCallID,
|
||||
ToolName: ev.ToolName, ToolKind: ev.ToolKind,
|
||||
ToolResult: ev.Result, IsError: ev.IsError,
|
||||
}
|
||||
case kit.TurnStartEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_start"}
|
||||
case kit.TurnEndEvent:
|
||||
return extensions.SubagentEvent{Type: "turn_end"}
|
||||
default:
|
||||
return extensions.SubagentEvent{}
|
||||
}
|
||||
}
|
||||
|
||||
// SpawnSubagent runs a subagent in-process via the Kit SDK and translates
|
||||
// the result/events back into the extension-facing types. The returned
|
||||
// handle is always nil — the SDK path runs synchronously and does not
|
||||
// expose a separate process handle. Callers that need non-blocking
|
||||
// behaviour should run this in their own goroutine.
|
||||
//
|
||||
// This function consolidates the previously-duplicated wiring in
|
||||
// cmd/root.go (interactive + runtime contexts) and
|
||||
// internal/acpserver/session.go.
|
||||
func SpawnSubagent(ctx context.Context, k *kit.Kit, cfg extensions.SubagentConfig) (*extensions.SubagentHandle, *extensions.SubagentResult, error) {
|
||||
sdkCfg := kit.SubagentConfig{
|
||||
Prompt: cfg.Prompt,
|
||||
Model: cfg.Model,
|
||||
SystemPrompt: cfg.SystemPrompt,
|
||||
Timeout: cfg.Timeout,
|
||||
NoSession: cfg.NoSession,
|
||||
Tools: k.GetToolsForSubagent(),
|
||||
}
|
||||
if cfg.OnEvent != nil {
|
||||
sdkCfg.OnEvent = func(e kit.Event) {
|
||||
se := SDKEventToSubagentEvent(e)
|
||||
if se.Type != "" {
|
||||
cfg.OnEvent(se)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result, err := k.Subagent(ctx, sdkCfg)
|
||||
if result == nil {
|
||||
return nil, &extensions.SubagentResult{Error: err}, err
|
||||
}
|
||||
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
extResult.Usage = &extensions.SubagentUsage{
|
||||
InputTokens: result.Usage.InputTokens,
|
||||
OutputTokens: result.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
return nil, extResult, err
|
||||
}
|
||||
+135
-1
@@ -341,6 +341,13 @@ type Context struct {
|
||||
// The data survives across session restarts and can be retrieved via
|
||||
// GetEntries. Use entryType to namespace your data (e.g. "myext:state").
|
||||
//
|
||||
// AppendEntry is append-only and lives in the conversation tree, which
|
||||
// makes it the right tool for audit logs and event histories. For
|
||||
// last-write-wins snapshot state — "what's the current value of X?" —
|
||||
// prefer SetState / GetState instead. Those primitives store data in a
|
||||
// sidecar file outside the conversation tree, are O(1) to read/write,
|
||||
// and do not bloat branch reads or duplicate on fork.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := json.Marshal(myState)
|
||||
@@ -360,6 +367,45 @@ type Context struct {
|
||||
// }
|
||||
GetEntries func(entryType string) []ExtensionEntry
|
||||
|
||||
// SetState stores a key-value pair in session-scoped, last-write-wins
|
||||
// extension state. Unlike AppendEntry the value is kept in a sidecar
|
||||
// file outside the conversation tree, so:
|
||||
// - reads are O(1) (no branch walk)
|
||||
// - writes don't bloat the session JSONL
|
||||
// - state is not duplicated on fork (branches share the sidecar)
|
||||
// - state is invisible to the LLM
|
||||
//
|
||||
// Use SetState for snapshot state ("current value of X"); use
|
||||
// AppendEntry for audit logs and event histories. Namespace keys with
|
||||
// your extension name to avoid collisions (e.g. "myext:budget-cap").
|
||||
//
|
||||
// State persists for the lifetime of the session. For ephemeral or
|
||||
// in-memory sessions the state lives only in memory.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.SetState("myext:budget-cap", "10.00")
|
||||
SetState func(key string, value string)
|
||||
|
||||
// GetState returns the value previously stored via SetState. The bool
|
||||
// is false when the key was never written. Returns ("", false) when
|
||||
// state is unavailable.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if cap, ok := ctx.GetState("myext:budget-cap"); ok {
|
||||
// fmt.Println("current cap:", cap)
|
||||
// }
|
||||
GetState func(key string) (string, bool)
|
||||
|
||||
// DeleteState removes a key from session-scoped extension state.
|
||||
// No-op when the key is missing.
|
||||
DeleteState func(key string)
|
||||
|
||||
// ListState returns all keys currently stored in session-scoped
|
||||
// extension state, in unspecified order.
|
||||
ListState func() []string
|
||||
|
||||
// SetEditorText sets the text content of the input editor. This can
|
||||
// be used to pre-fill the editor with suggested text (e.g. extracted
|
||||
// questions, handoff prompts). The cursor is moved to the end.
|
||||
@@ -1102,6 +1148,7 @@ type API struct {
|
||||
onError func(func(ErrorEvent, Context))
|
||||
onRetry func(func(RetryEvent, Context))
|
||||
onPrepareStep func(func(PrepareStepEvent, Context) *PrepareStepResult)
|
||||
onLLMUsage func(func(LLMUsageEvent, Context))
|
||||
}
|
||||
|
||||
// OnToolCall registers a handler that fires before a tool executes.
|
||||
@@ -1359,6 +1406,19 @@ func (a *API) OnPrepareStep(handler func(PrepareStepEvent, Context) *PrepareStep
|
||||
a.onPrepareStep(handler)
|
||||
}
|
||||
|
||||
// OnLLMUsage registers a handler that fires after each LLM provider call
|
||||
// with the token and cost deltas for that single call. Use this for
|
||||
// per-call usage attribution, real-time budget enforcement, and cost
|
||||
// dashboards that need to react between calls within a single agent turn.
|
||||
//
|
||||
// Handlers receive an LLMUsageEvent describing the call's input/output
|
||||
// tokens, cache tokens, computed cost, model, and provider. A single agent
|
||||
// turn typically fires multiple LLMUsageEvents (one per tool-loop
|
||||
// iteration).
|
||||
func (a *API) OnLLMUsage(handler func(LLMUsageEvent, Context)) {
|
||||
a.onLLMUsage(handler)
|
||||
}
|
||||
|
||||
// RegisterToolRenderer registers a custom renderer for a specific tool's
|
||||
// display in the TUI. The renderer controls the header (parameter summary)
|
||||
// and/or body (result display) of the tool's output block. If multiple
|
||||
@@ -2091,10 +2151,47 @@ type AgentStartEvent struct {
|
||||
|
||||
func (e AgentStartEvent) Type() EventType { return AgentStart }
|
||||
|
||||
// AgentEndEvent fires when the agent finishes responding.
|
||||
// AgentEndEvent fires when the agent finishes responding. In addition to the
|
||||
// final response and stop reason, the event carries per-turn aggregates so
|
||||
// observer-style extensions don't have to maintain parallel bookkeeping in
|
||||
// OnToolResult / OnStepFinish handlers.
|
||||
type AgentEndEvent struct {
|
||||
Response string
|
||||
StopReason string // "completed", "cancelled", "error"
|
||||
|
||||
// ToolCallCount is the total number of tool invocations observed during
|
||||
// this turn (sum across all steps).
|
||||
ToolCallCount int
|
||||
|
||||
// ToolNames lists the tool names invoked during this turn, in call order.
|
||||
// Duplicates are preserved (e.g. two bash calls produce ["bash", "bash"]).
|
||||
ToolNames []string
|
||||
|
||||
// LLMCallCount is the number of LLM round-trips (tool-loop iterations)
|
||||
// performed during this turn. Always >= 1 for a successful turn.
|
||||
LLMCallCount int
|
||||
|
||||
// InputTokensDelta is the sum of input tokens consumed during this turn
|
||||
// across every LLM call (including cache-hit input tokens).
|
||||
InputTokensDelta int
|
||||
|
||||
// OutputTokensDelta is the sum of output tokens generated during this turn.
|
||||
OutputTokensDelta int
|
||||
|
||||
// CacheReadTokensDelta is the sum of cache-read tokens during this turn.
|
||||
CacheReadTokensDelta int
|
||||
|
||||
// CacheWriteTokensDelta is the sum of cache-write tokens during this turn.
|
||||
CacheWriteTokensDelta int
|
||||
|
||||
// CostDelta is the total cost in USD attributable to this turn. Computed
|
||||
// from per-step usage and current model pricing. Zero when pricing is
|
||||
// unknown or OAuth credentials are in use.
|
||||
CostDelta float64
|
||||
|
||||
// DurationMs is the elapsed wall-clock time from AgentStart to AgentEnd,
|
||||
// in milliseconds.
|
||||
DurationMs int64
|
||||
}
|
||||
|
||||
func (e AgentEndEvent) Type() EventType { return AgentEnd }
|
||||
@@ -2403,6 +2500,43 @@ type PrepareStepResult struct {
|
||||
|
||||
func (PrepareStepResult) isResult() {}
|
||||
|
||||
// LLMUsageEvent fires after each LLM provider call with the per-call token
|
||||
// and cost deltas. Use this for accurate budget tracking, cost dashboards,
|
||||
// and any logic that needs to react between LLM calls within a single agent
|
||||
// turn (rather than only at turn boundaries).
|
||||
//
|
||||
// A single agent turn typically produces multiple LLMUsageEvents (one per
|
||||
// tool-loop iteration). The Model and Provider fields reflect the model used
|
||||
// for that specific call, which may differ from earlier calls if the
|
||||
// extension switched models mid-turn via ctx.SetModel().
|
||||
type LLMUsageEvent struct {
|
||||
// InputTokens is the number of input tokens for this call.
|
||||
InputTokens int
|
||||
// OutputTokens is the number of output tokens generated by this call.
|
||||
OutputTokens int
|
||||
// CacheReadTokens is the number of cache-hit input tokens (provider-specific).
|
||||
CacheReadTokens int
|
||||
// CacheWriteTokens is the number of cache-write tokens.
|
||||
CacheWriteTokens int
|
||||
// Cost is the USD cost of this call computed from the model's per-token
|
||||
// pricing. Zero when pricing is unknown or OAuth credentials are in use.
|
||||
Cost float64
|
||||
// Model is the model identifier used for this call (e.g. "claude-sonnet-4-5-20250929").
|
||||
Model string
|
||||
// Provider is the provider identifier (e.g. "anthropic", "openai").
|
||||
Provider string
|
||||
// RequestID is an optional correlation id for the underlying provider
|
||||
// call. May be empty when the provider does not surface one.
|
||||
RequestID string
|
||||
// StepNumber is the zero-based step index within the current agent turn.
|
||||
StepNumber int
|
||||
// FinishReason mirrors the provider's finish reason for this call
|
||||
// (e.g. "stop", "tool_calls", "length"). May be empty.
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
func (e LLMUsageEvent) Type() EventType { return LLMUsage }
|
||||
|
||||
// ThemeColor is an adaptive color pair with light and dark hex values.
|
||||
// Either field may be empty to inherit from the default theme.
|
||||
type ThemeColor struct {
|
||||
|
||||
@@ -125,6 +125,11 @@ const (
|
||||
// after steering messages are injected and before messages are sent
|
||||
// to the LLM. Handlers can replace the context window for this step.
|
||||
PrepareStep EventType = "prepare_step"
|
||||
|
||||
// LLMUsage fires after each LLM provider call with the token and cost
|
||||
// deltas for that single call. Extensions use it to attribute usage to
|
||||
// specific calls/models and to drive budget enforcement between calls.
|
||||
LLMUsage EventType = "llm_usage"
|
||||
)
|
||||
|
||||
// AllEventTypes returns every supported event type.
|
||||
@@ -139,7 +144,7 @@ func AllEventTypes() []EventType {
|
||||
BeforeFork, BeforeSessionSwitch, BeforeCompact,
|
||||
SubagentStart, SubagentChunk, SubagentEnd,
|
||||
StepStart, StepFinish, ReasoningStart, Warnings, Source, Error, Retry,
|
||||
PrepareStep,
|
||||
PrepareStep, LLMUsage,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import "testing"
|
||||
|
||||
func TestAllEventTypes_Count(t *testing.T) {
|
||||
all := AllEventTypes()
|
||||
if len(all) != 32 {
|
||||
t.Fatalf("expected 32 event types, got %d", len(all))
|
||||
if len(all) != 33 {
|
||||
t.Fatalf("expected 33 event types, got %d", len(all))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -450,25 +450,6 @@ func globalGitInstallRoot() string {
|
||||
return filepath.Join(base, "kit", "git")
|
||||
}
|
||||
|
||||
// GetInstalledPackages returns all installed packages from both scopes.
|
||||
func (i *Installer) GetInstalledPackages() ([]ManifestEntry, error) {
|
||||
var all []ManifestEntry
|
||||
|
||||
global, err := i.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading global manifest: %w", err)
|
||||
}
|
||||
all = append(all, global.Packages...)
|
||||
|
||||
project, err := i.loadManifest(ScopeProject)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading project manifest: %w", err)
|
||||
}
|
||||
all = append(all, project.Packages...)
|
||||
|
||||
return all, nil
|
||||
}
|
||||
|
||||
// IsInstalled checks if a package is installed in either scope.
|
||||
// Returns (scope, true) if installed, ("", false) otherwise.
|
||||
func (i *Installer) IsInstalled(source *GitSource) (InstallScope, bool) {
|
||||
|
||||
@@ -245,14 +245,21 @@ func TestManifestEntryIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadAndSaveManifest exercises the live *Installer.loadManifest /
|
||||
// saveManifest round-trip against a temp directory, ensuring an absent
|
||||
// manifest loads as empty and a saved manifest reads back identically.
|
||||
func TestLoadAndSaveManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
manifestPath := filepath.Join(tempDir, "packages.json")
|
||||
|
||||
// Test loading non-existent manifest
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected empty packages, got %d", len(manifest.Packages))
|
||||
@@ -273,15 +280,20 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Save it
|
||||
err = saveManifestToPath(manifest, manifestPath)
|
||||
err = installer.saveManifest(manifest, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("saveManifestToPath() error = %v", err)
|
||||
t.Fatalf("saveManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was written to expected path
|
||||
if _, err := os.Stat(manifestPath); err != nil {
|
||||
t.Fatalf("manifest file not created: %v", err)
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded, err := loadManifestFromPath(manifestPath)
|
||||
loaded, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(loaded.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(loaded.Packages))
|
||||
@@ -291,21 +303,15 @@ func TestLoadAndSaveManifest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddAndRemoveFromManifest verifies that *Installer.addToManifest
|
||||
// followed by removeFromManifest leaves the manifest in its original
|
||||
// (empty) state, using a temp-directory installer scope.
|
||||
func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Set up environment for manifest path
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
installer := &Installer{
|
||||
projectGitRoot: tempDir,
|
||||
globalGitRoot: tempDir,
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// The manifest path when XDG_DATA_HOME is set
|
||||
manifestPath := filepath.Join(tempDir, "kit", "git", "packages.json")
|
||||
|
||||
// Add an entry
|
||||
entry := ManifestEntry{
|
||||
@@ -315,58 +321,51 @@ func TestAddAndRemoveFromManifest(t *testing.T) {
|
||||
Scope: ScopeGlobal,
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
if err := installer.addToManifest(entry, ScopeGlobal); err != nil {
|
||||
t.Fatalf("addToManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was added
|
||||
manifest, err := loadManifestFromPath(manifestPath)
|
||||
manifest, err := installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 1 {
|
||||
t.Errorf("Expected 1 package, got %d", len(manifest.Packages))
|
||||
}
|
||||
|
||||
// Remove it
|
||||
err = removeEntryFromManifest("github.com/user/repo", ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("removeEntryFromManifest() error = %v", err)
|
||||
if err := installer.removeFromManifest("github.com/user/repo", ScopeGlobal); err != nil {
|
||||
t.Fatalf("removeFromManifest() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify it was removed
|
||||
manifest, err = loadManifestFromPath(manifestPath)
|
||||
manifest, err = installer.loadManifest(ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("loadManifestFromPath() error = %v", err)
|
||||
t.Fatalf("loadManifest() error = %v", err)
|
||||
}
|
||||
if len(manifest.Packages) != 0 {
|
||||
t.Errorf("Expected 0 packages, got %d", len(manifest.Packages))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFindInManifest writes a manifest file directly to the path
|
||||
// resolved by the package-level manifestPathForScope helper and then
|
||||
// confirms FindInManifest locates the entry by identity (and returns
|
||||
// nil for a non-existent identity).
|
||||
func TestFindInManifest(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
if err := os.Setenv("XDG_DATA_HOME", tempDir); err != nil {
|
||||
t.Fatalf("Setenv() error = %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Unsetenv("XDG_DATA_HOME"); err != nil {
|
||||
t.Logf("Unsetenv() error = %v", err)
|
||||
}
|
||||
}()
|
||||
t.Setenv("XDG_DATA_HOME", tempDir)
|
||||
|
||||
// Add an entry to global manifest
|
||||
entry := ManifestEntry{
|
||||
Source: "git:github.com/user/repo",
|
||||
Host: "github.com",
|
||||
Path: "user/repo",
|
||||
Scope: ScopeGlobal,
|
||||
// Write a manifest entry directly via the package-level path resolver
|
||||
// so FindInManifest (which uses manifestPathForScope) can read it back.
|
||||
manifestPath := manifestPathForScope(ScopeGlobal)
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
err := addEntryToManifest(entry, ScopeGlobal)
|
||||
if err != nil {
|
||||
t.Fatalf("addEntryToManifest() error = %v", err)
|
||||
data := []byte(`{"packages":[{"source":"git:github.com/user/repo","repo":"","host":"github.com","path":"user/repo","pinned":false,"scope":"global","installed":"0001-01-01T00:00:00Z"}]}`)
|
||||
if err := os.WriteFile(manifestPath, data, 0644); err != nil {
|
||||
t.Fatalf("WriteFile() error = %v", err)
|
||||
}
|
||||
|
||||
// Find it
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
package extensions
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRunner_EmitLLMUsage(t *testing.T) {
|
||||
var got LLMUsageEvent
|
||||
var called bool
|
||||
ext := makeHandlerExt("llmusage.go", map[EventType][]HandlerFunc{
|
||||
LLMUsage: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(LLMUsageEvent)
|
||||
called = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(LLMUsageEvent{
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
Cost: 0.0012,
|
||||
Model: "claude-sonnet-4-5-20250929",
|
||||
Provider: "anthropic",
|
||||
StepNumber: 2,
|
||||
FinishReason: "tool_calls",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Fatal("expected LLMUsage handler to be called")
|
||||
}
|
||||
if got.InputTokens != 100 || got.OutputTokens != 50 {
|
||||
t.Errorf("token fields not propagated: %+v", got)
|
||||
}
|
||||
if got.Cost != 0.0012 {
|
||||
t.Errorf("cost not propagated, got %v", got.Cost)
|
||||
}
|
||||
if got.Model != "claude-sonnet-4-5-20250929" || got.Provider != "anthropic" {
|
||||
t.Errorf("model/provider not propagated: %+v", got)
|
||||
}
|
||||
if got.StepNumber != 2 || got.FinishReason != "tool_calls" {
|
||||
t.Errorf("step/finish reason not propagated: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_LLMUsageRegisteredViaTestAPI(t *testing.T) {
|
||||
// Verify NewTestAPI wires up onLLMUsage so the extension can call
|
||||
// api.OnLLMUsage during Init.
|
||||
ext := &LoadedExtension{Handlers: make(map[EventType][]HandlerFunc)}
|
||||
api := NewTestAPI(ext)
|
||||
|
||||
var calls int
|
||||
api.OnLLMUsage(func(e LLMUsageEvent, c Context) {
|
||||
calls++
|
||||
})
|
||||
|
||||
if len(ext.Handlers[LLMUsage]) != 1 {
|
||||
t.Fatalf("expected 1 LLMUsage handler registered, got %d", len(ext.Handlers[LLMUsage]))
|
||||
}
|
||||
|
||||
r := makeRunner(*ext)
|
||||
_, _ = r.Emit(LLMUsageEvent{InputTokens: 1})
|
||||
if calls != 1 {
|
||||
t.Errorf("expected handler called once, got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentEndEvent_EnrichedFields(t *testing.T) {
|
||||
// Verify the enriched event carries through Emit without mangling.
|
||||
var got AgentEndEvent
|
||||
ext := makeHandlerExt("end.go", map[EventType][]HandlerFunc{
|
||||
AgentEnd: {
|
||||
func(e Event, c Context) Result {
|
||||
got = e.(AgentEndEvent)
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
_, err := r.Emit(AgentEndEvent{
|
||||
Response: "done",
|
||||
StopReason: "completed",
|
||||
ToolCallCount: 3,
|
||||
ToolNames: []string{"bash", "read", "bash"},
|
||||
LLMCallCount: 4,
|
||||
InputTokensDelta: 1500,
|
||||
OutputTokensDelta: 400,
|
||||
CacheReadTokensDelta: 200,
|
||||
CacheWriteTokensDelta: 100,
|
||||
CostDelta: 0.0123,
|
||||
DurationMs: 2500,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
if got.ToolCallCount != 3 {
|
||||
t.Errorf("ToolCallCount: got %d want 3", got.ToolCallCount)
|
||||
}
|
||||
if len(got.ToolNames) != 3 || got.ToolNames[0] != "bash" || got.ToolNames[2] != "bash" {
|
||||
t.Errorf("ToolNames: %v", got.ToolNames)
|
||||
}
|
||||
if got.LLMCallCount != 4 {
|
||||
t.Errorf("LLMCallCount: got %d want 4", got.LLMCallCount)
|
||||
}
|
||||
if got.InputTokensDelta != 1500 || got.OutputTokensDelta != 400 {
|
||||
t.Errorf("token deltas: %+v", got)
|
||||
}
|
||||
if got.CacheReadTokensDelta != 200 || got.CacheWriteTokensDelta != 100 {
|
||||
t.Errorf("cache deltas: %+v", got)
|
||||
}
|
||||
if got.CostDelta != 0.0123 {
|
||||
t.Errorf("CostDelta: got %v", got.CostDelta)
|
||||
}
|
||||
if got.DurationMs != 2500 {
|
||||
t.Errorf("DurationMs: got %d", got.DurationMs)
|
||||
}
|
||||
}
|
||||
@@ -669,6 +669,12 @@ func loadSingleExtension(path string) (*LoadedExtension, error) {
|
||||
return *r
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// Call Init — the extension registers its handlers, tools, commands.
|
||||
|
||||
@@ -72,30 +72,6 @@ func loadManifestFromPath(path string) (*Manifest, error) {
|
||||
return &manifest, nil
|
||||
}
|
||||
|
||||
// saveManifestToScope saves the manifest to the given scope.
|
||||
func saveManifestToScope(manifest *Manifest, scope InstallScope) error {
|
||||
path := manifestPathForScope(scope)
|
||||
return saveManifestToPath(manifest, path)
|
||||
}
|
||||
|
||||
// saveManifestToPath saves a manifest to a specific file path.
|
||||
func saveManifestToPath(manifest *Manifest, path string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return fmt.Errorf("creating manifest directory: %w", err)
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(manifest, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("encoding manifest: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("writing manifest: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// manifestPathForScope returns the manifest file path for a scope.
|
||||
func manifestPathForScope(scope InstallScope) string {
|
||||
if scope == ScopeProject {
|
||||
@@ -113,55 +89,6 @@ func manifestPathForScope(scope InstallScope) string {
|
||||
return filepath.Join(base, "kit", "git", "packages.json")
|
||||
}
|
||||
|
||||
// GetGlobalManifest returns the global manifest.
|
||||
func GetGlobalManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeGlobal)
|
||||
}
|
||||
|
||||
// GetProjectManifest returns the project manifest.
|
||||
func GetProjectManifest() (*Manifest, error) {
|
||||
return loadManifestFromScope(ScopeProject)
|
||||
}
|
||||
|
||||
// addEntryToManifest adds or replaces an entry in the manifest for a scope.
|
||||
func addEntryToManifest(entry ManifestEntry, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove any existing entry with same identity
|
||||
identity := entry.Identity()
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, entry)
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// removeEntryFromManifest removes an entry by identity from the manifest for a scope.
|
||||
func removeEntryFromManifest(identity string, scope InstallScope) error {
|
||||
manifest, err := loadManifestFromScope(scope)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filtered := make([]ManifestEntry, 0, len(manifest.Packages))
|
||||
for _, p := range manifest.Packages {
|
||||
if p.Identity() != identity {
|
||||
filtered = append(filtered, p)
|
||||
}
|
||||
}
|
||||
manifest.Packages = filtered
|
||||
|
||||
return saveManifestToScope(manifest, scope)
|
||||
}
|
||||
|
||||
// FindInManifest finds an entry by identity in either global or project manifest.
|
||||
// Returns the entry and its scope, or nil if not found.
|
||||
func FindInManifest(identity string) (*ManifestEntry, InstallScope, error) {
|
||||
|
||||
@@ -2,9 +2,12 @@ package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -98,9 +101,24 @@ type Runner struct {
|
||||
disabledTools map[string]bool // nil = all tools enabled
|
||||
customEventSubs map[string][]func(string) // inter-extension event bus
|
||||
optionOverrides map[string]string // runtime option overrides
|
||||
configStore *viper.Viper // per-instance config store (nil = global)
|
||||
state map[string]string // session-scoped extension state (last-write-wins)
|
||||
stateMu sync.RWMutex // guards state independently of mu
|
||||
saverMu sync.Mutex // serializes stateSaver invocations so atomic-rename writes don't interleave
|
||||
stateSaver func() // optional persistence hook invoked after each state mutation
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetConfigStore sets the per-instance configuration store used by GetOption
|
||||
// to resolve "options.<name>" config values. When unset (nil), GetOption falls
|
||||
// back to the process-global viper store. Threading a per-Kit store keeps
|
||||
// extension option resolution isolated between Kit instances.
|
||||
func (r *Runner) SetConfigStore(v *viper.Viper) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.configStore = v
|
||||
}
|
||||
|
||||
// ShortcutEntry pairs a shortcut definition with its handler.
|
||||
type ShortcutEntry struct {
|
||||
Def ShortcutDef
|
||||
@@ -253,6 +271,18 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.SetState == nil {
|
||||
ctx.SetState = func(string, string) {}
|
||||
}
|
||||
if ctx.GetState == nil {
|
||||
ctx.GetState = func(string) (string, bool) { return "", false }
|
||||
}
|
||||
if ctx.DeleteState == nil {
|
||||
ctx.DeleteState = func(string) {}
|
||||
}
|
||||
if ctx.ListState == nil {
|
||||
ctx.ListState = func() []string { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
@@ -734,6 +764,168 @@ func (r *Runner) GetMessageRenderer(name string) *MessageRendererConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Extension state store (session-scoped, last-write-wins)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SetState records a key-value pair in the runner's session-scoped extension
|
||||
// state store. The store is in-memory; callers wire SetStateSaver to persist
|
||||
// changes to a sidecar file. Thread-safe.
|
||||
//
|
||||
// When a saver is installed, concurrent SetState/DeleteState invocations are
|
||||
// serialized through saverMu so that overlapping snapshot-and-rename writes
|
||||
// cannot interleave (which would otherwise race on the shared tmp file and
|
||||
// risk persisting an older snapshot after a newer one).
|
||||
func (r *Runner) SetState(key, value string) {
|
||||
r.stateMu.Lock()
|
||||
if r.state == nil {
|
||||
r.state = make(map[string]string)
|
||||
}
|
||||
r.state[key] = value
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// GetState returns the value previously stored via SetState, plus a bool
|
||||
// indicating whether the key was present. Thread-safe.
|
||||
func (r *Runner) GetState(key string) (string, bool) {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
v, ok := r.state[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// DeleteState removes a key from the state store. No-op if the key is
|
||||
// missing. Thread-safe. Saver invocations are serialized via saverMu — see
|
||||
// SetState for the rationale.
|
||||
func (r *Runner) DeleteState(key string) {
|
||||
r.stateMu.Lock()
|
||||
_, existed := r.state[key]
|
||||
if existed {
|
||||
delete(r.state, key)
|
||||
}
|
||||
saver := r.stateSaver
|
||||
r.stateMu.Unlock()
|
||||
if !existed {
|
||||
return
|
||||
}
|
||||
r.runSaver(saver)
|
||||
}
|
||||
|
||||
// runSaver invokes the optional persistence callback under saverMu so
|
||||
// concurrent SetState/DeleteState writers cannot race on the shared tmp
|
||||
// file used by SaveStateToFile's atomic rename. The deferred Unlock
|
||||
// guarantees saverMu is released even if the saver panics.
|
||||
func (r *Runner) runSaver(saver func()) {
|
||||
if saver == nil {
|
||||
return
|
||||
}
|
||||
r.saverMu.Lock()
|
||||
defer r.saverMu.Unlock()
|
||||
saver()
|
||||
}
|
||||
|
||||
// ListState returns all keys currently in the state store, in unspecified
|
||||
// order. Thread-safe.
|
||||
func (r *Runner) ListState() []string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
keys := make([]string, 0, len(r.state))
|
||||
for k := range r.state {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// SetStateSaver installs an optional persistence hook invoked after each
|
||||
// mutation to the state store (SetState / DeleteState / LoadStateFromFile).
|
||||
// Pass nil to disable persistence. Thread-safe.
|
||||
func (r *Runner) SetStateSaver(saver func()) {
|
||||
r.stateMu.Lock()
|
||||
defer r.stateMu.Unlock()
|
||||
r.stateSaver = saver
|
||||
}
|
||||
|
||||
// SnapshotState returns a copy of the current state store as a
|
||||
// fresh map. Useful for persisting to disk without holding the lock.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SnapshotState() map[string]string {
|
||||
r.stateMu.RLock()
|
||||
defer r.stateMu.RUnlock()
|
||||
if len(r.state) == 0 {
|
||||
return nil
|
||||
}
|
||||
copyMap := make(map[string]string, len(r.state))
|
||||
maps.Copy(copyMap, r.state)
|
||||
return copyMap
|
||||
}
|
||||
|
||||
// LoadStateFromFile reads a JSON map from path and replaces the in-memory
|
||||
// state store with its contents. Missing or empty files are treated as
|
||||
// "no prior state": the in-memory store is replaced with an empty map so
|
||||
// callers can safely switch sessions without leaking keys from a prior
|
||||
// session into a new one. Malformed JSON returns the parse error without
|
||||
// touching the existing store. Thread-safe.
|
||||
func (r *Runner) LoadStateFromFile(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading extension state: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
r.stateMu.Lock()
|
||||
r.state = map[string]string{}
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
var loaded map[string]string
|
||||
if err := json.Unmarshal(data, &loaded); err != nil {
|
||||
return fmt.Errorf("parsing extension state: %w", err)
|
||||
}
|
||||
r.stateMu.Lock()
|
||||
r.state = loaded
|
||||
r.stateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveStateToFile writes the current state store to path as JSON, creating
|
||||
// parent directories as needed. An empty store writes an empty object so
|
||||
// that consumers can distinguish "loaded but empty" from "never saved".
|
||||
// Writes are atomic via a tmp-file-and-rename sequence. Thread-safe.
|
||||
func (r *Runner) SaveStateToFile(path string) error {
|
||||
snap := r.SnapshotState()
|
||||
if snap == nil {
|
||||
snap = map[string]string{}
|
||||
}
|
||||
data, err := json.MarshalIndent(snap, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling extension state: %w", err)
|
||||
}
|
||||
if dir := filepath.Dir(path); dir != "." && dir != "" {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return fmt.Errorf("creating state directory: %w", err)
|
||||
}
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
||||
return fmt.Errorf("writing extension state: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmp, path); err != nil {
|
||||
_ = os.Remove(tmp)
|
||||
return fmt.Errorf("renaming extension state: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hot-reload
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -757,7 +949,9 @@ func (r *Runner) Reload(exts []LoadedExtension) {
|
||||
r.uiVisibility = nil
|
||||
r.disabledTools = nil
|
||||
r.customEventSubs = nil
|
||||
// optionOverrides are intentionally preserved.
|
||||
// optionOverrides and state are intentionally preserved across reloads:
|
||||
// they represent user/session intent (not extension code) and would be
|
||||
// surprising to lose on a hot-reload.
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -872,7 +1066,13 @@ func (r *Runner) GetOption(name string) string {
|
||||
|
||||
// 3. Viper config: options.<name>
|
||||
configKey := "options." + name
|
||||
if v := viper.GetString(configKey); v != "" {
|
||||
r.mu.RLock()
|
||||
store := r.configStore
|
||||
r.mu.RUnlock()
|
||||
if store == nil {
|
||||
store = viper.GetViper()
|
||||
}
|
||||
if v := store.GetString(configKey); v != "" {
|
||||
return v
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,262 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunner_State_BasicSetGetDelete(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
|
||||
if _, ok := r.GetState("missing"); ok {
|
||||
t.Fatal("expected GetState to return ok=false for missing key")
|
||||
}
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("b", "2")
|
||||
r.SetState("a", "3") // last-write-wins
|
||||
|
||||
if v, ok := r.GetState("a"); !ok || v != "3" {
|
||||
t.Errorf("expected GetState(a)=(3,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r.GetState("b"); !ok || v != "2" {
|
||||
t.Errorf("expected GetState(b)=(2,true), got (%q,%v)", v, ok)
|
||||
}
|
||||
|
||||
keys := r.ListState()
|
||||
if len(keys) != 2 {
|
||||
t.Errorf("expected 2 keys, got %d (%v)", len(keys), keys)
|
||||
}
|
||||
|
||||
r.DeleteState("a")
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected key a to be gone after DeleteState")
|
||||
}
|
||||
if len(r.ListState()) != 1 {
|
||||
t.Errorf("expected 1 key after delete, got %v", r.ListState())
|
||||
}
|
||||
|
||||
// Deleting missing key is a no-op.
|
||||
r.DeleteState("never-there")
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverFires(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
var mu sync.Mutex
|
||||
r.SetStateSaver(func() {
|
||||
mu.Lock()
|
||||
calls++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
r.SetState("a", "1")
|
||||
r.SetState("a", "2")
|
||||
r.DeleteState("a")
|
||||
r.DeleteState("a") // missing → no save
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if calls != 3 {
|
||||
t.Errorf("expected saver to fire 3 times (2 sets + 1 delete), got %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaveAndLoadRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r1 := NewRunner(nil)
|
||||
r1.SetState("k1", "v1")
|
||||
r1.SetState("k2", `{"json":"value"}`)
|
||||
if err := r1.SaveStateToFile(path); err != nil {
|
||||
t.Fatalf("SaveStateToFile: %v", err)
|
||||
}
|
||||
|
||||
// Verify file contains JSON map.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["k1"] != "v1" || parsed["k2"] != `{"json":"value"}` {
|
||||
t.Errorf("unexpected file contents: %v", parsed)
|
||||
}
|
||||
|
||||
r2 := NewRunner(nil)
|
||||
if err := r2.LoadStateFromFile(path); err != nil {
|
||||
t.Fatalf("LoadStateFromFile: %v", err)
|
||||
}
|
||||
if v, ok := r2.GetState("k1"); !ok || v != "v1" {
|
||||
t.Errorf("expected k1=v1 after load, got (%q,%v)", v, ok)
|
||||
}
|
||||
if v, ok := r2.GetState("k2"); !ok || v != `{"json":"value"}` {
|
||||
t.Errorf("expected k2 to round-trip, got %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMissingFileClearsState(t *testing.T) {
|
||||
// LoadStateFromFile is documented to "replace the in-memory state store
|
||||
// with its contents"; for a missing file that means clearing the store.
|
||||
// This is what makes session-switching safe: a new session that has not
|
||||
// yet written a sidecar must not inherit keys from a prior session.
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(filepath.Join(t.TempDir(), "does-not-exist.json")); err != nil {
|
||||
t.Errorf("expected nil error for missing file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is missing")
|
||||
}
|
||||
if keys := r.ListState(); keys != nil {
|
||||
t.Errorf("expected ListState() to be nil after clearing, got %v", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadEmptyFileClearsState(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.json")
|
||||
if err := os.WriteFile(path, nil, 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
r.SetState("a", "1")
|
||||
if err := r.LoadStateFromFile(path); err != nil {
|
||||
t.Errorf("expected nil error for empty file, got %v", err)
|
||||
}
|
||||
if _, ok := r.GetState("a"); ok {
|
||||
t.Error("expected pre-existing state to be cleared when target file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_LoadMalformedFileError(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad.json")
|
||||
if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := NewRunner(nil)
|
||||
if err := r.LoadStateFromFile(path); err == nil {
|
||||
t.Error("expected error loading malformed JSON, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_PersistenceViaSaver(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "ext-state.json")
|
||||
|
||||
r := NewRunner(nil)
|
||||
r.SetStateSaver(func() {
|
||||
_ = r.SaveStateToFile(path)
|
||||
})
|
||||
r.SetState("hello", "world")
|
||||
|
||||
// File should exist with the value already.
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("reading saved file: %v", err)
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("unmarshalling: %v", err)
|
||||
}
|
||||
if parsed["hello"] != "world" {
|
||||
t.Errorf("expected file to contain hello=world, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ConcurrentSet(t *testing.T) {
|
||||
r := NewRunner(nil)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 16
|
||||
const iterations = 100
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
r.SetState("k", "v")
|
||||
_, _ = r.GetState("k")
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if v, ok := r.GetState("k"); !ok || v != "v" {
|
||||
t.Errorf("expected k=v after concurrent writes, got (%q,%v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_ContextNoOpsWhenUnset(t *testing.T) {
|
||||
// Verify normalizeContext installs safe no-ops for SetState/GetState/etc.
|
||||
// when not provided by the caller.
|
||||
ext := makeHandlerExt("state.go", map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// All four state functions should be non-nil and safe to call.
|
||||
c.SetState("a", "b")
|
||||
if v, ok := c.GetState("a"); ok || v != "" {
|
||||
t.Errorf("no-op GetState should return (\"\", false); got (%q,%v)", v, ok)
|
||||
}
|
||||
c.DeleteState("a")
|
||||
if keys := c.ListState(); keys != nil {
|
||||
t.Errorf("no-op ListState should return nil; got %v", keys)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
r := makeRunner(ext)
|
||||
// SetContext with empty Context to exercise normalizeContext defaults.
|
||||
r.SetContext(Context{})
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("emit: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_State_SaverPanicReleasesSaverMu(t *testing.T) {
|
||||
// If the saver callback panics (e.g. disk full mid-write), runSaver
|
||||
// must still release saverMu so subsequent SetState/DeleteState calls
|
||||
// can make progress. Without `defer Unlock()` the lock would be
|
||||
// permanently held and the next write would deadlock.
|
||||
r := NewRunner(nil)
|
||||
var calls int
|
||||
r.SetStateSaver(func() {
|
||||
calls++
|
||||
if calls == 1 {
|
||||
panic("simulated disk-write failure")
|
||||
}
|
||||
})
|
||||
|
||||
// First call panics. Recover, then verify a follow-up call still works
|
||||
// without blocking (proving saverMu was released).
|
||||
func() {
|
||||
defer func() {
|
||||
if rec := recover(); rec == nil {
|
||||
t.Fatal("expected panic from first saver invocation")
|
||||
}
|
||||
}()
|
||||
r.SetState("a", "1")
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.SetState("b", "2") // would deadlock if saverMu were still held
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("SetState after saver panic blocked — saverMu was not released")
|
||||
}
|
||||
if calls != 2 {
|
||||
t.Errorf("expected saver to fire twice (panic + recovery write), got %d", calls)
|
||||
}
|
||||
}
|
||||
@@ -2,22 +2,15 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Subagent types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentConfig configures a subagent spawn.
|
||||
type SubagentConfig struct {
|
||||
// Prompt is the task/instruction for the subagent (required).
|
||||
@@ -157,221 +150,3 @@ func (h *SubagentHandle) Wait() SubagentResult {
|
||||
func (h *SubagentHandle) Done() <-chan struct{} {
|
||||
return h.done
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// subagentJSONOutput matches the JSON envelope produced by `kit --json`.
|
||||
type subagentJSONOutput struct {
|
||||
Response string `json:"response"`
|
||||
StopReason string `json:"stop_reason,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Usage *struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func findKitBinary() string {
|
||||
// Try the current process executable first.
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if _, err := os.Stat(exe); err == nil {
|
||||
return exe
|
||||
}
|
||||
}
|
||||
// Fall back to PATH lookup.
|
||||
if p, err := exec.LookPath("kit"); err == nil {
|
||||
return p
|
||||
}
|
||||
return "kit"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SpawnSubagent implementation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SpawnSubagent spawns a child Kit instance to perform a task.
|
||||
//
|
||||
// When config.Blocking is true, blocks until completion and returns the result
|
||||
// directly (handle is nil). When false, returns immediately with a handle for
|
||||
// monitoring/cancellation.
|
||||
//
|
||||
// The subagent runs with --json --no-session --no-extensions flags by default,
|
||||
// ensuring isolation from the parent's extensions and session state.
|
||||
func SpawnSubagent(cfg SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
if cfg.Prompt == "" {
|
||||
return nil, nil, fmt.Errorf("prompt is required")
|
||||
}
|
||||
|
||||
timeout := cfg.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 5 * time.Minute
|
||||
}
|
||||
|
||||
kitBinary := findKitBinary()
|
||||
|
||||
// Build subprocess arguments.
|
||||
args := []string{
|
||||
"--json",
|
||||
"--no-extensions",
|
||||
}
|
||||
if cfg.NoSession {
|
||||
args = append(args, "--no-session")
|
||||
}
|
||||
if cfg.Model != "" {
|
||||
args = append(args, "--model", cfg.Model)
|
||||
}
|
||||
|
||||
// Handle system prompt - write to temp file if provided.
|
||||
var tmpFile *os.File
|
||||
if cfg.SystemPrompt != "" {
|
||||
var err error
|
||||
tmpFile, err = os.CreateTemp("", "kit-subagent-*.txt")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
if _, err := tmpFile.WriteString(cfg.SystemPrompt); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return nil, nil, fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
args = append(args, "--system-prompt", tmpFile.Name())
|
||||
}
|
||||
|
||||
// Add the prompt as a positional argument.
|
||||
args = append(args, cfg.Prompt)
|
||||
|
||||
// Create command with timeout context.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
|
||||
cmd := exec.CommandContext(ctx, kitBinary, args...)
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
handle := &SubagentHandle{
|
||||
ID: generateSubagentID(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the subprocess.
|
||||
start := time.Now()
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
if tmpFile != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
}
|
||||
return nil, nil, fmt.Errorf("start subprocess: %w", err)
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.proc = cmd.Process
|
||||
handle.mu.Unlock()
|
||||
|
||||
// Run the subprocess monitoring in a goroutine.
|
||||
go func() {
|
||||
defer close(handle.done)
|
||||
defer cancel()
|
||||
if tmpFile != nil {
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var stdoutBuf strings.Builder
|
||||
|
||||
// Read stderr (live output).
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderr)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if cfg.OnOutput != nil && strings.TrimSpace(line) != "" {
|
||||
cfg.OnOutput(line + "\n")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Read stdout (JSON output).
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
||||
for scanner.Scan() {
|
||||
stdoutBuf.WriteString(scanner.Text() + "\n")
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
waitErr := cmd.Wait()
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Build result.
|
||||
result := SubagentResult{Elapsed: elapsed}
|
||||
if waitErr != nil {
|
||||
result.Error = waitErr
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
result.ExitCode = exitErr.ExitCode()
|
||||
} else {
|
||||
result.ExitCode = 1
|
||||
}
|
||||
}
|
||||
|
||||
// Parse JSON output.
|
||||
raw := strings.TrimSpace(stdoutBuf.String())
|
||||
var parsed subagentJSONOutput
|
||||
if raw != "" && json.Unmarshal([]byte(raw), &parsed) == nil {
|
||||
result.Response = parsed.Response
|
||||
result.SessionID = parsed.SessionID
|
||||
if parsed.Usage != nil {
|
||||
result.Usage = &SubagentUsage{
|
||||
InputTokens: parsed.Usage.InputTokens,
|
||||
OutputTokens: parsed.Usage.OutputTokens,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback: use raw stdout.
|
||||
result.Response = raw
|
||||
}
|
||||
|
||||
handle.mu.Lock()
|
||||
handle.result = &result
|
||||
handle.proc = nil
|
||||
handle.mu.Unlock()
|
||||
|
||||
if cfg.OnComplete != nil {
|
||||
cfg.OnComplete(result)
|
||||
}
|
||||
}()
|
||||
|
||||
if cfg.Blocking {
|
||||
// Wait for completion and return result directly.
|
||||
<-handle.done
|
||||
handle.mu.Lock()
|
||||
r := handle.result
|
||||
handle.mu.Unlock()
|
||||
return nil, r, nil
|
||||
}
|
||||
|
||||
return handle, nil, nil
|
||||
}
|
||||
|
||||
@@ -183,6 +183,7 @@ func Symbols() interp.Exports {
|
||||
"RetryEvent": reflect.ValueOf((*RetryEvent)(nil)),
|
||||
"PrepareStepEvent": reflect.ValueOf((*PrepareStepEvent)(nil)),
|
||||
"PrepareStepResult": reflect.ValueOf((*PrepareStepResult)(nil)),
|
||||
"LLMUsageEvent": reflect.ValueOf((*LLMUsageEvent)(nil)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,5 +189,11 @@ func NewTestAPI(ext *LoadedExtension) API {
|
||||
return nil
|
||||
})
|
||||
},
|
||||
onLLMUsage: func(h func(LLMUsageEvent, Context)) {
|
||||
reg(LLMUsage, func(e Event, c Context) Result {
|
||||
h(e.(LLMUsageEvent), c)
|
||||
return nil
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package extensions
|
||||
|
||||
// ToolKind constants classify what a tool does, enabling UIs to render
|
||||
// appropriate visualizations (e.g. diff view for edit tools, command+output
|
||||
// for execute tools) and file trackers to identify which results contain
|
||||
// modifications.
|
||||
//
|
||||
// This is the single source of truth for tool-kind classification; the
|
||||
// pkg/kit SDK re-exports these constants.
|
||||
const (
|
||||
ToolKindExecute = "execute" // Shell execution (bash)
|
||||
ToolKindEdit = "edit" // File modification (edit, write)
|
||||
ToolKindRead = "read" // File reading (read, ls)
|
||||
ToolKindSearch = "search" // Content/file search (grep, find)
|
||||
ToolKindSubagent = "agent" // Subagent spawning (subagent)
|
||||
)
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
// MCP and extension tools without an entry default to ToolKindExecute.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": ToolKindExecute,
|
||||
"edit": ToolKindEdit,
|
||||
"write": ToolKindEdit,
|
||||
"read": ToolKindRead,
|
||||
"ls": ToolKindRead,
|
||||
"grep": ToolKindSearch,
|
||||
"find": ToolKindSearch,
|
||||
"subagent": ToolKindSubagent,
|
||||
}
|
||||
|
||||
// ToolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// ToolKindExecute for unknown tools (including MCP tools).
|
||||
func ToolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return ToolKindExecute
|
||||
}
|
||||
+24
-157
@@ -1,143 +1,32 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mark3labs/kit/internal/watcher"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
// Watcher monitors extension directories for .go file changes and triggers
|
||||
// a reload callback when changes are detected. It is implemented in terms
|
||||
// of the general-purpose internal/watcher.ContentWatcher.
|
||||
//
|
||||
// Type-aliasing here lets existing call sites (cmd/root.go and the
|
||||
// watcher_test.go suite) keep using `extensions.NewWatcher` / `*Watcher`
|
||||
// without knowing about the underlying implementation.
|
||||
type Watcher = watcher.ContentWatcher
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
return watcher.New(watcher.Options{
|
||||
Dirs: dirs,
|
||||
Extensions: []string{".go"},
|
||||
OnReload: onReload,
|
||||
Label: "extensions",
|
||||
})
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
@@ -146,47 +35,25 @@ func (w *Watcher) Close() error {
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
standard := []string{
|
||||
globalExtensionsDir(),
|
||||
filepath.Join(".kit", "extensions"),
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
// Filter explicit paths into directories (passed through) and files
|
||||
// (parent dir watched) for CollectDirs to dedupe.
|
||||
var extras []string
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
extras = append(extras, p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
extras = append(extras, filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
return watcher.CollectDirs(standard, extras)
|
||||
}
|
||||
|
||||
@@ -40,27 +40,6 @@ func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentToo
|
||||
return tools
|
||||
}
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
// "execute" for unknown tools (including MCP tools).
|
||||
func toolKindFor(toolName string) string {
|
||||
if kind, ok := coreToolKinds[toolName]; ok {
|
||||
return kind
|
||||
}
|
||||
return "execute"
|
||||
}
|
||||
|
||||
// parseToolArgsJSON attempts to parse JSON-encoded tool args into a map.
|
||||
// Returns nil on failure (non-fatal convenience parsing).
|
||||
func parseToolArgsJSON(input string) map[string]any {
|
||||
@@ -93,7 +72,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
fmt.Sprintf("Error: tool %q is currently disabled", toolName)), nil
|
||||
}
|
||||
|
||||
kind := toolKindFor(toolName)
|
||||
kind := ToolKindFor(toolName)
|
||||
|
||||
// 1. Emit ToolCall — extensions can block execution.
|
||||
if w.runner.HasHandlers(ToolCall) {
|
||||
|
||||
+65
-42
@@ -46,9 +46,9 @@ type AgentSetupOptions struct {
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
// BuildProviderConfig(). Callers (e.g. Kit.New) pre-build this from their
|
||||
// per-instance config store and pass it here, so the slow agent/MCP
|
||||
// initialisation can run without further config reads.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
@@ -75,6 +75,11 @@ type AgentSetupOptions struct {
|
||||
// MCPTaskConfig configures task-augmented tools/call execution. The
|
||||
// zero value preserves historical synchronous-only behaviour.
|
||||
MCPTaskConfig tools.MCPTaskConfig
|
||||
// Viper is the per-instance configuration store. When set, it is used for
|
||||
// any fallback config reads (debug, no-extensions, max-steps, stream,
|
||||
// extension paths) and is attached to the extension runner. When nil, the
|
||||
// process-global viper store is used.
|
||||
Viper *viper.Viper
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -87,57 +92,62 @@ type AgentSetupResult struct {
|
||||
ExtRunner *extensions.Runner
|
||||
}
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the supplied viper
|
||||
// store (or the process-global store when v is nil). All entry points (root,
|
||||
// script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
func BuildProviderConfig(v *viper.Viper) (*models.ProviderConfig, string, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
systemPrompt, err := config.LoadSystemPrompt(v.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
numGPU := int32(v.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(v.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
ModelString: v.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
ProviderAPIKey: v.GetString("provider-api-key"),
|
||||
ProviderURL: v.GetString("provider-url"),
|
||||
MaxTokens: v.GetInt("max-tokens"),
|
||||
StopSequences: v.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
TLSSkipVerify: v.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(v.GetString("thinking-level")),
|
||||
ConfigStore: v,
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
if v.IsSet("temperature") {
|
||||
val := float32(v.GetFloat64("temperature"))
|
||||
cfg.Temperature = &val
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
if v.IsSet("top-p") {
|
||||
val := float32(v.GetFloat64("top-p"))
|
||||
cfg.TopP = &val
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
if v.IsSet("top-k") {
|
||||
val := int32(v.GetInt("top-k"))
|
||||
cfg.TopK = &val
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
if v.IsSet("frequency-penalty") {
|
||||
val := float32(v.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &val
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
if v.IsSet("presence-penalty") {
|
||||
val := float32(v.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &val
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
@@ -149,14 +159,21 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
// Resolve the config store: prefer the per-instance store, falling back to
|
||||
// the process-global store.
|
||||
v := opts.Viper
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after building the
|
||||
// per-instance store). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -164,13 +181,13 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
debugEnabled := opts.Debug || v.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || v.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
maxSteps = v.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
streamingEnabled := opts.StreamingEnabled || v.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
@@ -189,7 +206,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
extRunner, extCreationOpts, extErr = loadExtensions(v)
|
||||
if extErr != nil {
|
||||
fmt.Printf("Warning: Failed to load extensions: %v\n", extErr)
|
||||
}
|
||||
@@ -253,9 +270,14 @@ type extensionCreationOpts struct {
|
||||
}
|
||||
|
||||
// loadExtensions discovers and loads Yaegi extensions, builds the runner,
|
||||
// and returns the tool wrapper/extra tools.
|
||||
func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
extraPaths := viper.GetStringSlice("extension")
|
||||
// and returns the tool wrapper/extra tools. The supplied store is used to
|
||||
// resolve the "extension" config key and is attached to the runner so
|
||||
// extension option lookups stay isolated to this Kit instance.
|
||||
func loadExtensions(v *viper.Viper) (*extensions.Runner, extensionCreationOpts, error) {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
extraPaths := v.GetStringSlice("extension")
|
||||
loaded, err := extensions.LoadExtensions(extraPaths)
|
||||
if err != nil {
|
||||
return nil, extensionCreationOpts{}, err
|
||||
@@ -266,6 +288,7 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
}
|
||||
|
||||
runner := extensions.NewRunner(loaded)
|
||||
runner.SetConfigStore(v)
|
||||
|
||||
wrapper := func(tools []fantasy.AgentTool) []fantasy.AgentTool {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNpmToWireProtocol documents the wire protocols that the auto-router
|
||||
// understands. Provider-specific bundles that need bespoke auth or URL
|
||||
// templating (azure, bedrock, openrouter, google-vertex*, @ai-sdk/gateway)
|
||||
// are intentionally absent — they have native top-level cases in
|
||||
// CreateProvider and never reach the auto-router.
|
||||
func TestNpmToWireProtocol(t *testing.T) {
|
||||
want := map[string]wireProtocol{
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
|
||||
// Thin OpenAI-compatible wrappers — routed via openaicompat using
|
||||
// the SDK's hard-coded default base URL (sdkDefaultBaseURL).
|
||||
"@ai-sdk/groq": wireOpenAI,
|
||||
"@ai-sdk/cerebras": wireOpenAI,
|
||||
"@ai-sdk/perplexity": wireOpenAI,
|
||||
"@ai-sdk/togetherai": wireOpenAI,
|
||||
"@ai-sdk/xai": wireOpenAI,
|
||||
"@ai-sdk/deepinfra": wireOpenAI,
|
||||
"@ai-sdk/mistral": wireOpenAI,
|
||||
"@ai-sdk/cohere": wireOpenAI,
|
||||
"@ai-sdk/vercel": wireOpenAI,
|
||||
"@aihubmix/ai-sdk-provider": wireOpenAI,
|
||||
"venice-ai-sdk-provider": wireOpenAI,
|
||||
"merge-gateway-ai-sdk-provider": wireOpenAI,
|
||||
}
|
||||
for npm, wire := range want {
|
||||
if got := npmToWireProtocol[npm]; got != wire {
|
||||
t.Errorf("npmToWireProtocol[%q] = %d, want %d", npm, got, wire)
|
||||
}
|
||||
}
|
||||
|
||||
// Bundle packages must NOT be in the table — they need bespoke auth or
|
||||
// URL templating that the auto-router cannot satisfy.
|
||||
for _, npm := range []string{
|
||||
"@ai-sdk/google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic",
|
||||
"@ai-sdk/amazon-bedrock",
|
||||
"@ai-sdk/azure",
|
||||
"@openrouter/ai-sdk-provider",
|
||||
"@ai-sdk/gateway",
|
||||
} {
|
||||
if _, ok := npmToWireProtocol[npm]; ok {
|
||||
t.Errorf("npmToWireProtocol unexpectedly contains bundle package %q", npm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestRegistry builds a registry containing a single proxy-style provider
|
||||
// ("testproxy") with the given default npm, plus one model that carries the
|
||||
// given per-model npm override.
|
||||
func newTestRegistry(api, defaultNPM, modelID, modelNPMOverride string) *ModelsRegistry {
|
||||
return &ModelsRegistry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"testproxy": {
|
||||
ID: "testproxy",
|
||||
Name: "Test Proxy",
|
||||
Env: []string{"TESTPROXY_API_KEY"},
|
||||
NPM: defaultNPM,
|
||||
API: api,
|
||||
Models: map[string]ModelInfo{
|
||||
modelID: {
|
||||
ID: modelID,
|
||||
Name: modelID,
|
||||
ProviderNPM: modelNPMOverride,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_WireRouting verifies that autoRouteProvider routes each
|
||||
// npm package to the correct fantasy provider implementation. This is the core
|
||||
// regression test for issue #41: previously any npm that resolved to a
|
||||
// non-openai/anthropic/openaicompat LLM provider (notably @ai-sdk/google) hit a
|
||||
// dead `default` branch and failed with "has no LLM provider mapping".
|
||||
func TestAutoRouteProvider_WireRouting(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
defaultNPM string
|
||||
overrideNPM string
|
||||
// wantType is the concrete fantasy LanguageModel type the model should
|
||||
// be routed to, identified by reflect type string.
|
||||
wantType string
|
||||
}{
|
||||
{
|
||||
name: "openai-compatible default",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
{
|
||||
name: "anthropic override",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/anthropic",
|
||||
wantType: "anthropic.languageModel",
|
||||
},
|
||||
{
|
||||
name: "openai (responses) override",
|
||||
modelID: "gpt-4o",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/openai",
|
||||
wantType: "openai.responsesLanguageModel",
|
||||
},
|
||||
{
|
||||
// The bug: opencode's gemini-* models override the default
|
||||
// openai-compatible npm with @ai-sdk/google.
|
||||
name: "google override (issue #41)",
|
||||
modelID: "gemini-3.5-flash",
|
||||
defaultNPM: "@ai-sdk/openai-compatible",
|
||||
overrideNPM: "@ai-sdk/google",
|
||||
wantType: "*google.languageModel",
|
||||
},
|
||||
{
|
||||
// Unknown npm but provider has an API URL → openai-compatible fallback.
|
||||
name: "unknown npm with API URL falls back to openai-compat",
|
||||
modelID: "test-model",
|
||||
defaultNPM: "@ai-sdk/some-future-thing",
|
||||
wantType: "openai.languageModel",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", tt.defaultNPM, tt.modelID, tt.overrideNPM)
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
result, err := autoRouteProvider(context.Background(), config, "testproxy", tt.modelID, reg)
|
||||
if err != nil {
|
||||
t.Fatalf("autoRouteProvider returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.Model == nil {
|
||||
t.Fatalf("autoRouteProvider returned nil model")
|
||||
}
|
||||
|
||||
gotType := reflect.TypeOf(result.Model).String()
|
||||
if gotType != tt.wantType {
|
||||
t.Errorf("routed to %s, want %s", gotType, tt.wantType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownNpmNoAPI verifies the improved error message for
|
||||
// a provider whose npm has no known wire protocol and that has no API URL to
|
||||
// fall back on.
|
||||
func TestAutoRouteProvider_UnknownNpmNoAPI(t *testing.T) {
|
||||
reg := newTestRegistry("", "@ai-sdk/unmapped", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "testproxy", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown npm with no API URL, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot auto-route provider testproxy") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--provider-url") {
|
||||
t.Errorf("error should suggest --provider-url, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_UnknownProvider verifies the not-in-database error.
|
||||
func TestAutoRouteProvider_UnknownProvider(t *testing.T) {
|
||||
reg := newTestRegistry("https://proxy.example/v1", "@ai-sdk/openai-compatible", "test-model", "")
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
_, err := autoRouteProvider(context.Background(), config, "does-not-exist", "test-model", reg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found in model database") {
|
||||
t.Errorf("unexpected error message: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsProviderLLMSupported_Google verifies that a provider whose npm is
|
||||
// @ai-sdk/google is reported as supported (it now maps to a wire protocol).
|
||||
func TestIsProviderLLMSupported_Google(t *testing.T) {
|
||||
info := &ProviderInfo{ID: "testproxy", NPM: "@ai-sdk/google"}
|
||||
if !isProviderLLMSupported("testproxy", info) {
|
||||
t.Error("expected @ai-sdk/google provider to be LLM-supported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestVersionedBasePath verifies detection of proxy base URLs that already
|
||||
// carry an API version segment (which collides with the genai SDK's injected
|
||||
// version).
|
||||
func TestVersionedBasePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
rawURL string
|
||||
want string
|
||||
}{
|
||||
{"https://opencode.ai/zen/v1", "/zen/v1"},
|
||||
{"https://opencode.ai/zen/v1/", "/zen/v1"},
|
||||
{"https://example.com/api/v1beta", "/api/v1beta"},
|
||||
{"https://example.com/api/v2alpha", "/api/v2alpha"},
|
||||
{"https://generativelanguage.googleapis.com", ""},
|
||||
{"https://proxy.example/openai", ""},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if got := versionedBasePath(tt.rawURL); got != tt.want {
|
||||
t.Errorf("versionedBasePath(%q) = %q, want %q", tt.rawURL, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordingRoundTripper captures the path of the request it receives.
|
||||
type recordingRoundTripper struct{ gotPath string }
|
||||
|
||||
func (r *recordingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
r.gotPath = req.URL.Path
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestGeminiProxyTransport_StripsInjectedVersion verifies that the transport
|
||||
// collapses the genai-injected "/v1beta" segment that follows a proxy base
|
||||
// URL which already carries its own version segment. This is the second-order
|
||||
// fix that makes opencode/gemini-* actually reach the proxy (issue #41).
|
||||
func TestGeminiProxyTransport_StripsInjectedVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
basePath string
|
||||
reqPath string
|
||||
wantPath string
|
||||
}{
|
||||
{
|
||||
name: "strips doubled v1beta after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "strips doubled v1beta1 after /zen/v1",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/zen/v1/v1beta1/models/gemini-3.5-flash:generateContent",
|
||||
wantPath: "/zen/v1/models/gemini-3.5-flash:generateContent",
|
||||
},
|
||||
{
|
||||
name: "leaves non-matching path untouched",
|
||||
basePath: "/zen/v1",
|
||||
reqPath: "/other/v1beta/models/x:generateContent",
|
||||
wantPath: "/other/v1beta/models/x:generateContent",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rec := &recordingRoundTripper{}
|
||||
tr := &geminiProxyTransport{base: rec, basePath: tt.basePath}
|
||||
req, err := http.NewRequest(http.MethodPost, "https://host"+tt.reqPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
if _, err := tr.RoundTrip(req); err != nil {
|
||||
t.Fatalf("RoundTrip: %v", err)
|
||||
}
|
||||
if rec.gotPath != tt.wantPath {
|
||||
t.Errorf("forwarded path = %q, want %q", rec.gotPath, tt.wantPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -69,19 +68,3 @@ func generateCacheKey(systemPrompt, modelID string) string {
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ package models
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
@@ -192,57 +190,3 @@ func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCopilotProviderAliasUsesCatalog(t *testing.T) {
|
||||
registry := NewModelsRegistry()
|
||||
|
||||
models, err := registry.GetModelsForProvider("copilot")
|
||||
if err != nil {
|
||||
t.Fatalf("GetModelsForProvider(copilot) failed: %v", err)
|
||||
}
|
||||
if len(models) == 0 {
|
||||
t.Fatal("expected copilot alias to return github-copilot catalog models")
|
||||
}
|
||||
if registry.LookupModel("copilot", "gpt-5.5") == nil {
|
||||
t.Fatal("expected copilot/gpt-5.5 to resolve through github-copilot catalog")
|
||||
}
|
||||
if registry.GetProviderInfo("copilot") == nil {
|
||||
t.Fatal("expected copilot alias to return github-copilot provider info")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotRejectsNonGPTModels(t *testing.T) {
|
||||
_, err := CreateProvider(t.Context(), &ProviderConfig{ModelString: "copilot/claude-sonnet-4.6"})
|
||||
if err == nil {
|
||||
t.Fatal("expected non-GPT Copilot model to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotHTTPClientCachesToken(t *testing.T) {
|
||||
client := createCopilotHTTPClient("cached-token", time.Now().Add(time.Hour).Unix(), false)
|
||||
transport, ok := client.Transport.(*copilotTransport)
|
||||
if !ok {
|
||||
t.Fatal("expected *copilotTransport")
|
||||
}
|
||||
|
||||
token := transport.cachedToken(t.Context())
|
||||
if token != "cached-token" {
|
||||
t.Fatalf("expected cached token, got %q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilotTransportHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest failed: %v", err)
|
||||
}
|
||||
|
||||
transport := &copilotTransport{
|
||||
base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("Authorization") != "Bearer cached-token" {
|
||||
t.Fatalf("unexpected Authorization header: %q", req.Header.Get("Authorization"))
|
||||
}
|
||||
if req.Header.Get("Copilot-Integration-Id") != copilotIntegrationID {
|
||||
t.Fatalf("unexpected Copilot-Integration-Id header: %q", req.Header.Get("Copilot-Integration-Id"))
|
||||
}
|
||||
if req.Header.Get("Editor-Version") != copilotEditorVersion {
|
||||
t.Fatalf("unexpected Editor-Version header: %q", req.Header.Get("Editor-Version"))
|
||||
}
|
||||
if req.Header.Get("User-Agent") != copilotUserAgent {
|
||||
t.Fatalf("unexpected User-Agent header: %q", req.Header.Get("User-Agent"))
|
||||
}
|
||||
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
|
||||
}),
|
||||
token: "cached-token",
|
||||
expiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
}
|
||||
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("RoundTrip failed: %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
+46
-20
@@ -10,14 +10,24 @@ import (
|
||||
|
||||
// loadCustomModelsFromConfig loads custom model definitions from the config file
|
||||
// and returns them as a map of model ID -> ModelInfo. Returns nil if no custom
|
||||
// models are configured.
|
||||
// models are configured. Reads from the process-global viper store (the model
|
||||
// registry is a process-global singleton).
|
||||
func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
if !viper.IsSet("customModels") {
|
||||
return loadCustomModelsFrom(viper.GetViper())
|
||||
}
|
||||
|
||||
// loadCustomModelsFrom loads custom model definitions from the supplied store.
|
||||
// When v is nil the process-global store is used.
|
||||
func loadCustomModelsFrom(v *viper.Viper) map[string]ModelInfo {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("customModels") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var customModels map[string]CustomModelConfig
|
||||
if err := viper.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
if err := v.UnmarshalKey("customModels", &customModels); err != nil {
|
||||
log.Printf("Warning: Failed to parse customModels: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -59,16 +69,20 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
// LoadModelSettingsFrom loads per-model generation parameter overrides from the
|
||||
// supplied per-instance store. When v is nil the process-global store is used.
|
||||
// Keys are "provider/model" strings. Returns nil if no model settings are
|
||||
// configured.
|
||||
func LoadModelSettingsFrom(v *viper.Viper) map[string]*GenerationParams {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
if !v.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
if err := v.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
@@ -148,12 +162,17 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve the config store: prefer the per-instance store carried on the
|
||||
// ProviderConfig (set by BuildProviderConfig / Kit.New), falling back to
|
||||
// the process-global store for callers that don't thread one through.
|
||||
store := config.ConfigStore
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
if settings := LoadModelSettingsFrom(store); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
@@ -173,28 +192,28 @@ func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
if params.MaxTokens != nil && !isExplicitlySet(store, "max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
if params.Temperature != nil && !isExplicitlySet(store, "temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
if params.TopP != nil && !isExplicitlySet(store, "top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
if params.TopK != nil && !isExplicitlySet(store, "top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet(store, "frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
if params.PresencePenalty != nil && !isExplicitlySet(store, "presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet(store, "stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet(store, "thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
@@ -228,7 +247,14 @@ func LoadSystemPromptValue(input string) string {
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
//
|
||||
// The check runs against the supplied per-instance store when non-nil,
|
||||
// otherwise the process-global store. This keeps the "explicit vs unset"
|
||||
// precedence contract per-Kit-instance once a store is threaded through.
|
||||
func isExplicitlySet(v *viper.Viper, key string) bool {
|
||||
if v == nil {
|
||||
v = viper.GetViper()
|
||||
}
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
@@ -239,7 +265,7 @@ func isExplicitlySet(key string) bool {
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
return v.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
|
||||
File diff suppressed because one or more lines are too long
+83
-14
@@ -48,18 +48,87 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
"@ai-sdk/google-vertex": "google-vertex",
|
||||
"@ai-sdk/google-vertex/anthropic": "google-vertex-anthropic",
|
||||
"@ai-sdk/amazon-bedrock": "bedrock",
|
||||
"@ai-sdk/azure": "azure",
|
||||
"@openrouter/ai-sdk-provider": "openrouter",
|
||||
"@ai-sdk/vercel": "vercel",
|
||||
"@ai-sdk/openai-compatible": "openaicompat",
|
||||
// wireProtocol identifies which LLM API protocol an npm package speaks.
|
||||
// Fantasy implements three native protocols (openai, anthropic, google);
|
||||
// everything else in its providers/ tree is a thin wrapper around one of
|
||||
// them with a pre-baked default URL or auth scheme.
|
||||
type wireProtocol int
|
||||
|
||||
const (
|
||||
wireUnknown wireProtocol = iota
|
||||
wireOpenAI
|
||||
wireAnthropic
|
||||
wireGoogle
|
||||
)
|
||||
|
||||
// npmToWireProtocol maps npm package names from models.dev to the wire
|
||||
// protocol they speak. Provider-specific bundles that need bespoke auth or
|
||||
// URL templating (azure, bedrock, openrouter, google-vertex, google-vertex-
|
||||
// anthropic, and @ai-sdk/gateway which is the Vercel AI Gateway) are
|
||||
// intentionally absent — they have native top-level cases in CreateProvider
|
||||
// and never reach the auto-router. Providers not in this map but with an
|
||||
// api URL are auto-routed through the OpenAI-compatible wire.
|
||||
//
|
||||
// The thin OpenAI-compatible npm wrappers (groq, cerebras, mistral, …) are
|
||||
// listed explicitly so that auto-routing can recover their hard-coded base
|
||||
// URL from sdkDefaultBaseURL when the registry entry has no api field.
|
||||
var npmToWireProtocol = map[string]wireProtocol{
|
||||
// Native wires.
|
||||
"@ai-sdk/openai": wireOpenAI,
|
||||
"@ai-sdk/openai-compatible": wireOpenAI,
|
||||
"@ai-sdk/anthropic": wireAnthropic,
|
||||
"@ai-sdk/google": wireGoogle,
|
||||
|
||||
// Thin OpenAI-compatible wrappers. Each ships with a hard-coded base URL
|
||||
// in its JS SDK (see sdkDefaultBaseURL) but speaks the plain OpenAI chat
|
||||
// completions wire — so we can route them all through fantasy's
|
||||
// openaicompat provider once we supply the URL.
|
||||
"@ai-sdk/groq": wireOpenAI,
|
||||
"@ai-sdk/cerebras": wireOpenAI,
|
||||
"@ai-sdk/perplexity": wireOpenAI,
|
||||
"@ai-sdk/togetherai": wireOpenAI,
|
||||
"@ai-sdk/xai": wireOpenAI,
|
||||
"@ai-sdk/deepinfra": wireOpenAI,
|
||||
"@ai-sdk/mistral": wireOpenAI,
|
||||
"@ai-sdk/cohere": wireOpenAI,
|
||||
"@ai-sdk/vercel": wireOpenAI, // v0 API (api.v0.dev), distinct from @ai-sdk/gateway
|
||||
"@aihubmix/ai-sdk-provider": wireOpenAI,
|
||||
"venice-ai-sdk-provider": wireOpenAI,
|
||||
"merge-gateway-ai-sdk-provider": wireOpenAI,
|
||||
}
|
||||
|
||||
// sdkDefaultBaseURL maps an npm package name to the base URL its JavaScript
|
||||
// SDK uses by default. This lets us recover a working endpoint for providers
|
||||
// whose models.dev entry omits the `api` field because the JS SDK hard-codes
|
||||
// the URL (e.g. groq, cerebras, mistral, x.ai…).
|
||||
//
|
||||
// Only OpenAI-compatible and native-wire SDKs are listed; providers needing
|
||||
// bespoke auth or URL templating (bedrock SigV4, azure resource URLs,
|
||||
// google-vertex project/location, cloudflare gateway account IDs, gitlab,
|
||||
// sap-ai-core) are handled by native CreateProvider cases or surface a
|
||||
// targeted error that asks the user to supply --provider-url.
|
||||
var sdkDefaultBaseURL = map[string]string{
|
||||
// Native wires.
|
||||
"@ai-sdk/openai": "https://api.openai.com/v1",
|
||||
"@ai-sdk/anthropic": "https://api.anthropic.com/v1",
|
||||
"@ai-sdk/google": "https://generativelanguage.googleapis.com/v1beta",
|
||||
|
||||
// Thin OpenAI-compatible wrappers.
|
||||
"@ai-sdk/groq": "https://api.groq.com/openai/v1",
|
||||
"@ai-sdk/cerebras": "https://api.cerebras.ai/v1",
|
||||
"@ai-sdk/perplexity": "https://api.perplexity.ai",
|
||||
"@ai-sdk/togetherai": "https://api.together.xyz/v1",
|
||||
"@ai-sdk/xai": "https://api.x.ai/v1",
|
||||
"@ai-sdk/deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"@ai-sdk/mistral": "https://api.mistral.ai/v1",
|
||||
"@ai-sdk/cohere": "https://api.cohere.com/compatibility/v1",
|
||||
"@ai-sdk/vercel": "https://api.v0.dev/v1",
|
||||
"@aihubmix/ai-sdk-provider": "https://aihubmix.com/v1",
|
||||
"venice-ai-sdk-provider": "https://api.venice.ai/api/v1",
|
||||
"merge-gateway-ai-sdk-provider": "https://api-gateway.merge.dev/v1/ai-sdk",
|
||||
|
||||
// Native handlers — included for ResolveProviderBaseURL introspection
|
||||
// even though CreateProvider routes these via dedicated cases.
|
||||
"@ai-sdk/gateway": "https://ai-gateway.vercel.sh/v1",
|
||||
"@openrouter/ai-sdk-provider": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// ProviderPool manages reusable LLM provider instances to reduce overhead
|
||||
// when spawning multiple subagents or making repeated completion calls.
|
||||
type ProviderPool struct {
|
||||
mu sync.RWMutex
|
||||
providers map[string]*pooledProvider
|
||||
ttl time.Duration
|
||||
closed bool
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
type pooledProvider struct {
|
||||
model fantasy.LanguageModel
|
||||
closer func() error
|
||||
providerOpts fantasy.ProviderOptions
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
refs int32
|
||||
}
|
||||
|
||||
// DefaultPoolTTL is the default time-to-live for idle pooled providers.
|
||||
const DefaultPoolTTL = 5 * time.Minute
|
||||
|
||||
// globalPool is the singleton provider pool instance.
|
||||
var globalPool *ProviderPool
|
||||
var poolOnce sync.Once
|
||||
|
||||
// GetGlobalPool returns the singleton provider pool instance.
|
||||
func GetGlobalPool() *ProviderPool {
|
||||
poolOnce.Do(func() {
|
||||
globalPool = NewProviderPool(DefaultPoolTTL)
|
||||
})
|
||||
return globalPool
|
||||
}
|
||||
|
||||
// NewProviderPool creates a provider pool with the given TTL for idle providers.
|
||||
func NewProviderPool(ttl time.Duration) *ProviderPool {
|
||||
p := &ProviderPool{
|
||||
providers: make(map[string]*pooledProvider),
|
||||
ttl: ttl,
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
go p.cleanupLoop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Get returns a provider for the model string, creating one if needed.
|
||||
// The returned release function must be called when the provider is no longer
|
||||
// needed. The provider may be reused by subsequent Get calls.
|
||||
func (p *ProviderPool) Get(ctx context.Context, modelString string) (fantasy.LanguageModel, fantasy.ProviderOptions, func(), error) {
|
||||
p.mu.Lock()
|
||||
|
||||
// Check if we have an existing provider.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
p.mu.Unlock()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
// Create a new provider outside the lock.
|
||||
config := &ProviderConfig{ModelString: modelString}
|
||||
result, err := CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Double-check: another goroutine may have created one while we were unlocked.
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
// Close the one we just created and use the existing one.
|
||||
if result.Closer != nil {
|
||||
_ = result.Closer.Close()
|
||||
}
|
||||
pp.refs++
|
||||
pp.lastUsed = time.Now()
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
var closerFn func() error
|
||||
if result.Closer != nil {
|
||||
closerFn = result.Closer.Close
|
||||
}
|
||||
|
||||
pp := &pooledProvider{
|
||||
model: result.Model,
|
||||
closer: closerFn,
|
||||
providerOpts: result.ProviderOptions,
|
||||
created: time.Now(),
|
||||
lastUsed: time.Now(),
|
||||
refs: 1,
|
||||
}
|
||||
p.providers[modelString] = pp
|
||||
|
||||
return pp.model, pp.providerOpts, func() { p.release(modelString) }, nil
|
||||
}
|
||||
|
||||
func (p *ProviderPool) release(modelString string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if pp, ok := p.providers[modelString]; ok {
|
||||
pp.refs--
|
||||
pp.lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanupLoop() {
|
||||
ticker := time.NewTicker(p.ttl / 2)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ProviderPool) cleanup() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, pp := range p.providers {
|
||||
// Only clean up providers with no active references and past TTL.
|
||||
if pp.refs <= 0 && now.Sub(pp.lastUsed) > p.ttl {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the pool and releases all providers.
|
||||
func (p *ProviderPool) Close() {
|
||||
p.mu.Lock()
|
||||
if p.closed {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
close(p.closeCh)
|
||||
|
||||
for key, pp := range p.providers {
|
||||
if pp.closer != nil {
|
||||
_ = pp.closer()
|
||||
}
|
||||
delete(p.providers, key)
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
+444
-74
@@ -9,8 +9,11 @@ import (
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
@@ -25,11 +28,30 @@ import (
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
// ClaudeCodePrompt is the required system prompt for OAuth authentication.
|
||||
ClaudeCodePrompt = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
|
||||
// copilotProviderID is the canonical models.dev provider key. The CLI also
|
||||
// accepts the shorter "copilot" alias for user-facing model strings.
|
||||
copilotProviderID = "github-copilot"
|
||||
// copilotAliasProviderID is the short provider prefix accepted by kit.
|
||||
copilotAliasProviderID = "copilot"
|
||||
// copilotBaseURL is the fallback API URL if the model catalog has no API URL.
|
||||
copilotBaseURL = "https://api.githubcopilot.com"
|
||||
|
||||
// GitHub Copilot currently expects VS Code Copilot Chat client identifiers.
|
||||
// Keep these centralized so they are easy to audit and update when GitHub
|
||||
// changes accepted client metadata.
|
||||
copilotIntegrationID = "vscode-chat"
|
||||
copilotEditorVersion = "vscode/1.104.1"
|
||||
copilotEditorPluginVersion = "copilot-chat/0.31.0"
|
||||
copilotUserAgent = "GitHubCopilotChat/0.31.0"
|
||||
copilotOpenAIIntent = "conversation-agent"
|
||||
copilotGitHubAPIVersion = "2026-01-09"
|
||||
)
|
||||
|
||||
// resolveModelAlias resolves model aliases to their full names using the registry
|
||||
@@ -164,6 +186,13 @@ type ProviderConfig struct {
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
|
||||
// ConfigStore is the per-instance configuration store used to resolve
|
||||
// "explicitly set" precedence checks (isExplicitlySet), per-model
|
||||
// settings, and right-sizing. When nil, the process-global viper store is
|
||||
// used. Threading a per-Kit store here keeps generation-parameter
|
||||
// precedence isolated between Kit instances in the same process.
|
||||
ConfigStore *viper.Viper
|
||||
|
||||
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
|
||||
// for long operations like Ollama model pulls. The returned io.ReadCloser
|
||||
// must be closed when done. When nil, the raw reader is consumed directly
|
||||
@@ -205,6 +234,20 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
return "", "", fmt.Errorf("invalid model format %q: expected provider/model (e.g. anthropic/claude-sonnet-4-5)", modelString)
|
||||
}
|
||||
|
||||
// isCopilotProvider reports whether provider is the canonical catalog key or
|
||||
// the user-facing shorthand alias.
|
||||
func isCopilotProvider(provider string) bool {
|
||||
return provider == copilotAliasProviderID || provider == copilotProviderID
|
||||
}
|
||||
|
||||
// catalogProviderID maps supported provider aliases to their models.dev keys.
|
||||
func catalogProviderID(provider string) string {
|
||||
if isCopilotProvider(provider) {
|
||||
return copilotProviderID
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
// CreateProvider creates a fantasy LanguageModel based on the provider configuration.
|
||||
// Model metadata is looked up from the models.dev database for cost tracking and
|
||||
// capability detection, but unknown models are passed through to the provider
|
||||
@@ -212,8 +255,10 @@ func ParseModelString(modelString string) (provider, model string, err error) {
|
||||
//
|
||||
// Native providers: anthropic, openai, google, ollama, azure, google-vertex-anthropic,
|
||||
// openrouter, bedrock, vercel.
|
||||
// Any provider in models.dev with an api URL or openai-compatible npm package
|
||||
// is auto-routed through fantasy's openaicompat provider.
|
||||
// Any other provider in models.dev is auto-routed by wire protocol: its npm
|
||||
// package (or per-model override) selects the OpenAI, Anthropic, or Google
|
||||
// transport, using the provider's api URL as the base. Providers with an api
|
||||
// URL but an unrecognized npm package fall back to the OpenAI-compatible wire.
|
||||
func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResult, error) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
@@ -226,17 +271,30 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
}
|
||||
|
||||
registry := GetGlobalRegistry()
|
||||
lookupProvider := catalogProviderID(provider)
|
||||
|
||||
// Look up model metadata (advisory, not blocking).
|
||||
// Look up model metadata (advisory for most providers, strict for Copilot).
|
||||
// When the model is known we validate config limits and print
|
||||
// suggestions on likely typos; when unknown we let the provider
|
||||
// API be the authority.
|
||||
modelInfo := registry.LookupModel(provider, modelName)
|
||||
if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
|
||||
// API be the authority except for Copilot, whose non-GPT catalog entries
|
||||
// require unsupported wire protocols.
|
||||
modelInfo := registry.LookupModel(lookupProvider, modelName)
|
||||
if isCopilotProvider(provider) {
|
||||
providerInfo := registry.GetProviderInfo(copilotProviderID)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", copilotProviderID)
|
||||
}
|
||||
if modelInfo == nil {
|
||||
if suggestions := registry.SuggestModels(copilotProviderID, modelName); len(suggestions) > 0 {
|
||||
return nil, fmt.Errorf("model %q not found for provider %s. Did you mean one of: %s", modelName, copilotProviderID, strings.Join(suggestions, ", "))
|
||||
}
|
||||
return nil, fmt.Errorf("model %q not found for provider %s", modelName, copilotProviderID)
|
||||
}
|
||||
} else if modelInfo == nil && provider != "ollama" && config.ProviderURL == "" {
|
||||
// Model not in database — warn with suggestions but don't block.
|
||||
if suggestions := registry.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
if suggestions := registry.SuggestModels(lookupProvider, modelName); len(suggestions) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "Warning: model %q not found in model database for provider %s. Similar models: %s\n",
|
||||
modelName, provider, strings.Join(suggestions, ", "))
|
||||
modelName, lookupProvider, strings.Join(suggestions, ", "))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,17 +328,21 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "copilot", "github-copilot":
|
||||
result, createErr = createCopilotProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
case "azure", "azure-cognitive-services":
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "google-vertex":
|
||||
result, createErr = createGoogleVertexProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
case "bedrock", "amazon-bedrock":
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
@@ -327,44 +389,100 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
// in the models.dev database and routing through the appropriate fantasy provider.
|
||||
// For openai-compatible providers, it uses the api URL from models.dev.
|
||||
// Models may have a provider override that specifies a different npm package than
|
||||
// the provider's default (e.g., opencode's claude-opus-4-6 uses @ai-sdk/anthropic).
|
||||
// It routes on wire protocol (openai, anthropic, google) rather than per-npm
|
||||
// provider name: fantasy implements three native wire protocols, and every other
|
||||
// entry in its providers/ tree is a thin wrapper around one of them. Using the
|
||||
// provider's api URL from models.dev as the base URL, any proxy that re-flavors
|
||||
// one of these protocols (e.g. opencode's Gemini routes) Just Works.
|
||||
//
|
||||
// Models may carry a provider override that specifies a different npm package
|
||||
// than the provider's default (e.g. opencode's claude-* uses @ai-sdk/anthropic
|
||||
// and its gemini-* uses @ai-sdk/google), which is resolved first.
|
||||
func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, modelName string, registry *ModelsRegistry) (*ProviderResult, error) {
|
||||
providerInfo := registry.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
return nil, fmt.Errorf("unsupported provider: %s (not found in model database)", provider)
|
||||
}
|
||||
|
||||
// Check for model-specific provider override
|
||||
// Resolve npm: per-model override > provider default.
|
||||
npmPackage := providerInfo.NPM
|
||||
if modelInfo := registry.LookupModel(provider, modelName); modelInfo != nil && modelInfo.ProviderNPM != "" {
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
llmProvider = "openaicompat"
|
||||
wire, known := npmToWireProtocol[npmPackage]
|
||||
if !known {
|
||||
// Unknown npm but the provider has an API URL → assume OpenAI-compatible.
|
||||
// (Preserves the long-standing "any provider in models.dev with an api URL
|
||||
// is auto-routed through openaicompat" behaviour.)
|
||||
if providerInfo.API == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"cannot auto-route provider %s: npm package %q has no known wire protocol "+
|
||||
"and the registry has no API URL (use --provider-url to override)",
|
||||
provider, npmPackage,
|
||||
)
|
||||
}
|
||||
wire = wireOpenAI
|
||||
}
|
||||
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
// All three wires use the provider's API URL from models.dev as the base.
|
||||
// When the registry has none, fall back to the SDK's hard-coded default for
|
||||
// this npm package (covers groq, cerebras, mistral, x.ai, etc. — providers
|
||||
// whose JS SDK ships a built-in baseURL that models.dev doesn't restate).
|
||||
if config.ProviderURL == "" {
|
||||
if providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
} else if defaultURL, ok := sdkDefaultBaseURL[npmPackage]; ok {
|
||||
config.ProviderURL = defaultURL
|
||||
providerInfo.API = defaultURL // for downstream helpers that read info.API
|
||||
}
|
||||
return createAutoRoutedAnthropicProvider(ctx, config, modelName, providerInfo)
|
||||
case "openai":
|
||||
if config.ProviderURL == "" && providerInfo.API != "" {
|
||||
config.ProviderURL = providerInfo.API
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
|
||||
// Provider templates a runtime account/region/deployment segment into the
|
||||
// URL (cloudflare-ai-gateway, databricks, snowflake-cortex, gitlab,
|
||||
// sap-ai-core). Resolve via environment variables, or surface a targeted
|
||||
// error pointing the user at the right knobs.
|
||||
if resolved, err := resolveTemplatedAPIURL(config.ProviderURL, providerInfo); err != nil {
|
||||
return nil, err
|
||||
} else if resolved != "" {
|
||||
config.ProviderURL = resolved
|
||||
providerInfo.API = resolved
|
||||
}
|
||||
|
||||
switch wire {
|
||||
case wireOpenAI:
|
||||
// The native OpenAI SDK package (@ai-sdk/openai) speaks the Responses
|
||||
// API; openai-compatible proxies (and unknown-npm fallbacks) use the
|
||||
// chat-completions wire via fantasy's openaicompat provider.
|
||||
if npmPackage == "@ai-sdk/openai" {
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
}
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case wireAnthropic:
|
||||
return createAutoRoutedAnthropicProvider(ctx, config, modelName, providerInfo)
|
||||
case wireGoogle:
|
||||
return createAutoRoutedGoogleProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("internal error: unknown wire protocol for provider %s (npm: %s)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveAutoRouteAPIKey looks up the API key for an auto-routed provider,
|
||||
// returning a uniform error message when none can be resolved.
|
||||
func resolveAutoRouteAPIKey(config *ProviderConfig, info *ProviderInfo) (string, error) {
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// wrapProviderErr produces the uniform "failed to create X provider/model: %w"
|
||||
// error wrap used by every createXxxProvider path. kind is typically
|
||||
// "provider" or "model".
|
||||
func wrapProviderErr(name, kind string, err error) error {
|
||||
return fmt.Errorf("failed to create %s %s: %w", name, kind, err)
|
||||
}
|
||||
|
||||
// createAutoRoutedOpenAICompatProvider creates an openaicompat provider using
|
||||
@@ -378,10 +496,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
return nil, fmt.Errorf("provider %s requires --provider-url (no API URL in database)", info.ID)
|
||||
}
|
||||
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
@@ -395,12 +512,12 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -411,10 +528,9 @@ func createAutoRoutedOpenAICompatProvider(ctx context.Context, config *ProviderC
|
||||
func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
clearConflictingAnthropicSamplingParams(config)
|
||||
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []anthropic.Option
|
||||
@@ -433,12 +549,12 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
|
||||
p, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -447,10 +563,9 @@ func createAutoRoutedAnthropicProvider(ctx context.Context, config *ProviderConf
|
||||
// createAutoRoutedOpenAIProvider creates an openai provider for
|
||||
// third-party providers with openai-compatible APIs.
|
||||
func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
apiKey := resolveAPIKey(config.ProviderAPIKey, info.Env)
|
||||
if apiKey == "" {
|
||||
return nil, fmt.Errorf("%s API key not provided. Use --provider-api-key or set %s",
|
||||
info.Name, strings.Join(info.Env, " / "))
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var opts []openai.Option
|
||||
@@ -467,12 +582,12 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s provider: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create %s model: %w", info.Name, err)
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
providerOpts := buildOpenAIProviderOptions(config, modelName)
|
||||
@@ -480,6 +595,114 @@ func createAutoRoutedOpenAIProvider(ctx context.Context, config *ProviderConfig,
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createAutoRoutedGoogleProvider creates a Google (Gemini) provider for
|
||||
// third-party providers that expose a Gemini-compatible API (e.g. opencode's
|
||||
// Gemini routes, which carry an @ai-sdk/google per-model override).
|
||||
//
|
||||
// The underlying genai SDK always injects its own API version segment
|
||||
// ("v1beta") between the base URL and the resource path. When the proxy's
|
||||
// base URL from models.dev already carries a version segment (e.g. opencode's
|
||||
// https://opencode.ai/zen/v1), that produces a doubled ".../v1/v1beta/..."
|
||||
// path that the proxy rejects. In that case we install a transport that
|
||||
// strips the injected segment so the proxy's own version is used.
|
||||
func createAutoRoutedGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string, info *ProviderInfo) (*ProviderResult, error) {
|
||||
apiKey, err := resolveAutoRouteAPIKey(config, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := []google.Option{
|
||||
google.WithGeminiAPIKey(apiKey),
|
||||
google.WithName(info.ID),
|
||||
}
|
||||
|
||||
if config.ProviderURL != "" {
|
||||
opts = append(opts, google.WithBaseURL(config.ProviderURL))
|
||||
}
|
||||
|
||||
// Decide whether the genai-injected version segment needs stripping.
|
||||
var httpClient *http.Client
|
||||
if basePath := versionedBasePath(config.ProviderURL); basePath != "" {
|
||||
httpClient = newGeminiProxyHTTPClient(basePath, config.TLSSkipVerify)
|
||||
} else if config.TLSSkipVerify {
|
||||
httpClient = createHTTPClientWithTLSConfig(true)
|
||||
}
|
||||
if httpClient != nil {
|
||||
opts = append(opts, google.WithHTTPClient(httpClient))
|
||||
}
|
||||
|
||||
p, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr(info.Name, "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr(info.Name, "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// versionSegmentRe matches a trailing API version segment in a URL path,
|
||||
// e.g. "/v1", "/v1beta", "/v1beta1", "/v2alpha".
|
||||
var versionSegmentRe = regexp.MustCompile(`/v\d+(?:beta\d*|alpha\d*)?$`)
|
||||
|
||||
// versionedBasePath returns the path component of rawURL when that path ends
|
||||
// with an API version segment (e.g. opencode's ".../zen/v1" → "/zen/v1").
|
||||
// It returns "" when rawURL is empty, unparseable, or has no version suffix
|
||||
// — in which case the genai SDK's default version injection is correct and
|
||||
// no rewriting is needed.
|
||||
func versionedBasePath(rawURL string) string {
|
||||
if rawURL == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path := strings.TrimSuffix(u.Path, "/")
|
||||
if versionSegmentRe.MatchString(path) {
|
||||
return path
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// newGeminiProxyHTTPClient builds an HTTP client whose transport strips the
|
||||
// genai-injected version segment ("v1beta"/"v1beta1") that directly follows
|
||||
// basePath, collapsing "{basePath}/v1beta/..." back to "{basePath}/...".
|
||||
func newGeminiProxyHTTPClient(basePath string, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: &geminiProxyTransport{base: base, basePath: basePath},
|
||||
}
|
||||
}
|
||||
|
||||
// geminiProxyTransport removes the redundant API version segment that the
|
||||
// genai SDK injects after a proxy base URL that already carries its own
|
||||
// version segment.
|
||||
type geminiProxyTransport struct {
|
||||
base http.RoundTripper
|
||||
basePath string
|
||||
}
|
||||
|
||||
func (t *geminiProxyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
for _, injected := range []string{"/v1beta1", "/v1beta"} {
|
||||
prefix := t.basePath + injected + "/"
|
||||
if strings.HasPrefix(req.URL.Path, prefix) {
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.URL.Path = t.basePath + strings.TrimPrefix(req.URL.Path, t.basePath+injected)
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// resolveAPIKey returns the first non-empty API key from the explicit key
|
||||
// or the environment variables.
|
||||
func resolveAPIKey(explicitKey string, envVars []string) string {
|
||||
@@ -530,7 +753,7 @@ func rightSizeMaxTokens(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
if modelInfo == nil || modelInfo.Limit.Output <= 0 {
|
||||
return
|
||||
}
|
||||
if isExplicitlySet("max-tokens") {
|
||||
if isExplicitlySet(config.ConfigStore, "max-tokens") {
|
||||
return
|
||||
}
|
||||
target := min(modelInfo.Limit.Output, defaultRightSizeCap)
|
||||
@@ -709,7 +932,7 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
|
||||
}
|
||||
|
||||
// Handle OAuth vs API key authentication
|
||||
if strings.HasPrefix(source, "stored OAuth") {
|
||||
if source == auth.CredentialSourceOAuth {
|
||||
httpClient := createOAuthHTTPClient(apiKey, config.TLSSkipVerify)
|
||||
opts = append(opts, anthropic.WithHTTPClient(httpClient))
|
||||
// Note: For OAuth, the API key is set as a placeholder; the transport handles auth
|
||||
@@ -719,12 +942,12 @@ func createAnthropicProvider(ctx context.Context, config *ProviderConfig, modelN
|
||||
|
||||
provider, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Anthropic provider: %w", err)
|
||||
return nil, wrapProviderErr("Anthropic", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Anthropic model: %w", err)
|
||||
return nil, wrapProviderErr("Anthropic", "model", err)
|
||||
}
|
||||
|
||||
// Build provider options for extended thinking (reasoning budget).
|
||||
@@ -761,12 +984,12 @@ func createVertexAnthropicProvider(ctx context.Context, config *ProviderConfig,
|
||||
|
||||
provider, err := anthropic.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vertex Anthropic provider: %w", err)
|
||||
return nil, wrapProviderErr("Vertex Anthropic", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vertex Anthropic model: %w", err)
|
||||
return nil, wrapProviderErr("Vertex Anthropic", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -834,12 +1057,12 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI model: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI", "model", err)
|
||||
}
|
||||
|
||||
// Build provider options for OpenAI Responses API reasoning models.
|
||||
@@ -848,6 +1071,72 @@ func createOpenAIProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// createCopilotProvider builds a GitHub Copilot provider through fantasy's
|
||||
// OpenAI-compatible provider. The catalog key is github-copilot, but the public
|
||||
// model prefix may be either copilot/ or github-copilot/.
|
||||
//
|
||||
// Only gpt-* Copilot models are enabled here. The catalog also lists Claude and
|
||||
// Gemini Copilot models, but those require different wire protocols and must be
|
||||
// routed explicitly before they can be safely accepted.
|
||||
func createCopilotProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if !strings.HasPrefix(modelName, "gpt-") {
|
||||
return nil, fmt.Errorf("GitHub Copilot model %q is not supported yet: only gpt-* models use the OpenAI-compatible protocol", modelName)
|
||||
}
|
||||
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize credential manager: %w", err)
|
||||
}
|
||||
|
||||
token, err := cm.GetValidCopilotAccessTokenContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GitHub Copilot credentials not available. Use 'kit auth login copilot': %w", err)
|
||||
}
|
||||
|
||||
expiresAt := int64(0)
|
||||
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == token {
|
||||
expiresAt = creds.ExpiresAt
|
||||
}
|
||||
|
||||
baseURL := copilotBaseURL
|
||||
if providerInfo := GetGlobalRegistry().GetProviderInfo(copilotProviderID); providerInfo != nil && providerInfo.API != "" {
|
||||
baseURL = providerInfo.API
|
||||
}
|
||||
if config.ProviderURL != "" {
|
||||
baseURL = config.ProviderURL
|
||||
}
|
||||
|
||||
opts := []openai.Option{
|
||||
openai.WithName(copilotAliasProviderID),
|
||||
openai.WithBaseURL(baseURL),
|
||||
openai.WithAPIKey(token),
|
||||
openai.WithHTTPClient(createCopilotHTTPClient(token, expiresAt, config.TLSSkipVerify)),
|
||||
openai.WithUseResponsesAPI(),
|
||||
openai.WithResponsesAPIFunc(copilotUsesResponsesAPI),
|
||||
openai.WithObjectMode(fantasy.ObjectModeTool),
|
||||
}
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GitHub Copilot provider: %w", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GitHub Copilot model: %w", err)
|
||||
}
|
||||
|
||||
providerOpts := buildOpenAIProviderOptions(config, modelName)
|
||||
|
||||
return &ProviderResult{Model: model, ProviderOptions: providerOpts}, nil
|
||||
}
|
||||
|
||||
// copilotUsesResponsesAPI selects the OpenAI Responses API for Copilot models
|
||||
// known to support it. Non-gpt models are rejected before provider creation.
|
||||
func copilotUsesResponsesAPI(modelID string) bool {
|
||||
return strings.HasPrefix(modelID, "gpt-5")
|
||||
}
|
||||
|
||||
// createOpenAICodexProvider creates a provider for ChatGPT/Codex OAuth tokens.
|
||||
// Uses the chatgpt.com/backend-api/codex endpoint with special headers.
|
||||
func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, modelName, token, accountID string) (*ProviderResult, error) {
|
||||
@@ -875,12 +1164,12 @@ func createOpenAICodexProvider(ctx context.Context, config *ProviderConfig, mode
|
||||
|
||||
provider, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI Codex", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenAI Codex model: %w", err)
|
||||
return nil, wrapProviderErr("OpenAI Codex", "model", err)
|
||||
}
|
||||
|
||||
providerOpts := buildCodexProviderOptions(config, modelName)
|
||||
@@ -977,6 +1266,87 @@ func (t *codexTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// createCopilotHTTPClient returns an HTTP client that injects Copilot-specific
|
||||
// authorization and client metadata headers. The token and expiry are cached in
|
||||
// the transport so streaming requests do not hit credentials.json on every
|
||||
// RoundTrip; the credential manager is consulted only near expiry.
|
||||
func createCopilotHTTPClient(token string, expiresAt int64, skipVerify bool) *http.Client {
|
||||
var base http.RoundTripper
|
||||
if skipVerify {
|
||||
base = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &copilotTransport{
|
||||
base: base,
|
||||
token: token,
|
||||
expiresAt: expiresAt,
|
||||
},
|
||||
Timeout: 120 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// copilotTransport decorates requests for api.githubcopilot.com.
|
||||
//
|
||||
// It owns a cached Copilot access token. When the token is still valid, the hot
|
||||
// path is in-memory only. Near expiry it refreshes through CredentialManager,
|
||||
// which updates both the cache here and credentials.json.
|
||||
type copilotTransport struct {
|
||||
base http.RoundTripper
|
||||
token string
|
||||
expiresAt int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *copilotTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := t.cachedToken(req.Context())
|
||||
|
||||
newReq := req.Clone(req.Context())
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
newReq.Header.Set("Copilot-Integration-Id", copilotIntegrationID)
|
||||
newReq.Header.Set("Editor-Version", copilotEditorVersion)
|
||||
newReq.Header.Set("Editor-Plugin-Version", copilotEditorPluginVersion)
|
||||
newReq.Header.Set("Openai-Intent", copilotOpenAIIntent)
|
||||
newReq.Header.Set("User-Agent", copilotUserAgent)
|
||||
newReq.Header.Set("X-GitHub-Api-Version", copilotGitHubAPIVersion)
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// cachedToken returns the cached token unless it is within the five-minute
|
||||
// refresh window. Refresh errors fall back to the last token so the request can
|
||||
// surface any authoritative auth failure from the Copilot API.
|
||||
func (t *copilotTransport) cachedToken(ctx context.Context) string {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.expiresAt == 0 || time.Now().Unix() < t.expiresAt-300 {
|
||||
return t.token
|
||||
}
|
||||
|
||||
cm, err := auth.NewCredentialManager()
|
||||
if err != nil {
|
||||
return t.token
|
||||
}
|
||||
|
||||
fresh, err := cm.GetValidCopilotAccessTokenContext(ctx)
|
||||
if err != nil || fresh == "" {
|
||||
return t.token
|
||||
}
|
||||
|
||||
t.token = fresh
|
||||
if creds, err := cm.GetCopilotCredentials(); err == nil && creds != nil && creds.CopilotAccessToken == fresh {
|
||||
t.expiresAt = creds.ExpiresAt
|
||||
}
|
||||
return t.token
|
||||
}
|
||||
|
||||
func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
apiKey := firstNonEmpty(
|
||||
config.ProviderAPIKey,
|
||||
@@ -993,12 +1363,12 @@ func createGoogleProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Google provider: %w", err)
|
||||
return nil, wrapProviderErr("Google", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Google model: %w", err)
|
||||
return nil, wrapProviderErr("Google", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1031,12 +1401,12 @@ func createAzureProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := azure.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure OpenAI provider: %w", err)
|
||||
return nil, wrapProviderErr("Azure OpenAI", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure OpenAI model: %w", err)
|
||||
return nil, wrapProviderErr("Azure OpenAI", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1056,12 +1426,12 @@ func createOpenRouterProvider(ctx context.Context, config *ProviderConfig, model
|
||||
|
||||
provider, err := openrouter.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenRouter provider: %w", err)
|
||||
return nil, wrapProviderErr("OpenRouter", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OpenRouter model: %w", err)
|
||||
return nil, wrapProviderErr("OpenRouter", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1073,12 +1443,12 @@ func createBedrockProvider(ctx context.Context, config *ProviderConfig, modelNam
|
||||
// Bedrock uses AWS SDK default credential chain (env vars, shared config, etc.)
|
||||
provider, err := bedrock.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Bedrock provider: %w", err)
|
||||
return nil, wrapProviderErr("Bedrock", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Bedrock model: %w", err)
|
||||
return nil, wrapProviderErr("Bedrock", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1102,12 +1472,12 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := vercel.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vercel provider: %w", err)
|
||||
return nil, wrapProviderErr("Vercel", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Vercel model: %w", err)
|
||||
return nil, wrapProviderErr("Vercel", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1160,12 +1530,12 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
return nil, wrapProviderErr("custom", "provider", err)
|
||||
}
|
||||
|
||||
model, err := p.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom model: %w", err)
|
||||
return nil, wrapProviderErr("custom", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
@@ -1209,12 +1579,12 @@ func createOllamaProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
|
||||
provider, err := openaicompat.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Ollama provider: %w", err)
|
||||
return nil, wrapProviderErr("Ollama", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Ollama model: %w", err)
|
||||
return nil, wrapProviderErr("Ollama", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{
|
||||
|
||||
@@ -246,6 +246,7 @@ func loadEmbeddedProviders() map[string]modelsDBProvider {
|
||||
// doesn't track yet. Callers should treat a nil return as "unknown model"
|
||||
// and continue with sensible defaults.
|
||||
func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
@@ -273,6 +274,7 @@ func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
@@ -287,6 +289,7 @@ func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
// variables. Returns nil for providers not in the registry (unknown
|
||||
// providers are assumed to handle auth themselves or via --provider-api-key).
|
||||
func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) error {
|
||||
provider = catalogProviderID(provider)
|
||||
if apiKey != "" {
|
||||
return nil
|
||||
}
|
||||
@@ -311,6 +314,15 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
}
|
||||
}
|
||||
|
||||
// For GitHub Copilot, check stored GitHub OAuth credentials.
|
||||
if provider == copilotProviderID {
|
||||
if cm, err := auth.NewCredentialManager(); err == nil {
|
||||
if has, _ := cm.HasCopilotCredentials(); has {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
envVars, err := r.getRequiredEnvVars(provider)
|
||||
if err != nil {
|
||||
// Unknown provider — nothing to validate
|
||||
@@ -350,6 +362,7 @@ func (r *ModelsRegistry) ValidateEnvironment(provider string, apiKey string) err
|
||||
|
||||
// SuggestModels returns similar model names when an invalid model is provided.
|
||||
func (r *ModelsRegistry) SuggestModels(provider, invalidModel string) []string {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
@@ -404,8 +417,8 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
// Check if npm maps to a known wire protocol
|
||||
if _, ok := npmToWireProtocol[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -415,6 +428,7 @@ func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
|
||||
// GetModelsForProvider returns all models for a specific provider.
|
||||
func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]ModelInfo, error) {
|
||||
provider = catalogProviderID(provider)
|
||||
providerInfo, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unsupported provider: %s", provider)
|
||||
@@ -425,6 +439,7 @@ func (r *ModelsRegistry) GetModelsForProvider(provider string) (map[string]Model
|
||||
|
||||
// GetProviderInfo returns the full provider info, or nil if not found.
|
||||
func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
provider = catalogProviderID(provider)
|
||||
info, exists := r.providers[provider]
|
||||
if !exists {
|
||||
return nil
|
||||
|
||||
@@ -0,0 +1,170 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy/providers/google"
|
||||
)
|
||||
|
||||
// templatePlaceholderRe matches "${NAME}" placeholders in URL templates from
|
||||
// models.dev (e.g. "https://${DATABRICKS_HOST}/ai-gateway/mlflow/v1").
|
||||
var templatePlaceholderRe = regexp.MustCompile(`\$\{([A-Z0-9_]+)\}`)
|
||||
|
||||
// templateEnvVarOverrides supplies fallback environment variable names for
|
||||
// placeholders that providers commonly use under non-obvious env names.
|
||||
// The placeholder name itself is always tried first; this map adds extra
|
||||
// names to try when the placeholder doesn't match the canonical env var.
|
||||
var templateEnvVarOverrides = map[string][]string{
|
||||
"CLOUDFLARE_ACCOUNT_ID": {"CF_ACCOUNT_ID"},
|
||||
"CLOUDFLARE_GATEWAY_NAME": {"CF_GATEWAY", "CLOUDFLARE_GATEWAY"},
|
||||
"DATABRICKS_HOST": {"DATABRICKS_WORKSPACE_URL"},
|
||||
"SNOWFLAKE_ACCOUNT": {"SNOWFLAKE_ACCOUNT_ID"},
|
||||
}
|
||||
|
||||
// resolveTemplatedAPIURL substitutes "${VAR}" placeholders in apiURL with the
|
||||
// values of the named environment variables. Returns:
|
||||
// - ("", nil) when apiURL contains no placeholders (caller keeps current URL),
|
||||
// - (resolved, nil) when every placeholder was resolved,
|
||||
// - ("", error) when one or more placeholders are unset, with a message that
|
||||
// names the missing env vars and points at the relevant provider.
|
||||
//
|
||||
// The info parameter is used purely for error messaging (provider name).
|
||||
func resolveTemplatedAPIURL(apiURL string, info *ProviderInfo) (string, error) {
|
||||
if apiURL == "" || !strings.Contains(apiURL, "${") {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var missing []string
|
||||
resolved := templatePlaceholderRe.ReplaceAllStringFunc(apiURL, func(match string) string {
|
||||
// match is "${NAME}". Extract NAME.
|
||||
name := match[2 : len(match)-1]
|
||||
if v := os.Getenv(name); v != "" {
|
||||
return v
|
||||
}
|
||||
for _, alt := range templateEnvVarOverrides[name] {
|
||||
if v := os.Getenv(alt); v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
missing = append(missing, name)
|
||||
return match
|
||||
})
|
||||
|
||||
if len(missing) > 0 {
|
||||
providerName := info.ID
|
||||
if info.Name != "" {
|
||||
providerName = info.Name
|
||||
}
|
||||
return "", fmt.Errorf(
|
||||
"provider %s requires environment variable(s) %s to construct its API URL (%s); "+
|
||||
"set them or pass --provider-url to override",
|
||||
providerName, strings.Join(missing, ", "), apiURL,
|
||||
)
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// ResolveProviderBaseURL returns the base API URL kit will use when talking to
|
||||
// the given provider, applying the same resolution order as CreateProvider:
|
||||
//
|
||||
// 1. The provider's `api` field from the models.dev registry.
|
||||
// 2. The hard-coded default base URL of its npm SDK package (e.g.
|
||||
// @ai-sdk/groq → https://api.groq.com/openai/v1).
|
||||
// 3. Template substitution against the current process environment when the
|
||||
// URL contains "${VAR}" placeholders (e.g. cloudflare-workers-ai needs
|
||||
// CLOUDFLARE_ACCOUNT_ID).
|
||||
//
|
||||
// It returns an error when the provider is unknown, when no URL can be derived,
|
||||
// or when a templated URL has unset placeholders. The error message is suitable
|
||||
// for direct display to end users.
|
||||
//
|
||||
// Note: providers handled by bespoke auth schemes (amazon-bedrock SigV4,
|
||||
// azure resource URLs, google-vertex project/location, sap-ai-core customer
|
||||
// deployments) may return either an empty URL or a regional/templated URL —
|
||||
// the actual endpoint is finalised inside their native handlers and depends on
|
||||
// runtime credentials.
|
||||
func ResolveProviderBaseURL(providerID string) (string, error) {
|
||||
registry := GetGlobalRegistry()
|
||||
info := registry.GetProviderInfo(providerID)
|
||||
if info == nil {
|
||||
return "", fmt.Errorf("unknown provider: %s", providerID)
|
||||
}
|
||||
|
||||
apiURL := info.API
|
||||
if apiURL == "" {
|
||||
if defaultURL, ok := sdkDefaultBaseURL[info.NPM]; ok {
|
||||
apiURL = defaultURL
|
||||
}
|
||||
}
|
||||
|
||||
if apiURL == "" {
|
||||
return "", fmt.Errorf(
|
||||
"provider %s has no default API URL: its npm package %q does not "+
|
||||
"ship a built-in baseURL (likely Bedrock SigV4, Azure deployment, "+
|
||||
"Vertex project/location, or a customer-hosted endpoint). "+
|
||||
"Pass --provider-url or set the provider's URL env var",
|
||||
providerID, info.NPM,
|
||||
)
|
||||
}
|
||||
|
||||
if strings.Contains(apiURL, "${") {
|
||||
resolved, err := resolveTemplatedAPIURL(apiURL, info)
|
||||
if err != nil {
|
||||
return apiURL, err
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
return apiURL, nil
|
||||
}
|
||||
|
||||
// createGoogleVertexProvider creates a Google Gemini provider that targets the
|
||||
// Vertex AI backend (rather than the public generativelanguage.googleapis.com
|
||||
// endpoint). It requires the same project/region environment variables as
|
||||
// google-vertex-anthropic.
|
||||
func createGoogleVertexProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
projectID := firstNonEmpty(
|
||||
os.Getenv("GOOGLE_VERTEX_PROJECT"),
|
||||
os.Getenv("GOOGLE_CLOUD_PROJECT"),
|
||||
os.Getenv("GCLOUD_PROJECT"),
|
||||
os.Getenv("CLOUDSDK_CORE_PROJECT"),
|
||||
)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf(
|
||||
"google Vertex project ID not provided, set GOOGLE_VERTEX_PROJECT, " +
|
||||
"GOOGLE_CLOUD_PROJECT, or GCLOUD_PROJECT environment variable",
|
||||
)
|
||||
}
|
||||
|
||||
region := firstNonEmpty(
|
||||
os.Getenv("GOOGLE_VERTEX_LOCATION"),
|
||||
os.Getenv("CLOUD_ML_REGION"),
|
||||
)
|
||||
if region == "" {
|
||||
region = "global"
|
||||
}
|
||||
|
||||
opts := []google.Option{
|
||||
google.WithVertex(projectID, region),
|
||||
google.WithName("google-vertex"),
|
||||
}
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, google.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
provider, err := google.New(opts...)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr("Google Vertex", "provider", err)
|
||||
}
|
||||
|
||||
model, err := provider.LanguageModel(ctx, modelName)
|
||||
if err != nil {
|
||||
return nil, wrapProviderErr("Google Vertex", "model", err)
|
||||
}
|
||||
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSDKDefaultBaseURL_CoversAllWireMappedPackages enforces the invariant
|
||||
// that every npm package recognised by the auto-router has a corresponding
|
||||
// default base URL — otherwise a provider that omits its `api` field in the
|
||||
// registry would silently fail to route at runtime.
|
||||
func TestSDKDefaultBaseURL_CoversAllWireMappedPackages(t *testing.T) {
|
||||
for npm := range npmToWireProtocol {
|
||||
// @ai-sdk/openai-compatible is a wire family, not a single SDK with
|
||||
// a default URL — providers using it always supply their own `api`.
|
||||
if npm == "@ai-sdk/openai-compatible" {
|
||||
continue
|
||||
}
|
||||
if _, ok := sdkDefaultBaseURL[npm]; !ok {
|
||||
t.Errorf("npm %q is in npmToWireProtocol but has no sdkDefaultBaseURL entry — "+
|
||||
"providers using this npm with no `api` field cannot be routed", npm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestSDKDefaultBaseURL_AllURLsAreAbsolute sanity-checks that every default
|
||||
// URL is a well-formed absolute https endpoint (catches typos in the table).
|
||||
func TestSDKDefaultBaseURL_AllURLsAreAbsolute(t *testing.T) {
|
||||
for npm, url := range sdkDefaultBaseURL {
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
t.Errorf("sdkDefaultBaseURL[%q] = %q is not an absolute https URL", npm, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_RegistryFirst verifies that the registry's `api`
|
||||
// field wins over any SDK default.
|
||||
func TestResolveProviderBaseURL_RegistryFirst(t *testing.T) {
|
||||
// xai is in the registry with no `api` field — its URL comes from the
|
||||
// SDK default. Use a synthetic registry-backed provider to test the
|
||||
// priority via the public registry instead.
|
||||
url, err := ResolveProviderBaseURL("openai")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL(openai): %v", err)
|
||||
}
|
||||
if url != "https://api.openai.com/v1" {
|
||||
t.Errorf("openai URL = %q, want https://api.openai.com/v1", url)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_SDKDefaultFallback verifies that providers
|
||||
// without an `api` field (groq, cerebras, xai, …) resolve to their SDK
|
||||
// hard-coded default URL.
|
||||
func TestResolveProviderBaseURL_SDKDefaultFallback(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"cerebras": "https://api.cerebras.ai/v1",
|
||||
"xai": "https://api.x.ai/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"perplexity": "https://api.perplexity.ai",
|
||||
"togetherai": "https://api.together.xyz/v1",
|
||||
"deepinfra": "https://api.deepinfra.com/v1/openai",
|
||||
"cohere": "https://api.cohere.com/compatibility/v1",
|
||||
"v0": "https://api.v0.dev/v1",
|
||||
"aihubmix": "https://aihubmix.com/v1",
|
||||
"venice": "https://api.venice.ai/api/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
for providerID, wantURL := range tests {
|
||||
t.Run(providerID, func(t *testing.T) {
|
||||
got, err := ResolveProviderBaseURL(providerID)
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL(%s): %v", providerID, err)
|
||||
}
|
||||
if got != wantURL {
|
||||
t.Errorf("%s URL = %q, want %q", providerID, got, wantURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_TemplatedURL_MissingEnv verifies that providers
|
||||
// whose URL contains "${VAR}" placeholders surface a targeted error when the
|
||||
// environment variables are unset.
|
||||
func TestResolveProviderBaseURL_TemplatedURL_MissingEnv(t *testing.T) {
|
||||
// cloudflare-workers-ai's api URL contains ${CLOUDFLARE_ACCOUNT_ID}.
|
||||
// Ensure the variable is unset for this test.
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
|
||||
t.Setenv("CF_ACCOUNT_ID", "")
|
||||
|
||||
_, err := ResolveProviderBaseURL("cloudflare-workers-ai")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unset CLOUDFLARE_ACCOUNT_ID, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "CLOUDFLARE_ACCOUNT_ID") {
|
||||
t.Errorf("error should name the missing env var, got: %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "--provider-url") {
|
||||
t.Errorf("error should suggest --provider-url override, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_TemplatedURL_Resolved verifies env-var
|
||||
// substitution succeeds when the placeholder is set.
|
||||
func TestResolveProviderBaseURL_TemplatedURL_Resolved(t *testing.T) {
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "test-acct-123")
|
||||
got, err := ResolveProviderBaseURL("cloudflare-workers-ai")
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveProviderBaseURL: %v", err)
|
||||
}
|
||||
if !strings.Contains(got, "test-acct-123") {
|
||||
t.Errorf("resolved URL %q should contain test-acct-123", got)
|
||||
}
|
||||
if strings.Contains(got, "${") {
|
||||
t.Errorf("resolved URL %q still contains template placeholder", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveProviderBaseURL_UnknownProvider verifies the not-in-registry error.
|
||||
func TestResolveProviderBaseURL_UnknownProvider(t *testing.T) {
|
||||
_, err := ResolveProviderBaseURL("does-not-exist")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown provider, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown provider") {
|
||||
t.Errorf("error should say 'unknown provider', got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutoRouteProvider_SDKDefaultURLFallback verifies that providers whose
|
||||
// registry entry omits the `api` field (groq, mistral, xai, etc.) are still
|
||||
// auto-routed by falling back to the SDK's hard-coded default URL.
|
||||
func TestAutoRouteProvider_SDKDefaultURLFallback(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
npmPackage string
|
||||
wantInURL string
|
||||
}{
|
||||
{"groq", "@ai-sdk/groq", "groq.com"},
|
||||
{"cerebras", "@ai-sdk/cerebras", "cerebras.ai"},
|
||||
{"xai", "@ai-sdk/xai", "x.ai"},
|
||||
{"mistral", "@ai-sdk/mistral", "mistral.ai"},
|
||||
{"v0", "@ai-sdk/vercel", "v0.dev"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reg := &ModelsRegistry{
|
||||
providers: map[string]ProviderInfo{
|
||||
"testfallback": {
|
||||
ID: "testfallback",
|
||||
Name: "Test Fallback",
|
||||
Env: []string{"TESTFALLBACK_API_KEY"},
|
||||
NPM: tt.npmPackage,
|
||||
// API intentionally omitted — must fall back to SDK default.
|
||||
Models: map[string]ModelInfo{
|
||||
"any-model": {ID: "any-model", Name: "any-model"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
config := &ProviderConfig{ProviderAPIKey: "test-key"}
|
||||
|
||||
result, err := autoRouteProvider(context.Background(), config, "testfallback", "any-model", reg)
|
||||
if err != nil {
|
||||
t.Fatalf("autoRouteProvider returned error: %v", err)
|
||||
}
|
||||
if result == nil || result.Model == nil {
|
||||
t.Fatal("autoRouteProvider returned nil model")
|
||||
}
|
||||
// Verify the SDK default URL was picked up.
|
||||
if !strings.Contains(config.ProviderURL, tt.wantInURL) {
|
||||
t.Errorf("config.ProviderURL = %q, want substring %q (SDK default)",
|
||||
config.ProviderURL, tt.wantInURL)
|
||||
}
|
||||
// All these wrappers route through the openai-compat wire.
|
||||
gotType := reflect.TypeOf(result.Model).String()
|
||||
if gotType != "openai.languageModel" {
|
||||
t.Errorf("model type = %q, want openai.languageModel", gotType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveTemplatedAPIURL_NoPlaceholders verifies that URLs without
|
||||
// placeholders are returned as-is (the caller keeps using the original).
|
||||
func TestResolveTemplatedAPIURL_NoPlaceholders(t *testing.T) {
|
||||
got, err := resolveTemplatedAPIURL("https://api.example.com/v1", &ProviderInfo{ID: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Errorf("got %q, want empty string for URL with no placeholders", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestResolveTemplatedAPIURL_AltEnvVar verifies that the alternative env-var
|
||||
// names (e.g. CF_ACCOUNT_ID for CLOUDFLARE_ACCOUNT_ID) are honoured.
|
||||
func TestResolveTemplatedAPIURL_AltEnvVar(t *testing.T) {
|
||||
t.Setenv("CLOUDFLARE_ACCOUNT_ID", "")
|
||||
t.Setenv("CF_ACCOUNT_ID", "alt-name-123")
|
||||
|
||||
got, err := resolveTemplatedAPIURL(
|
||||
"https://api.cloudflare.com/client/v4/accounts/${CLOUDFLARE_ACCOUNT_ID}/ai/v1",
|
||||
&ProviderInfo{ID: "cloudflare-workers-ai"},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !strings.Contains(got, "alt-name-123") {
|
||||
t.Errorf("resolved URL %q should have picked up CF_ACCOUNT_ID alternative", got)
|
||||
}
|
||||
}
|
||||
+39
-35
@@ -36,15 +36,17 @@ type Diagnostic struct {
|
||||
}
|
||||
|
||||
// LoadAll discovers and loads all prompt templates from standard locations
|
||||
// and any extra paths. Templates are loaded in order of precedence (lowest
|
||||
// to highest), with later templates overriding earlier ones of the same name.
|
||||
// and any extra paths. Templates are loaded in order of precedence (highest
|
||||
// to lowest); the first source to define a given name wins, later definitions
|
||||
// of the same name are dropped with a diagnostic.
|
||||
//
|
||||
// Discovery paths searched in order:
|
||||
// 1. Default templates (if IncludeDefaults)
|
||||
// 2. ~/.kit/prompts/ (global user templates)
|
||||
// 3. .kit/prompts/ (project-local templates)
|
||||
// 4. ConfigPaths (from configuration)
|
||||
// 5. ExtraPaths (explicit paths, highest precedence)
|
||||
// 2. ~/.kit/prompts/ (legacy global)
|
||||
// 3. $XDG_CONFIG_HOME/kit/prompts/ (XDG global, default ~/.config/kit/prompts/)
|
||||
// 4. <cwd>/.kit/prompts/ (project-local templates)
|
||||
// 5. ConfigPaths (from configuration)
|
||||
// 6. ExtraPaths (explicit paths, lowest precedence)
|
||||
func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
if opts.Cwd == "" {
|
||||
opts.Cwd, _ = os.Getwd()
|
||||
@@ -88,13 +90,21 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
addTemplates(defaults, "default")
|
||||
}
|
||||
|
||||
// 2. Global user templates: ~/.kit/prompts/
|
||||
globalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(globalDir); err == nil {
|
||||
// 2. Legacy global user templates: ~/.kit/prompts/
|
||||
legacyGlobalDir := filepath.Join(opts.HomeDir, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(legacyGlobalDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
|
||||
// 3. Project-local templates: .kit/prompts/
|
||||
// 3. XDG global user templates: $XDG_CONFIG_HOME/kit/prompts/
|
||||
// Default: ~/.config/kit/prompts/. Aligns with extensions and skills.
|
||||
if xdgDir := GlobalDir(); xdgDir != "" && xdgDir != legacyGlobalDir {
|
||||
if templates, err := LoadFromDir(xdgDir); err == nil {
|
||||
addTemplates(templates, "global")
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Project-local templates: .kit/prompts/
|
||||
localDir := filepath.Join(opts.Cwd, ".kit", "prompts")
|
||||
if templates, err := LoadFromDir(localDir); err == nil {
|
||||
addTemplates(templates, "local")
|
||||
@@ -179,31 +189,6 @@ func LoadFromDir(dir string) ([]*PromptTemplate, error) {
|
||||
return templates, nil
|
||||
}
|
||||
|
||||
// Deduplicate removes duplicate templates by name, keeping the first occurrence.
|
||||
// It returns the deduplicated list and diagnostics for any collisions.
|
||||
// This is a standalone function for when you need to deduplicate an existing list.
|
||||
func Deduplicate(templates []*PromptTemplate) ([]*PromptTemplate, []Diagnostic) {
|
||||
seen := make(map[string]*PromptTemplate)
|
||||
var result []*PromptTemplate
|
||||
var diagnostics []Diagnostic
|
||||
|
||||
for _, tpl := range templates {
|
||||
if existing, ok := seen[tpl.Name]; ok {
|
||||
diagnostics = append(diagnostics, Diagnostic{
|
||||
Name: tpl.Name,
|
||||
KeptPath: existing.FilePath,
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: "duplicate template name (first-match-wins)",
|
||||
})
|
||||
} else {
|
||||
seen[tpl.Name] = tpl
|
||||
result = append(result, tpl)
|
||||
}
|
||||
}
|
||||
|
||||
return result, diagnostics
|
||||
}
|
||||
|
||||
// loadDefaultTemplates returns the built-in default templates.
|
||||
// These are embedded templates that ship with Kit.
|
||||
func loadDefaultTemplates() []*PromptTemplate {
|
||||
@@ -211,3 +196,22 @@ func loadDefaultTemplates() []*PromptTemplate {
|
||||
// For now, return an empty slice - users can define their own templates
|
||||
return nil
|
||||
}
|
||||
|
||||
// GlobalDir returns the XDG-aligned global prompts directory, respecting
|
||||
// $XDG_CONFIG_HOME. Defaults to ~/.config/kit/prompts/. Returns an empty
|
||||
// string if the user's home directory cannot be resolved.
|
||||
//
|
||||
// This is the canonical location for user-wide prompt templates and aligns
|
||||
// with the discovery paths used for extensions ($XDG_CONFIG_HOME/kit/extensions/)
|
||||
// and skills ($XDG_CONFIG_HOME/kit/skills/).
|
||||
func GlobalDir() string {
|
||||
base := os.Getenv("XDG_CONFIG_HOME")
|
||||
if base == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
base = filepath.Join(home, ".config")
|
||||
}
|
||||
return filepath.Join(base, "kit", "prompts")
|
||||
}
|
||||
|
||||
@@ -70,7 +70,8 @@ func ParseTemplate(path string) (*PromptTemplate, error) {
|
||||
}
|
||||
|
||||
// ParseCommandArgs splits a command line into arguments respecting quotes.
|
||||
// It handles single quotes, double quotes, and backslash escaping.
|
||||
// It handles single quotes, double quotes, backslash escaping, and splits on
|
||||
// spaces and tabs.
|
||||
func ParseCommandArgs(input string) []string {
|
||||
var args []string
|
||||
var current strings.Builder
|
||||
@@ -78,7 +79,7 @@ func ParseCommandArgs(input string) []string {
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
|
||||
for i, r := range input {
|
||||
for _, r := range input {
|
||||
if escaped {
|
||||
current.WriteRune(r)
|
||||
escaped = false
|
||||
@@ -101,7 +102,7 @@ func ParseCommandArgs(input string) []string {
|
||||
continue
|
||||
}
|
||||
|
||||
if r == ' ' && !inSingleQuote && !inDoubleQuote {
|
||||
if (r == ' ' || r == '\t') && !inSingleQuote && !inDoubleQuote {
|
||||
if current.Len() > 0 {
|
||||
args = append(args, current.String())
|
||||
current.Reset()
|
||||
@@ -110,7 +111,6 @@ func ParseCommandArgs(input string) []string {
|
||||
}
|
||||
|
||||
current.WriteRune(r)
|
||||
_ = i // silence unused warning when we need position later
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
@@ -325,8 +325,3 @@ func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
args := ParseCommandArgs(argsInput)
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
// ExpandWithArgs substitutes the provided arguments into the template content.
|
||||
func (t *PromptTemplate) ExpandWithArgs(args []string) string {
|
||||
return SubstituteArgs(t.Content, args)
|
||||
}
|
||||
|
||||
@@ -129,26 +129,35 @@ func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
// BuildContext should return: [summary] + [M3 (kept)] + [M4 (new after compaction)]
|
||||
// Kept messages must appear BEFORE post-compaction messages so the LLM
|
||||
// sees the conversation in chronological order. Otherwise the latest
|
||||
// post-compaction user message would be followed by an older kept user
|
||||
// message, breaking user/assistant alternation and causing the model to
|
||||
// respond as if the post-compaction turn never happened.
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
t.Fatalf("expected 3 messages (summary + M3 + M4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
// Verify order: summary, M3 (kept), M4 (new)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3 kept), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4 post-compact), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEncodeCwdForDir verifies the working-directory → session-directory
|
||||
// name encoding strips characters that are illegal on Windows (notably the
|
||||
// drive-letter colon, see issue #18) while preserving the previous output
|
||||
// for the typical Unix paths.
|
||||
func TestEncodeCwdForDir(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cwd string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "unix absolute path",
|
||||
cwd: "/home/user/proj",
|
||||
want: "home--user--proj",
|
||||
},
|
||||
{
|
||||
name: "unix relative path",
|
||||
cwd: "proj/sub",
|
||||
want: "proj--sub",
|
||||
},
|
||||
{
|
||||
name: "windows drive root",
|
||||
cwd: `C:\test`,
|
||||
want: "C--test",
|
||||
},
|
||||
{
|
||||
name: "windows nested path",
|
||||
cwd: `C:\Users\User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows secondary drive",
|
||||
cwd: `S:\work\repo`,
|
||||
want: "S--work--repo",
|
||||
},
|
||||
{
|
||||
name: "windows mixed separators",
|
||||
cwd: `C:\Users/User\code`,
|
||||
want: "C--Users--User--code",
|
||||
},
|
||||
{
|
||||
name: "windows other illegal chars stripped",
|
||||
cwd: `C:\a<b>c|d?e*f"g`,
|
||||
want: "C--abcdefg",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := encodeCwdForDir(tc.cwd)
|
||||
if got != tc.want {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q, want %q", tc.cwd, got, tc.want)
|
||||
}
|
||||
// Encoded directory must never contain characters that are
|
||||
// illegal in Windows directory names.
|
||||
for _, bad := range []string{":", "<", ">", "\"", "|", "?", "*", "\\", "/"} {
|
||||
if strings.Contains(got, bad) {
|
||||
t.Errorf("encodeCwdForDir(%q) = %q contains illegal char %q", tc.cwd, got, bad)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -97,6 +99,11 @@ func ListAllSessions() ([]SessionInfo, error) {
|
||||
|
||||
// listSessionsInDir reads all .jsonl files in a directory and extracts session info.
|
||||
// Empty sessions (no messages) are automatically cleaned up and not returned.
|
||||
//
|
||||
// Per-file extraction is parallelized across a small worker pool because each
|
||||
// file requires a full JSONL scan to compute MessageCount and FirstMessage —
|
||||
// for users with many sessions this is the dominant cost of opening the
|
||||
// session picker.
|
||||
func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
@@ -107,20 +114,47 @@ func listSessionsInDir(dir string) ([]SessionInfo, error) {
|
||||
return nil, fmt.Errorf("failed to read directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
var sessions []SessionInfo
|
||||
// Collect candidate paths first so we can parallelize the heavy work.
|
||||
paths := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".jsonl") {
|
||||
continue
|
||||
}
|
||||
paths = append(paths, filepath.Join(dir, entry.Name()))
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, entry.Name())
|
||||
info, err := extractSessionInfo(path)
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
results := make([]*SessionInfo, len(paths))
|
||||
|
||||
// Worker pool sized to GOMAXPROCS, capped to avoid thrashing for tiny lists.
|
||||
workers := max(min(runtime.GOMAXPROCS(0), len(paths)), 1)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan int, len(paths))
|
||||
for range workers {
|
||||
wg.Go(func() {
|
||||
for i := range jobs {
|
||||
info, err := extractSessionInfo(paths[i])
|
||||
if err != nil {
|
||||
continue // skip malformed session files
|
||||
}
|
||||
results[i] = info
|
||||
}
|
||||
})
|
||||
}
|
||||
for i := range paths {
|
||||
jobs <- i
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
|
||||
sessions := make([]SessionInfo, 0, len(results))
|
||||
for i, info := range results {
|
||||
if info == nil {
|
||||
continue
|
||||
}
|
||||
// Clean up and skip empty sessions (no messages)
|
||||
// Clean up and skip empty sessions (no messages).
|
||||
if info.MessageCount == 0 {
|
||||
_ = os.Remove(path)
|
||||
_ = os.Remove(paths[i])
|
||||
continue
|
||||
}
|
||||
sessions = append(sessions, *info)
|
||||
|
||||
@@ -458,11 +458,6 @@ func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
func (tm *TreeManager) AppendModelChange(provider, modelID string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
@@ -755,9 +750,17 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
// If there is a compaction, inject the summary first, then the
|
||||
// preserved "kept" messages (chronologically before the compaction),
|
||||
// then the post-compaction messages (chronologically after).
|
||||
//
|
||||
// Order matters: the kept messages must come BEFORE the post-compaction
|
||||
// branch so the LLM sees the conversation in chronological order. If the
|
||||
// kept messages were appended last, the latest user message in the
|
||||
// current branch would be followed by an older kept user message,
|
||||
// breaking the strict user/assistant alternation that providers expect
|
||||
// and causing the model to respond as if the previous turn never
|
||||
// happened.
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -768,49 +771,10 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
})
|
||||
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
// Step 1: collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not on the current branch (the compaction entry is a
|
||||
// new root with no parent), so we iterate tm.entries in append order
|
||||
// and stop when we reach the compaction entry itself.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
@@ -825,13 +789,12 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
// Stop when we reach the compaction entry itself; messages
|
||||
// after it are collected from the branch walk below.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
@@ -860,6 +823,42 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: collect entries on the current branch after the compaction
|
||||
// entry (these are post-compaction messages). The compaction entry
|
||||
// itself is skipped — its summary was already injected above.
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Summary already injected above.
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
@@ -1030,44 +1029,22 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
// If there's a compaction, we collect IDs in the same order as
|
||||
// BuildContext: [summary placeholder, kept messages, post-compaction
|
||||
// messages]. This ordering must stay in sync with BuildContext so a
|
||||
// cut-point index can be mapped back to the correct entry ID.
|
||||
if lastCompaction != nil {
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
// Step 1: IDs of the kept messages starting at FirstKeptEntryID.
|
||||
// Iterate tm.entries in append order and stop at the compaction
|
||||
// entry to avoid double-counting post-compaction messages.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
@@ -1076,7 +1053,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
@@ -1100,6 +1076,28 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: IDs of entries after the compaction entry on the current
|
||||
// branch (post-compaction messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
@@ -1167,11 +1165,6 @@ func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
return tm.flushLocked()
|
||||
}
|
||||
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
@@ -1179,11 +1172,6 @@ func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
@@ -1350,15 +1338,44 @@ func (tm *TreeManager) buildTreeNodeDepth(id string, depth int, visited map[stri
|
||||
// --- Path conventions ---
|
||||
|
||||
// DefaultSessionDir returns the default session storage directory for a cwd.
|
||||
// Convention: ~/.kit/sessions/--<cwd-path>--/
|
||||
// Convention: ~/.kit/sessions/<encoded-cwd>, where path separators are
|
||||
// encoded as "--" with no leading or trailing dashes — e.g.
|
||||
// /home/user/proj becomes home--user--proj. See encodeCwdForDir for the
|
||||
// full encoding rules (including Windows path handling).
|
||||
func DefaultSessionDir(cwd string) string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
home = "."
|
||||
}
|
||||
// Convert path separators to double dashes.
|
||||
safeCwd := strings.ReplaceAll(cwd, string(filepath.Separator), "--")
|
||||
return filepath.Join(home, ".kit", "sessions", encodeCwdForDir(cwd))
|
||||
}
|
||||
|
||||
// encodeCwdForDir converts a working-directory path into a single, filesystem-
|
||||
// safe directory name. Path separators are replaced with double dashes and
|
||||
// characters that are illegal in Windows directory names — most importantly
|
||||
// the colon that follows the drive letter (e.g. `C:\foo` → `C--foo`) — are
|
||||
// stripped. The result is identical to the previous Unix-only encoding for
|
||||
// paths that do not contain such characters, so existing session directories
|
||||
// are preserved.
|
||||
func encodeCwdForDir(cwd string) string {
|
||||
// Convert both `/` and `\` to double dashes so encoding is stable across
|
||||
// platforms and remains correct on Windows where `filepath.Separator`
|
||||
// would otherwise miss forward-slash style paths.
|
||||
safeCwd := strings.ReplaceAll(cwd, "\\", "--")
|
||||
safeCwd = strings.ReplaceAll(safeCwd, "/", "--")
|
||||
// Remove leading separator replacement.
|
||||
safeCwd = strings.TrimPrefix(safeCwd, "--")
|
||||
return filepath.Join(home, ".kit", "sessions", safeCwd)
|
||||
// Strip characters that are illegal in directory names on Windows
|
||||
// (`< > : " | ? *`). On Unix these characters are legal but rare in
|
||||
// practice; stripping them keeps the encoding portable.
|
||||
replacer := strings.NewReplacer(
|
||||
":", "",
|
||||
"<", "",
|
||||
">", "",
|
||||
"\"", "",
|
||||
"|", "",
|
||||
"?", "",
|
||||
"*", "",
|
||||
)
|
||||
return replacer.Replace(safeCwd)
|
||||
}
|
||||
|
||||
@@ -18,8 +18,11 @@ type PromptTemplate struct {
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// variableRe matches {{variable_name}} placeholders.
|
||||
var variableRe = regexp.MustCompile(`\{\{(\w+)\}\}`)
|
||||
// variableRe matches {{variable_name}} placeholders, tolerating surrounding
|
||||
// whitespace inside the braces (e.g. {{ name }}). This is the canonical
|
||||
// template grammar shared by skill prompts and the extension template API
|
||||
// (pkg/kit ParseTemplate/RenderTemplate delegate here).
|
||||
var variableRe = regexp.MustCompile(`\{\{\s*(\w+)\s*\}\}`)
|
||||
|
||||
// NewPromptTemplate creates a PromptTemplate, automatically extracting
|
||||
// variable names from {{...}} placeholders in content.
|
||||
@@ -50,11 +53,13 @@ func LoadPromptTemplate(path string) (*PromptTemplate, error) {
|
||||
// Expand replaces all {{variable}} placeholders with values from the
|
||||
// provided map. Missing variables are left as-is (no error).
|
||||
func (t *PromptTemplate) Expand(values map[string]string) string {
|
||||
result := t.Content
|
||||
for k, v := range values {
|
||||
result = strings.ReplaceAll(result, "{{"+k+"}}", v)
|
||||
}
|
||||
return result
|
||||
return variableRe.ReplaceAllStringFunc(t.Content, func(m string) string {
|
||||
name := variableRe.FindStringSubmatch(m)[1]
|
||||
if v, ok := values[name]; ok {
|
||||
return v
|
||||
}
|
||||
return m
|
||||
})
|
||||
}
|
||||
|
||||
// ExpandStrict replaces all {{variable}} placeholders and returns an error
|
||||
|
||||
@@ -345,49 +345,70 @@ func (p *MCPConnectionPool) createStdioClient(ctx context.Context, serverConfig
|
||||
return stdioClient, nil
|
||||
}
|
||||
|
||||
// createSSEClient creates an SSE client
|
||||
// parseHeaders parses "Key: Value" header strings into a map.
|
||||
func parseHeaders(raw []string) map[string]string {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
headers := make(map[string]string)
|
||||
for _, header := range raw {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) == 0 {
|
||||
return nil
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
// buildOAuthConfig constructs a transport.OAuthConfig from the server config
|
||||
// and the pool's OAuth flow. Returns nil if OAuth is not applicable.
|
||||
func (p *MCPConnectionPool) buildOAuthConfig(serverConfig config.MCPServerConfig) (*transport.OAuthConfig, error) {
|
||||
if p.oauthFlow == nil || serverConfig.NoOAuth {
|
||||
return nil, nil
|
||||
}
|
||||
tokenStore, err := p.createTokenStore(serverConfig.URL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", err)
|
||||
}
|
||||
cfg := &transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
cfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
cfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
cfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.ClientOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth. Public MCP servers (e.g.
|
||||
// PubMed) set NoOAuth to skip dynamic client registration and token
|
||||
// exchange, which would otherwise fail with a 404.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
@@ -406,43 +427,18 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
var options []transport.StreamableHTTPCOption
|
||||
|
||||
if len(serverConfig.Headers) > 0 {
|
||||
headers := make(map[string]string)
|
||||
for _, header := range serverConfig.Headers {
|
||||
parts := strings.SplitN(header, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := strings.TrimSpace(parts[1])
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
if headers := parseHeaders(serverConfig.Headers); headers != nil {
|
||||
options = append(options, transport.WithHTTPHeaders(headers))
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured
|
||||
// and the server hasn't opted out via NoOAuth.
|
||||
if p.oauthFlow != nil && !serverConfig.NoOAuth {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
oauthCfg, err := p.buildOAuthConfig(serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if oauthCfg != nil {
|
||||
options = append(options, transport.WithHTTPOAuth(*oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
|
||||
+60
-57
@@ -641,30 +641,16 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: callParams,
|
||||
}
|
||||
result, callErr := conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var result *mcp.CallToolResult
|
||||
err := m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
result, callErr = conn.client.CallTool(ctx, callRequest)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
// Task-augmented path. Bypass the upstream CallTool helper because its
|
||||
@@ -683,40 +669,25 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(result)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{Content: string(marshaledResult), IsError: result.IsError}, nil
|
||||
return marshalToolResult(result)
|
||||
}
|
||||
|
||||
callResult, taskResult, callErr := callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, mapping.serverName, callErr); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth re-authorization failed for tool %s: %w", mapping.originalName, flowErr)
|
||||
}
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
if callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
} else {
|
||||
m.connectionPool.HandleConnectionError(mapping.serverName, callErr)
|
||||
return nil, fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
var (
|
||||
callResult *mcp.CallToolResult
|
||||
taskResult *mcp.CreateTaskResult
|
||||
)
|
||||
err = m.withOAuthRetry(ctx, mapping.serverName, mapping.originalName, func() error {
|
||||
var callErr error
|
||||
callResult, taskResult, callErr = callToolWithTask(ctx, rawClient, callParams)
|
||||
return callErr
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Server chose to answer synchronously — same shape as the no-task path.
|
||||
if callResult != nil {
|
||||
marshaledResult, mErr := json.Marshal(callResult)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: callResult.IsError,
|
||||
}, nil
|
||||
return marshalToolResult(callResult)
|
||||
}
|
||||
|
||||
// Asynchronous task path: poll until terminal, then return the result.
|
||||
@@ -732,18 +703,50 @@ func (m *MCPToolManager) ExecuteTool(ctx context.Context, prefixedName, inputJSO
|
||||
}
|
||||
|
||||
// Adapt TaskResultResult → CallToolResult for downstream JSON shape parity.
|
||||
adapted := &mcp.CallToolResult{
|
||||
return marshalToolResult(&mcp.CallToolResult{
|
||||
Content: final.Content,
|
||||
StructuredContent: final.StructuredContent,
|
||||
IsError: final.IsError,
|
||||
})
|
||||
}
|
||||
|
||||
// withOAuthRetry runs call once; when it fails with an OAuth error and an
|
||||
// OAuth flow is configured, it re-authorizes the server and retries once.
|
||||
// Connection failures are reported to the pool and wrapped uniformly. This
|
||||
// consolidates the retry/error chain shared by the synchronous and
|
||||
// task-augmented tool-call paths.
|
||||
func (m *MCPToolManager) withOAuthRetry(ctx context.Context, serverName, toolName string, call func() error) error {
|
||||
callErr := call()
|
||||
if callErr == nil {
|
||||
return nil
|
||||
}
|
||||
marshaledResult, mErr := json.Marshal(adapted)
|
||||
if mErr != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", mErr)
|
||||
if m.connectionPool.oauthFlow != nil && IsOAuthError(callErr) {
|
||||
if flowErr := m.connectionPool.oauthFlow.RunAuthFlow(ctx, serverName, callErr); flowErr != nil {
|
||||
return fmt.Errorf("OAuth re-authorization failed for tool %s: %w", toolName, flowErr)
|
||||
}
|
||||
if callErr = call(); callErr != nil {
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool after re-auth: %w", callErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
m.connectionPool.HandleConnectionError(serverName, callErr)
|
||||
return fmt.Errorf("failed to call mcp tool: %w", callErr)
|
||||
}
|
||||
|
||||
// marshalToolResult converts an MCP CallToolResult into the JSON-encoded
|
||||
// MCPToolResult shape returned to the agent.
|
||||
func marshalToolResult(result *mcp.CallToolResult) (*MCPToolResult, error) {
|
||||
if result == nil {
|
||||
return nil, errors.New("mcp tool call returned nil result")
|
||||
}
|
||||
marshaled, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
return &MCPToolResult{
|
||||
Content: string(marshaledResult),
|
||||
IsError: final.IsError,
|
||||
Content: string(marshaled),
|
||||
IsError: result.IsError,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -28,15 +28,6 @@ type blockRenderer struct {
|
||||
// renderingOption configures block rendering
|
||||
type renderingOption func(*blockRenderer)
|
||||
|
||||
// WithFullWidth returns a renderingOption that configures the block renderer
|
||||
// to expand to the full available width of its container. When enabled, the
|
||||
// block will fill the entire horizontal space rather than sizing to its content.
|
||||
func WithFullWidth() renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.fullWidth = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithNoBorder returns a renderingOption that disables all borders on the
|
||||
// block, rendering content with only padding.
|
||||
func WithNoBorder() renderingOption {
|
||||
@@ -63,15 +54,6 @@ func WithBorderColor(c color.Color) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginTop returns a renderingOption that sets the top margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space above the block.
|
||||
func WithMarginTop(margin int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.marginTop = margin
|
||||
}
|
||||
}
|
||||
|
||||
// WithMarginBottom returns a renderingOption that sets the bottom margin
|
||||
// for the block. The margin is specified in number of lines and adds
|
||||
// vertical space below the block.
|
||||
@@ -81,24 +63,6 @@ func WithMarginBottom(margin int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingLeft returns a renderingOption that sets the left padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the left border and the content.
|
||||
func WithPaddingLeft(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingLeft = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingRight returns a renderingOption that sets the right padding
|
||||
// for the block content. The padding is specified in number of characters
|
||||
// and adds horizontal space between the content and the right border.
|
||||
func WithPaddingRight(padding int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.paddingRight = padding
|
||||
}
|
||||
}
|
||||
|
||||
// WithPaddingTop returns a renderingOption that sets the top padding
|
||||
// for the block content. The padding is specified in number of lines
|
||||
// and adds vertical space between the top border and the content.
|
||||
@@ -117,33 +81,6 @@ func WithPaddingBottom(padding int) renderingOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithBackground returns a renderingOption that sets the background color
|
||||
// for the entire block. The color parameter accepts any color.Color value,
|
||||
// typically a lipgloss hex color (e.g. lipgloss.Color("#1e1e2e")).
|
||||
func WithBackground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.background = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithForeground returns a renderingOption that overrides the default text
|
||||
// foreground color (theme.Text) for the block. Useful for muted or
|
||||
// de-emphasized content blocks.
|
||||
func WithForeground(c color.Color) renderingOption {
|
||||
return func(br *blockRenderer) {
|
||||
br.foreground = &c
|
||||
}
|
||||
}
|
||||
|
||||
// WithWidth returns a renderingOption that sets a specific width for the block
|
||||
// in characters. This overrides the default container width and allows precise
|
||||
// control over the block's horizontal dimensions.
|
||||
func WithWidth(width int) renderingOption {
|
||||
return func(c *blockRenderer) {
|
||||
c.width = width
|
||||
}
|
||||
}
|
||||
|
||||
// renderContentBlock renders content with configurable styling options
|
||||
func renderContentBlock(content string, containerWidth int, options ...renderingOption) string {
|
||||
renderer := &blockRenderer{
|
||||
|
||||
@@ -54,12 +54,6 @@ func (c *CLI) GetUsageTracker() *UsageTracker {
|
||||
return c.usageTracker
|
||||
}
|
||||
|
||||
// GetDebugLogger returns a CLIDebugLogger instance that routes debug output
|
||||
// through the CLI's rendering system for consistent message formatting and display.
|
||||
func (c *CLI) GetDebugLogger() *CLIDebugLogger {
|
||||
return NewCLIDebugLogger(c)
|
||||
}
|
||||
|
||||
// SetModelName updates the current AI model name being used in the conversation.
|
||||
// This name is displayed in message headers to indicate which model is responding.
|
||||
func (c *CLI) SetModelName(modelName string) {
|
||||
@@ -87,13 +81,6 @@ func (c *CLI) DisplayUserMessage(message string) {
|
||||
fmt.Println(c.renderer.RenderUserMessage(message, time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayAssistantMessage renders and displays an AI assistant's response message
|
||||
// with appropriate formatting. This method delegates to DisplayAssistantMessageWithModel
|
||||
// with an empty model name for backward compatibility.
|
||||
func (c *CLI) DisplayAssistantMessage(message string) error {
|
||||
return c.DisplayAssistantMessageWithModel(message, "")
|
||||
}
|
||||
|
||||
// DisplayAssistantMessageWithModel renders and displays an AI assistant's response
|
||||
// with the specified model name shown in the message header. The message is
|
||||
// formatted according to the current display mode and includes timestamp information.
|
||||
@@ -149,12 +136,6 @@ func (c *CLI) DisplayExtensionBlock(text, borderColor, subtitle string) {
|
||||
fmt.Println(rendered)
|
||||
}
|
||||
|
||||
// DisplayCancellation displays a system message indicating that the current
|
||||
// AI generation has been cancelled by the user (typically via ESC key).
|
||||
func (c *CLI) DisplayCancellation() {
|
||||
fmt.Println(c.renderer.RenderSystemMessage("Generation cancelled by user (ESC pressed)", time.Now()).Content)
|
||||
}
|
||||
|
||||
// DisplayDebugMessage renders and displays a debug message if debug mode is enabled.
|
||||
// Debug messages are formatted distinctively and only shown when the CLI is
|
||||
// initialized with debug=true.
|
||||
|
||||
@@ -161,6 +161,27 @@ var SlashCommands = []SlashCommand{
|
||||
Category: "Navigation",
|
||||
Aliases: []string{"/r"},
|
||||
},
|
||||
{
|
||||
Name: "/copy",
|
||||
Description: "Copy the last message to the system clipboard",
|
||||
Category: "System",
|
||||
Aliases: []string{"/cp"},
|
||||
},
|
||||
{
|
||||
Name: "/retry",
|
||||
Description: "Resubmit the last user message (e.g. after a provider error)",
|
||||
Category: "System",
|
||||
Aliases: []string{"/rt"},
|
||||
},
|
||||
{
|
||||
Name: "/edit",
|
||||
Description: "Open a file in $EDITOR (fuzzy-find a path, then edit)",
|
||||
Category: "System",
|
||||
Aliases: []string{"/ed"},
|
||||
HasArgs: true,
|
||||
// Note: no Complete callback — file fuzzy-finding is driven directly
|
||||
// by InputComponent (mirroring the @file popup with directory drill).
|
||||
},
|
||||
{
|
||||
Name: "/export",
|
||||
Description: "Export session (JSONL by default, or /export path.jsonl)",
|
||||
@@ -199,18 +220,6 @@ func GetCommandByName(name string) *SlashCommand {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllCommandNames returns a complete list of all command names and their aliases.
|
||||
// This is useful for command completion, validation, and help display. The returned
|
||||
// slice contains both primary command names and all alternative aliases.
|
||||
func GetAllCommandNames() []string {
|
||||
var names []string
|
||||
for _, cmd := range SlashCommands {
|
||||
names = append(names, cmd.Name)
|
||||
names = append(names, cmd.Aliases...)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// ExtensionCommand is a slash command registered by an extension. Unlike
|
||||
// built-in SlashCommands whose execution is hardcoded in handleSlashCommand,
|
||||
// extension commands carry their own Execute callback.
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CLIDebugLogger implements the tools.DebugLogger interface using CLI rendering.
|
||||
// It provides debug logging functionality that integrates with the CLI's display
|
||||
// system, ensuring debug messages are properly formatted and displayed alongside
|
||||
// other conversation content.
|
||||
type CLIDebugLogger struct {
|
||||
cli *CLI
|
||||
}
|
||||
|
||||
// NewCLIDebugLogger creates and returns a new CLIDebugLogger instance that routes
|
||||
// debug output through the provided CLI instance. The logger will respect the CLI's
|
||||
// debug mode setting and display format preferences.
|
||||
func NewCLIDebugLogger(cli *CLI) *CLIDebugLogger {
|
||||
return &CLIDebugLogger{cli: cli}
|
||||
}
|
||||
|
||||
// LogDebug processes and displays a debug message through the CLI's rendering system.
|
||||
// Messages are formatted with appropriate emojis and tags based on their content type
|
||||
// (DEBUG, POOL, etc.) and only displayed when debug mode is enabled. The method handles
|
||||
// multi-line debug output and connection pool status messages with context-aware formatting.
|
||||
func (l *CLIDebugLogger) LogDebug(message string) {
|
||||
if l.cli == nil || !l.cli.debug {
|
||||
return
|
||||
}
|
||||
|
||||
// Format the message to include all the debug info in a structured way
|
||||
var formattedMessage string
|
||||
|
||||
// Check if this is a multi-line debug output (like connection info)
|
||||
if strings.Contains(message, "[DEBUG]") || strings.Contains(message, "[POOL]") {
|
||||
// Extract the tag and content
|
||||
if after, ok := strings.CutPrefix(message, "[DEBUG]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
formattedMessage = fmt.Sprintf("🔍 DEBUG: %s", content)
|
||||
} else if after, ok := strings.CutPrefix(message, "[POOL]"); ok {
|
||||
content := after
|
||||
content = strings.TrimSpace(content)
|
||||
|
||||
// Add appropriate emoji based on the message content
|
||||
if strings.Contains(content, "Creating new connection") {
|
||||
formattedMessage = fmt.Sprintf("🆕 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Created connection") || strings.Contains(content, "Initialized") {
|
||||
formattedMessage = fmt.Sprintf("✅ POOL: %s", content)
|
||||
} else if strings.Contains(content, "Reusing") {
|
||||
formattedMessage = fmt.Sprintf("🔄 POOL: %s", content)
|
||||
} else if strings.Contains(content, "unhealthy") || strings.Contains(content, "failed") {
|
||||
formattedMessage = fmt.Sprintf("❌ POOL: %s", content)
|
||||
} else if strings.Contains(content, "closed") {
|
||||
formattedMessage = fmt.Sprintf("🛑 POOL: %s", content)
|
||||
} else if strings.Contains(content, "Failed to close") {
|
||||
formattedMessage = fmt.Sprintf("⚠️ POOL: %s", content)
|
||||
} else {
|
||||
formattedMessage = fmt.Sprintf("🔍 POOL: %s", content)
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
} else {
|
||||
formattedMessage = message
|
||||
}
|
||||
|
||||
// Use the CLI's debug message rendering
|
||||
fmt.Println(l.cli.renderer.RenderDebugMessage(formattedMessage, time.Now()).Content)
|
||||
}
|
||||
|
||||
// IsDebugEnabled checks whether debug logging is currently active. Returns true
|
||||
// if the CLI instance exists and has debug mode enabled, allowing callers to
|
||||
// conditionally perform expensive debug operations only when necessary.
|
||||
func (l *CLIDebugLogger) IsDebugEnabled() bool {
|
||||
return l.cli != nil && l.cli.debug
|
||||
}
|
||||
+29
-35
@@ -2,7 +2,6 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
@@ -44,28 +43,39 @@ func parseModelName(modelString string) (provider, model string) {
|
||||
// ollama or unrecognised models). This is used by the interactive TUI path
|
||||
// which doesn't go through SetupCLI.
|
||||
func CreateUsageTracker(modelString, providerAPIKey string) *UsageTracker {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry := models.GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel(provider, model)
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(providerAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
return NewUsageTracker(modelInfo, provider, 80, isOAuth)
|
||||
}
|
||||
|
||||
// UpdateUsageTrackerForModel refreshes an existing tracker after a model
|
||||
// switch so token counting and cost reporting use the new model's metadata.
|
||||
// No-op for a nil tracker or untrackable models (unknown/ollama).
|
||||
func UpdateUsageTrackerForModel(t *UsageTracker, modelString, providerAPIKey string) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
modelInfo, provider := lookupTrackableModel(modelString)
|
||||
if modelInfo == nil {
|
||||
return
|
||||
}
|
||||
isOAuth := provider == "anthropic" && auth.IsAnthropicOAuth(providerAPIKey)
|
||||
t.UpdateModelInfo(modelInfo, provider, isOAuth)
|
||||
}
|
||||
|
||||
// lookupTrackableModel resolves a model string to registry metadata, returning
|
||||
// nil for models without usage tracking support (unknown or ollama models).
|
||||
func lookupTrackableModel(modelString string) (*models.ModelInfo, string) {
|
||||
provider, model := parseModelName(modelString)
|
||||
if provider == "unknown" || model == "unknown" || provider == "ollama" {
|
||||
return nil, provider
|
||||
}
|
||||
return models.GetGlobalRegistry().LookupModel(provider, model), provider
|
||||
}
|
||||
|
||||
// SetupCLI creates, configures, and initializes a CLI instance with the provided
|
||||
// options. It sets up model display, usage tracking for supported providers, and
|
||||
// shows initial loading information. Returns nil in quiet mode or an initialized
|
||||
@@ -89,24 +99,8 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
|
||||
// Set up usage tracking for supported providers
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
// Skip usage tracking for ollama as it's not in models.dev
|
||||
if provider != "ollama" {
|
||||
registry := models.GetGlobalRegistry()
|
||||
if modelInfo := registry.LookupModel(provider, model); modelInfo != nil {
|
||||
// Check if OAuth credentials are being used for Anthropic models
|
||||
isOAuth := false
|
||||
if provider == "anthropic" {
|
||||
_, source, err := auth.GetAnthropicAPIKey(opts.ProviderAPIKey)
|
||||
if err == nil && strings.HasPrefix(source, "stored OAuth") {
|
||||
isOAuth = true
|
||||
}
|
||||
}
|
||||
|
||||
usageTracker := NewUsageTracker(modelInfo, provider, 80, isOAuth) // Will be updated with actual width
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
}
|
||||
if usageTracker := CreateUsageTracker(opts.ModelString, opts.ProviderAPIKey); usageTracker != nil {
|
||||
cli.SetUsageTracker(usageTracker)
|
||||
}
|
||||
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
|
||||
@@ -125,6 +125,33 @@ func ExtractAtPrefix(line string, cursorCol int) (hasAt bool, prefix string, sta
|
||||
return true, raw, atIdx
|
||||
}
|
||||
|
||||
// editTriggerPrefixes lists the command tokens (including trailing space)
|
||||
// that activate the /edit fuzzy-file picker. Aliases come first so the
|
||||
// longer alias "/edit " is matched before a hypothetical superset.
|
||||
var editTriggerPrefixes = []string{"/edit ", "/ed "}
|
||||
|
||||
// ExtractEditPrefix detects when the input value is a single-line /edit (or
|
||||
// alias) invocation and returns the path-portion the user has typed so far.
|
||||
//
|
||||
// Returns:
|
||||
// - cmdLen: byte offset where the path argument begins (i.e. length of
|
||||
// the matched command token, including its trailing space)
|
||||
// - pathPrefix: text the user has typed after the command token
|
||||
// - ok: true when the value matches one of the /edit triggers
|
||||
//
|
||||
// Multi-line values never match — /edit only makes sense as a single line.
|
||||
func ExtractEditPrefix(value string) (cmdLen int, pathPrefix string, ok bool) {
|
||||
if strings.Contains(value, "\n") {
|
||||
return 0, "", false
|
||||
}
|
||||
for _, p := range editTriggerPrefixes {
|
||||
if strings.HasPrefix(value, p) {
|
||||
return len(p), value[len(p):], true
|
||||
}
|
||||
}
|
||||
return 0, "", false
|
||||
}
|
||||
|
||||
// GetFileSuggestions returns file/directory suggestions matching the given
|
||||
// prefix. It tries `git ls-files` first (fast, respects .gitignore), then
|
||||
// falls back to a simple directory walk.
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
// Package imagepreview renders low-resolution, in-terminal thumbnails of
|
||||
// images using Unicode upper half-block characters (U+2580, "▀") combined
|
||||
// with SGR foreground/background color codes.
|
||||
//
|
||||
// The technique stacks two vertical pixels into a single character cell: the
|
||||
// foreground color paints the top pixel and the background color paints the
|
||||
// bottom pixel. This produces pure styled text — no graphics escape sequences
|
||||
// — so the output survives terminal multiplexers (tmux, zellij) untouched.
|
||||
//
|
||||
// The Kitty graphics protocol, Sixel, and iTerm2 inline images are
|
||||
// deliberately NOT used: those are graphics escape-sequence protocols that
|
||||
// tmux and zellij strip or mangle by default.
|
||||
package imagepreview
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
// Register the standard image decoders so image.Decode can handle the
|
||||
// common clipboard / attachment formats.
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
|
||||
"github.com/charmbracelet/colorprofile"
|
||||
"github.com/charmbracelet/x/ansi"
|
||||
xdraw "golang.org/x/image/draw"
|
||||
)
|
||||
|
||||
// upperHalfBlock is U+2580 ("▀"). The glyph fills the top half of a cell,
|
||||
// letting the foreground color render the top pixel and the cell's background
|
||||
// color render the bottom pixel.
|
||||
const upperHalfBlock = "▀"
|
||||
|
||||
// reset is the SGR reset sequence appended after each rendered row.
|
||||
const reset = "\x1b[0m"
|
||||
|
||||
// maxImageDimension is the largest width or height, in pixels, that Render will
|
||||
// fully decode. Images larger than this in either axis are rejected before the
|
||||
// expensive image.Decode call to guard against decompression bombs (small
|
||||
// encoded payloads that expand to enormous pixel buffers).
|
||||
const maxImageDimension = 20000
|
||||
|
||||
// Render returns a half-block ANSI thumbnail of the image, scaled to fit
|
||||
// within maxCols x maxRows terminal cells while preserving aspect ratio.
|
||||
//
|
||||
// Each terminal cell encodes two vertically-stacked pixels, so the effective
|
||||
// pixel resolution of the thumbnail is up to maxCols x (maxRows*2).
|
||||
//
|
||||
// Colors are emitted at the fidelity of the detected terminal color profile:
|
||||
// truecolor (24-bit) when available, degrading to 256-color. When the
|
||||
// terminal supports neither (no truecolor and no 256-color), Render returns
|
||||
// an empty string and a nil error so the caller can fall back to a text
|
||||
// indicator. A non-nil error is only returned when the image data cannot be
|
||||
// decoded.
|
||||
//
|
||||
// bg is the color used to composite transparent pixels (typically the
|
||||
// terminal background). A nil bg defaults to black.
|
||||
func Render(data []byte, mediaType string, maxCols, maxRows int, bg color.Color) (string, error) {
|
||||
profile := colorprofile.Env(os.Environ())
|
||||
return renderWithProfile(data, maxCols, maxRows, bg, profile)
|
||||
}
|
||||
|
||||
// renderWithProfile is the testable core of Render. It accepts an explicit
|
||||
// color profile instead of detecting one from the environment.
|
||||
func renderWithProfile(data []byte, maxCols, maxRows int, bg color.Color, profile colorprofile.Profile) (string, error) {
|
||||
// Half-block fidelity needs at least 256-color support. Anything less
|
||||
// degrades to the caller's text fallback.
|
||||
if profile < colorprofile.ANSI256 {
|
||||
return "", nil
|
||||
}
|
||||
if maxCols < 1 || maxRows < 1 {
|
||||
return "", nil
|
||||
}
|
||||
if bg == nil {
|
||||
bg = color.Black
|
||||
}
|
||||
|
||||
// Guard against decompression bombs: inspect the header dimensions before
|
||||
// fully decoding, so a small malicious payload cannot expand into an
|
||||
// enormous pixel buffer.
|
||||
cfg, _, err := image.DecodeConfig(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode image config: %w", err)
|
||||
}
|
||||
if cfg.Width > maxImageDimension || cfg.Height > maxImageDimension {
|
||||
return "", fmt.Errorf("decode image: dimensions %dx%d exceed limit %d", cfg.Width, cfg.Height, maxImageDimension)
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode image: %w", err)
|
||||
}
|
||||
|
||||
// Target pixel dimensions: one pixel per column horizontally and two
|
||||
// pixels per row vertically (the half-block trick).
|
||||
cols, rows := fitDimensions(img.Bounds().Dx(), img.Bounds().Dy(), maxCols, maxRows)
|
||||
if cols < 1 || rows < 1 {
|
||||
return "", nil
|
||||
}
|
||||
pxW, pxH := cols, rows*2
|
||||
|
||||
scaled := image.NewRGBA(image.Rect(0, 0, pxW, pxH))
|
||||
xdraw.CatmullRom.Scale(scaled, scaled.Bounds(), img, img.Bounds(), xdraw.Over, nil)
|
||||
|
||||
var b strings.Builder
|
||||
for y := 0; y < pxH; y += 2 {
|
||||
for x := range pxW {
|
||||
top := composite(scaled.At(x, y), bg)
|
||||
bottom := composite(scaled.At(x, y+1), bg)
|
||||
b.WriteString(sgr(top, bottom, profile))
|
||||
b.WriteString(upperHalfBlock)
|
||||
}
|
||||
b.WriteString(reset)
|
||||
if y+2 < pxH {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
// fitDimensions returns the largest cell dimensions (cols, rows) that fit a
|
||||
// srcW x srcH image inside a maxCols x maxRows box while preserving aspect
|
||||
// ratio. Because each cell stacks two vertical pixels, a terminal cell is
|
||||
// treated as roughly twice as tall as it is wide, which keeps the thumbnail's
|
||||
// aspect ratio visually correct.
|
||||
func fitDimensions(srcW, srcH, maxCols, maxRows int) (cols, rows int) {
|
||||
if srcW <= 0 || srcH <= 0 {
|
||||
return 0, 0
|
||||
}
|
||||
// Work in pixel space: the box is maxCols wide and maxRows*2 tall.
|
||||
maxPxW := float64(maxCols)
|
||||
maxPxH := float64(maxRows * 2)
|
||||
scale := maxPxW / float64(srcW)
|
||||
if h := maxPxH / float64(srcH); h < scale {
|
||||
scale = h
|
||||
}
|
||||
if scale > 1 {
|
||||
scale = 1 // never upscale; keep the low-res look
|
||||
}
|
||||
pxW := int(float64(srcW) * scale)
|
||||
pxH := int(float64(srcH) * scale)
|
||||
if pxW < 1 {
|
||||
pxW = 1
|
||||
}
|
||||
if pxH < 2 {
|
||||
pxH = 2
|
||||
}
|
||||
// Convert back to cells; round the row count up to an even pixel height.
|
||||
cols = pxW
|
||||
rows = (pxH + 1) / 2
|
||||
if cols > maxCols {
|
||||
cols = maxCols
|
||||
}
|
||||
if rows > maxRows {
|
||||
rows = maxRows
|
||||
}
|
||||
return cols, rows
|
||||
}
|
||||
|
||||
// composite blends a (possibly translucent) pixel over the background color,
|
||||
// returning an opaque color. Fully opaque pixels are returned unchanged.
|
||||
func composite(c, bg color.Color) color.Color {
|
||||
r, g, b, a := c.RGBA()
|
||||
if a == 0xffff {
|
||||
return c
|
||||
}
|
||||
br, bgc, bb, _ := bg.RGBA()
|
||||
// Standard "over" alpha compositing in 16-bit space.
|
||||
inv := 0xffff - a
|
||||
out := color.RGBA64{
|
||||
R: uint16(r + br*inv/0xffff),
|
||||
G: uint16(g + bgc*inv/0xffff),
|
||||
B: uint16(b + bb*inv/0xffff),
|
||||
A: 0xffff,
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// sgr builds the SGR escape sequence that sets the foreground (top pixel) and
|
||||
// background (bottom pixel) colors at the fidelity of the given profile.
|
||||
func sgr(fg, bg color.Color, profile colorprofile.Profile) string {
|
||||
if profile >= colorprofile.TrueColor {
|
||||
fr, fgc, fb := rgb8(fg)
|
||||
br, bgc, bb := rgb8(bg)
|
||||
return fmt.Sprintf("\x1b[38;2;%d;%d;%d;48;2;%d;%d;%dm", fr, fgc, fb, br, bgc, bb)
|
||||
}
|
||||
return fmt.Sprintf("\x1b[38;5;%d;48;5;%dm", index256(fg, profile), index256(bg, profile))
|
||||
}
|
||||
|
||||
// rgb8 reduces a color to 8-bit RGB components.
|
||||
func rgb8(c color.Color) (r, g, b uint8) {
|
||||
cr, cg, cb, _ := c.RGBA()
|
||||
return uint8(cr >> 8), uint8(cg >> 8), uint8(cb >> 8)
|
||||
}
|
||||
|
||||
// index256 converts a color to its nearest 256-color palette index using the
|
||||
// supplied profile.
|
||||
func index256(c color.Color, profile colorprofile.Profile) uint8 {
|
||||
cc := profile.Convert(c)
|
||||
if idx, ok := cc.(ansi.IndexedColor); ok {
|
||||
return uint8(idx)
|
||||
}
|
||||
if idx, ok := cc.(ansi.BasicColor); ok {
|
||||
return uint8(idx)
|
||||
}
|
||||
// Fallback: derive an index directly if conversion produced an
|
||||
// unexpected type.
|
||||
r, g, b := rgb8(c)
|
||||
return ansi256FromRGB(r, g, b)
|
||||
}
|
||||
|
||||
// ansi256FromRGB maps an 8-bit RGB color to the xterm 256-color cube. It is a
|
||||
// best-effort fallback used only when profile.Convert does not yield a known
|
||||
// indexed color type.
|
||||
func ansi256FromRGB(r, g, b uint8) uint8 {
|
||||
q := func(v uint8) int {
|
||||
switch {
|
||||
case v < 48:
|
||||
return 0
|
||||
case v < 115:
|
||||
return 1
|
||||
default:
|
||||
return int((v - 35) / 40)
|
||||
}
|
||||
}
|
||||
ri, gi, bi := q(r), q(g), q(b)
|
||||
return uint8(16 + 36*ri + 6*gi + bi)
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package imagepreview
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/charmbracelet/colorprofile"
|
||||
)
|
||||
|
||||
// makePNG builds a simple w x h PNG filled with the given color and returns
|
||||
// its encoded bytes.
|
||||
func makePNG(t *testing.T, w, h int, c color.Color) []byte {
|
||||
t.Helper()
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := range h {
|
||||
for x := range w {
|
||||
img.Set(x, y, c)
|
||||
}
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
t.Fatalf("encode png: %v", err)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestRenderTrueColor(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{R: 255, A: 255})
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected non-empty thumbnail for truecolor profile")
|
||||
}
|
||||
if !strings.Contains(out, upperHalfBlock) {
|
||||
t.Error("output should contain upper half block glyphs")
|
||||
}
|
||||
if !strings.Contains(out, "\x1b[38;2;") || !strings.Contains(out, "48;2;") {
|
||||
t.Errorf("expected truecolor SGR sequences, got %q", out)
|
||||
}
|
||||
// Red fill should appear as 255;0;0 somewhere.
|
||||
if !strings.Contains(out, "255;0;0") {
|
||||
t.Errorf("expected red color in output, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderANSI256(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{G: 255, A: 255})
|
||||
out, err := renderWithProfile(data, 8, 4, color.Black, colorprofile.ANSI256)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected non-empty thumbnail for ANSI256 profile")
|
||||
}
|
||||
if !strings.Contains(out, "\x1b[38;5;") || !strings.Contains(out, "48;5;") {
|
||||
t.Errorf("expected 256-color SGR sequences, got %q", out)
|
||||
}
|
||||
if strings.Contains(out, "38;2;") {
|
||||
t.Errorf("ANSI256 output should not contain truecolor sequences, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderDegradesBelowANSI256(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.RGBA{B: 255, A: 255})
|
||||
for _, p := range []colorprofile.Profile{colorprofile.ANSI, colorprofile.ASCII, colorprofile.NoTTY} {
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, p)
|
||||
if err != nil {
|
||||
t.Fatalf("profile %v: unexpected error: %v", p, err)
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("profile %v: expected empty fallback, got %q", p, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderInvalidImage(t *testing.T) {
|
||||
out, err := renderWithProfile([]byte("not an image"), 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid image data")
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output on decode error, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderRejectsOversizedImage(t *testing.T) {
|
||||
// A header advertising dimensions beyond maxImageDimension must be
|
||||
// rejected before full decode (decompression-bomb guard). image.RGBA
|
||||
// allocation is avoided by only checking the config path here.
|
||||
w := maxImageDimension + 1
|
||||
data := makePNG(t, w, 1, color.White)
|
||||
out, err := renderWithProfile(data, 10, 5, color.Black, colorprofile.TrueColor)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized image dimensions")
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output for oversized image, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderZeroBox(t *testing.T) {
|
||||
data := makePNG(t, 20, 20, color.White)
|
||||
out, err := renderWithProfile(data, 0, 0, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out != "" {
|
||||
t.Errorf("expected empty output for zero-sized box, got %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderNilBackgroundDefaults(t *testing.T) {
|
||||
data := makePNG(t, 10, 10, color.RGBA{R: 10, G: 20, B: 30, A: 255})
|
||||
out, err := renderWithProfile(data, 6, 3, nil, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == "" {
|
||||
t.Fatal("expected output with nil background (defaults to black)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowCountWithinBounds(t *testing.T) {
|
||||
// A tall image should be capped at maxRows cells.
|
||||
data := makePNG(t, 10, 100, color.White)
|
||||
out, err := renderWithProfile(data, 20, 6, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
rows := strings.Count(out, "\n") + 1
|
||||
if rows > 6 {
|
||||
t.Errorf("expected at most 6 rows, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumnCountWithinBounds(t *testing.T) {
|
||||
// A wide image should be capped at maxCols cells per row.
|
||||
data := makePNG(t, 100, 10, color.White)
|
||||
out, err := renderWithProfile(data, 8, 20, color.Black, colorprofile.TrueColor)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
firstRow := strings.SplitN(out, "\n", 2)[0]
|
||||
cols := strings.Count(firstRow, upperHalfBlock)
|
||||
if cols > 8 {
|
||||
t.Errorf("expected at most 8 columns, got %d", cols)
|
||||
}
|
||||
if cols == 0 {
|
||||
t.Error("expected at least one column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitDimensionsPreservesAspect(t *testing.T) {
|
||||
// 2:1 (wide) image into a 40x20 box. Pixel box is 40x40; width-bound.
|
||||
cols, rows := fitDimensions(200, 100, 40, 20)
|
||||
if cols != 40 {
|
||||
t.Errorf("expected 40 cols, got %d", cols)
|
||||
}
|
||||
// pxH = 100 * (40/200) = 20 → 10 rows.
|
||||
if rows != 10 {
|
||||
t.Errorf("expected 10 rows, got %d", rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitDimensionsNeverUpscales(t *testing.T) {
|
||||
cols, rows := fitDimensions(4, 4, 40, 20)
|
||||
if cols != 4 || rows != 2 {
|
||||
t.Errorf("expected 4x2 (no upscale), got %dx%d", cols, rows)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeOpaquePassthrough(t *testing.T) {
|
||||
c := color.RGBA{R: 1, G: 2, B: 3, A: 255}
|
||||
got := composite(c, color.White)
|
||||
if got != color.Color(c) {
|
||||
t.Errorf("opaque color should pass through unchanged, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompositeTransparentOverBackground(t *testing.T) {
|
||||
// Fully transparent pixel over red background should yield red.
|
||||
got := composite(color.RGBA{}, color.RGBA{R: 255, A: 255})
|
||||
r, g, b, a := got.RGBA()
|
||||
if r>>8 != 255 || g>>8 != 0 || b>>8 != 0 || a != 0xffff {
|
||||
t.Errorf("expected opaque red, got r=%d g=%d b=%d a=%d", r>>8, g>>8, b>>8, a)
|
||||
}
|
||||
}
|
||||
+224
-187
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image/color"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/clipboard"
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
"github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/imagepreview"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
@@ -42,6 +44,12 @@ type InputComponent struct {
|
||||
popupHeight int
|
||||
submitNext bool // defer submit one tick so popup dismisses cleanly
|
||||
|
||||
// popup is the shared PopupList used to render the / and @ autocomplete
|
||||
// dropdowns. State (items, cursor, visible search-driven filter) is
|
||||
// driven externally by InputComponent — we only use PopupList for the
|
||||
// rendering chrome so all popups in the app look identical.
|
||||
popup *PopupList
|
||||
|
||||
// Argument completion state. When the user types "/cmd " followed by
|
||||
// a partial argument and the command has a Complete function, the popup
|
||||
// switches to argument-completion mode showing suggestions from Complete.
|
||||
@@ -53,10 +61,16 @@ type InputComponent struct {
|
||||
// file path, the popup shows file/directory suggestions from the cwd.
|
||||
fileMode bool // true when showing @file completions
|
||||
filePrefix string // current text after @ being matched
|
||||
fileAtStartIdx int // byte offset of @ in the textarea value
|
||||
fileAtStartIdx int // byte offset of @ (or path start in /edit mode) in the textarea value
|
||||
fileSuggestions []FileSuggestion // backing storage for file entries
|
||||
fileSynthCmds []commands.SlashCommand // synthetic commands.SlashCommands wrapping file entries
|
||||
|
||||
// fileEditMode is true when fileMode was activated by the /edit slash
|
||||
// command rather than an @ trigger. Selecting a file submits the line
|
||||
// (running $EDITOR on it); selecting a directory drills further like @
|
||||
// does. MCP resources are excluded in this mode.
|
||||
fileEditMode bool
|
||||
|
||||
// cwd is the working directory used for @file path resolution and
|
||||
// autocomplete suggestions. Set by the parent via SetCwd.
|
||||
cwd string
|
||||
@@ -80,6 +94,23 @@ type InputComponent struct {
|
||||
// Images are added via Ctrl+V and cleared on submit or Ctrl+U.
|
||||
pendingImages []core.ImageAttachment
|
||||
|
||||
// imageThumbs caches the rendered half-block thumbnail for each entry in
|
||||
// pendingImages (1:1 index correspondence). Thumbnails are rendered
|
||||
// asynchronously off the Bubble Tea event loop (decode + resample is too
|
||||
// slow to run inside Update), so an entry starts as the empty string
|
||||
// placeholder and is filled in when the matching thumbnailReadyMsg
|
||||
// arrives. An entry stays empty when the terminal cannot display a
|
||||
// half-block preview, in which case the text pill is shown alone.
|
||||
// See internal/ui/imagepreview.
|
||||
imageThumbs []string
|
||||
|
||||
// imageGen is a monotonic generation counter incremented whenever the
|
||||
// pending image set is cleared. Async thumbnail results carry the
|
||||
// generation they were enqueued under and are discarded if it no longer
|
||||
// matches, preventing a stale thumbnail from landing on the wrong slot
|
||||
// after a clear + re-attach.
|
||||
imageGen int
|
||||
|
||||
// history stores previously submitted prompts (most recent last).
|
||||
// Limited to maxHistory entries; duplicates of the previous entry are
|
||||
// skipped. Empty strings are never stored.
|
||||
@@ -105,6 +136,16 @@ type clipboardImageMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// thumbnailReadyMsg carries the result of an async thumbnail render back to
|
||||
// the Update loop. gen and index identify the pendingImages slot the
|
||||
// thumbnail belongs to; the result is dropped if the generation no longer
|
||||
// matches (the pending set was cleared) or the index is out of range.
|
||||
type thumbnailReadyMsg struct {
|
||||
gen int
|
||||
index int
|
||||
thumb string
|
||||
}
|
||||
|
||||
// NewInputComponent creates a new InputComponent with the given width and
|
||||
// optional AppController. If appCtrl is nil the component still works but
|
||||
// /clear and /clear-queue are no-ops.
|
||||
@@ -135,7 +176,7 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
|
||||
styles.Focused.CursorLine = lipgloss.NewStyle()
|
||||
ta.SetStyles(styles)
|
||||
|
||||
return &InputComponent{
|
||||
ic := &InputComponent{
|
||||
textarea: ta,
|
||||
commands: commands.SlashCommands,
|
||||
width: width,
|
||||
@@ -143,6 +184,12 @@ func NewInputComponent(width int, appCtrl AppController) *InputComponent {
|
||||
appCtrl: appCtrl,
|
||||
hideHint: true,
|
||||
}
|
||||
ic.popup = NewPopupList("", nil, width, 0)
|
||||
ic.popup.ShowSearch = false
|
||||
ic.popup.HideCount = true
|
||||
ic.popup.MaxVisible = ic.popupHeight
|
||||
ic.popup.FooterHint = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
return ic
|
||||
}
|
||||
|
||||
// SetCwd sets the working directory used for @file autocomplete suggestions
|
||||
@@ -193,7 +240,23 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return s, nil
|
||||
}
|
||||
if msg.image != nil {
|
||||
s.pendingImages = append(s.pendingImages, *msg.image)
|
||||
img := *msg.image
|
||||
index := len(s.pendingImages)
|
||||
s.pendingImages = append(s.pendingImages, img)
|
||||
// Reserve a placeholder; the async render fills it in via
|
||||
// thumbnailReadyMsg so Update never blocks on decode/resample.
|
||||
s.imageThumbs = append(s.imageThumbs, "")
|
||||
cols := s.thumbCols()
|
||||
if cols < 1 {
|
||||
return s, nil
|
||||
}
|
||||
return s, renderThumbnailCmd(img, cols, thumbMaxRows, style.GetTheme().Background, s.imageGen, index)
|
||||
}
|
||||
return s, nil
|
||||
|
||||
case thumbnailReadyMsg:
|
||||
if msg.gen == s.imageGen && msg.index >= 0 && msg.index < len(s.imageThumbs) {
|
||||
s.imageThumbs[msg.index] = msg.thumb
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -250,6 +313,8 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Clear all pending image attachments.
|
||||
if len(s.pendingImages) > 0 {
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return s, nil
|
||||
}
|
||||
}
|
||||
@@ -405,10 +470,17 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
} else {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
}
|
||||
} else if len(lines) == 1 && strings.HasPrefix(lines[0], "/") {
|
||||
s.fileMode = false
|
||||
if !strings.Contains(lines[0], " ") {
|
||||
s.fileEditMode = false
|
||||
if cmdLen, pathPrefix, isEdit := ExtractEditPrefix(lines[0]); isEdit {
|
||||
// /edit fuzzy-file picker. Behaves like @ except
|
||||
// MCP resources are excluded and selecting a file
|
||||
// submits the line (running $EDITOR).
|
||||
s.updateEditFilePopup(cmdLen, pathPrefix)
|
||||
} else if !strings.Contains(lines[0], " ") {
|
||||
// Command name completion.
|
||||
s.showPopup = true
|
||||
s.argMode = false
|
||||
@@ -428,6 +500,7 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.showPopup = false
|
||||
s.argMode = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
}
|
||||
}
|
||||
return s, cmd
|
||||
@@ -486,6 +559,8 @@ func (s *InputComponent) handleSubmit(value string) tea.Cmd {
|
||||
// images and clear them.
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return func() tea.Msg {
|
||||
return core.SubmitMsg{Text: trimmed, Images: images}
|
||||
}
|
||||
@@ -519,6 +594,42 @@ func (s *InputComponent) resetHistoryBrowsing() {
|
||||
s.savedInput = ""
|
||||
}
|
||||
|
||||
// thumbMaxCols and thumbMaxRows cap the size, in terminal cells, of pending
|
||||
// image previews. Kept small for the low-res look and to keep scrollback
|
||||
// light.
|
||||
const (
|
||||
thumbMaxCols = 40
|
||||
thumbMaxRows = 12
|
||||
)
|
||||
|
||||
// thumbCols returns the thumbnail width in terminal cells given the current
|
||||
// input width, or 0 when there is no room to render a preview.
|
||||
func (s *InputComponent) thumbCols() int {
|
||||
if s.width <= 6 {
|
||||
return 0
|
||||
}
|
||||
cols := min(thumbMaxCols, s.width-6)
|
||||
if cols < 1 {
|
||||
return 0
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
// renderThumbnailCmd returns a tea.Cmd that renders a half-block ANSI preview
|
||||
// off the Bubble Tea event loop. The decode + resample work runs in the Cmd
|
||||
// goroutine, and the result is delivered as a thumbnailReadyMsg tagged with
|
||||
// the generation and slot index it was enqueued for. An empty thumbnail
|
||||
// (terminal unsupported or render error) leaves the text pill in place.
|
||||
func renderThumbnailCmd(img core.ImageAttachment, cols, rows int, bg color.Color, gen, index int) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, rows, bg)
|
||||
if err != nil {
|
||||
thumb = ""
|
||||
}
|
||||
return thumbnailReadyMsg{gen: gen, index: index, thumb: thumb}
|
||||
}
|
||||
}
|
||||
|
||||
// View implements tea.Model. Renders the textarea, autocomplete popup
|
||||
// (if visible), and help text.
|
||||
func (s *InputComponent) View() tea.View {
|
||||
@@ -544,7 +655,9 @@ func (s *InputComponent) View() tea.View {
|
||||
// Popup is now rendered as a centered overlay in AppModel.View()
|
||||
// instead of inline here to prevent bottom overflow
|
||||
|
||||
// Show image attachment indicator when images are pending.
|
||||
// Show image attachment previews when images are pending. A cached
|
||||
// half-block thumbnail is rendered when the terminal supports it;
|
||||
// otherwise the text pill alone is shown.
|
||||
if len(s.pendingImages) > 0 {
|
||||
imgStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
@@ -553,6 +666,14 @@ func (s *InputComponent) View() tea.View {
|
||||
label := fmt.Sprintf("[%d image(s) attached] ctrl+u to clear", len(s.pendingImages))
|
||||
view.WriteString("\n")
|
||||
view.WriteString(imgStyle.Render(label))
|
||||
|
||||
thumbStyle := lipgloss.NewStyle().PaddingLeft(3)
|
||||
for i := range s.pendingImages {
|
||||
if i < len(s.imageThumbs) && s.imageThumbs[i] != "" {
|
||||
view.WriteString("\n")
|
||||
view.WriteString(thumbStyle.Render(s.imageThumbs[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !s.hideHint {
|
||||
@@ -591,191 +712,37 @@ func (s *InputComponent) View() tea.View {
|
||||
return tea.NewView(containerStyle.Render(view.String()))
|
||||
}
|
||||
|
||||
// renderPopup renders the autocomplete popup for slash command suggestions.
|
||||
// When rendered inline (not centered), returns the styled popup content.
|
||||
// RenderPopupCentered renders the popup as a centered overlay.
|
||||
// RenderPopupCentered renders the autocomplete popup for / or @ as a
|
||||
// centered overlay. Returns "" when the popup is not currently shown.
|
||||
// The actual filtering / selection state lives on InputComponent — this
|
||||
// method merely converts the filtered FuzzyMatch list into PopupItems
|
||||
// and asks the shared PopupList to draw it. As a result the / popup, the
|
||||
// @ popup, the model picker, the tree selector and the session selector
|
||||
// all share identical chrome.
|
||||
func (s *InputComponent) RenderPopupCentered(termWidth, termHeight int) string {
|
||||
if !s.showPopup || len(s.filtered) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
popupContent := s.renderPopupWithOptions(true)
|
||||
|
||||
// Center popup using lipgloss.Place
|
||||
positioned := lipgloss.Place(
|
||||
termWidth,
|
||||
termHeight,
|
||||
lipgloss.Center,
|
||||
lipgloss.Center,
|
||||
popupContent,
|
||||
)
|
||||
|
||||
return positioned
|
||||
}
|
||||
|
||||
// renderPopupWithOptions renders the popup content with optional center styling.
|
||||
func (s *InputComponent) renderPopupWithOptions(centered bool) string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(s.width-4, 20)
|
||||
|
||||
// Use the theme background for the popup - the full-width item backgrounds
|
||||
// and primary-colored selection will provide sufficient contrast
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginLeft(0).
|
||||
MarginBottom(1) // Visual depth/shadow effect
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
|
||||
// Item background styles for high contrast
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
var items []string
|
||||
|
||||
visibleItems := min(len(s.filtered), s.popupHeight)
|
||||
startIdx := 0
|
||||
if s.selected >= s.popupHeight {
|
||||
startIdx = s.selected - s.popupHeight + 1
|
||||
}
|
||||
endIdx := min(startIdx+visibleItems, len(s.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
match := s.filtered[i]
|
||||
sc := match.Command
|
||||
|
||||
// Choose the appropriate background style
|
||||
itemStyle := normalItemBg
|
||||
if i == s.selected {
|
||||
itemStyle = selectedItemBg
|
||||
items := make([]PopupItem, len(s.filtered))
|
||||
for i, m := range s.filtered {
|
||||
desc := ""
|
||||
if m.Command != nil {
|
||||
desc = m.Command.Description
|
||||
}
|
||||
|
||||
// Build indicator with proper coloring
|
||||
var indicator string
|
||||
if i == s.selected {
|
||||
indicator = "> "
|
||||
} else {
|
||||
indicator = " "
|
||||
name := ""
|
||||
if m.Command != nil {
|
||||
name = m.Command.Name
|
||||
}
|
||||
|
||||
// Build content with name and description
|
||||
var content string
|
||||
if s.fileMode {
|
||||
// File mode: use full width for the path, show description inline
|
||||
maxNameLen := max(innerWidth-16, 8)
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameLen && maxNameLen > 3 {
|
||||
displayName = displayName[:maxNameLen-3] + "..."
|
||||
}
|
||||
|
||||
if sc.Description != "" && innerWidth > 30 {
|
||||
content = indicator + displayName + " " + sc.Description
|
||||
} else {
|
||||
content = indicator + displayName
|
||||
}
|
||||
} else {
|
||||
// Line layout: indicator(2) + name(nameWidth-2 visual) + desc
|
||||
if innerWidth < 20 {
|
||||
// Very narrow: show truncated name only
|
||||
displayName := sc.Name
|
||||
maxName := max(innerWidth-2, 3)
|
||||
if len(displayName) > maxName {
|
||||
displayName = displayName[:maxName-1] + "…"
|
||||
}
|
||||
content = indicator + displayName
|
||||
} else {
|
||||
// Compute nameWidth from the longest command name in the
|
||||
// visible slice so we never truncate unnecessarily.
|
||||
nameWidth := 0
|
||||
for _, fm := range s.filtered {
|
||||
if n := len([]rune(fm.Command.Name)); n > nameWidth {
|
||||
nameWidth = n
|
||||
}
|
||||
}
|
||||
nameWidth += 3 // account for indicator prefix (2) + gap before description (1)
|
||||
// Ensure descriptions still get at least 20 chars when possible.
|
||||
maxForName := innerWidth - 20
|
||||
if maxForName < 8 {
|
||||
maxForName = innerWidth * 2 / 3
|
||||
}
|
||||
if nameWidth > maxForName {
|
||||
nameWidth = maxForName
|
||||
}
|
||||
if nameWidth < 8 {
|
||||
nameWidth = 8
|
||||
}
|
||||
maxNameChars := nameWidth - 2
|
||||
displayName := sc.Name
|
||||
if len(displayName) > maxNameChars {
|
||||
displayName = displayName[:maxNameChars-1] + "…"
|
||||
}
|
||||
|
||||
// Description gets remaining space
|
||||
maxDescLen := max(innerWidth-nameWidth, 0)
|
||||
desc := sc.Description
|
||||
if maxDescLen >= 4 && desc != "" {
|
||||
if len(desc) > maxDescLen {
|
||||
desc = desc[:maxDescLen-3] + "..."
|
||||
}
|
||||
content = indicator + lipgloss.NewStyle().Width(maxNameChars).Render(displayName) + desc
|
||||
} else {
|
||||
content = indicator + displayName
|
||||
}
|
||||
}
|
||||
items[i] = PopupItem{
|
||||
Label: name,
|
||||
Description: desc,
|
||||
}
|
||||
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
// Add scroll indicators with background
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Padding(0, 1)
|
||||
|
||||
if startIdx > 0 {
|
||||
items = append([]string{scrollStyle.Render(" ↑ more above")}, items...)
|
||||
}
|
||||
if endIdx < len(s.filtered) {
|
||||
items = append(items, scrollStyle.Render(" ↓ more below"))
|
||||
}
|
||||
|
||||
content := strings.Join(items, "\n")
|
||||
|
||||
// Adapt footer text to available width with background
|
||||
var footerText string
|
||||
if innerWidth >= 50 {
|
||||
footerText = "↑↓ navigate • tab complete • ↵ select • esc dismiss"
|
||||
} else if innerWidth >= 30 {
|
||||
footerText = "↑↓ nav • tab • ↵ select • esc"
|
||||
} else {
|
||||
footerText = "↑↓ tab ↵ esc"
|
||||
}
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Italic(true).
|
||||
Render(footerText)
|
||||
|
||||
return popupStyle.Render(content + "\n\n" + footer)
|
||||
s.popup.SetSize(termWidth, termHeight)
|
||||
s.popup.SetItems(items)
|
||||
s.popup.SetCursor(s.selected)
|
||||
return s.popup.RenderCentered(termWidth, termHeight)
|
||||
}
|
||||
|
||||
// completeArgs checks whether the input line matches a command with a Complete
|
||||
@@ -844,6 +811,8 @@ func readClipboardImageCmd() tea.Cmd {
|
||||
func (s *InputComponent) ClearPendingImages() []core.ImageAttachment {
|
||||
images := s.pendingImages
|
||||
s.pendingImages = nil
|
||||
s.imageThumbs = nil
|
||||
s.imageGen++
|
||||
return images
|
||||
}
|
||||
|
||||
@@ -862,6 +831,7 @@ func (s *InputComponent) Clear() bool {
|
||||
s.showPopup = false
|
||||
s.argMode = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
s.browsingHistory = false
|
||||
s.savedInput = ""
|
||||
return hadContent
|
||||
@@ -871,6 +841,11 @@ func (s *InputComponent) Clear() bool {
|
||||
// file or MCP resource suggestion. For directories, it keeps the popup open
|
||||
// for further drilling. For files and resources, it closes the popup and adds
|
||||
// a trailing space.
|
||||
//
|
||||
// When fileEditMode is active the same path-replacement happens against the
|
||||
// /edit (or alias) command prefix instead of an @ trigger. Selecting a file
|
||||
// also arms submitNext so the next tick runs $EDITOR on it; selecting a
|
||||
// directory keeps the popup open for drill-down.
|
||||
func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
if idx >= len(s.fileSuggestions) {
|
||||
return
|
||||
@@ -889,7 +864,17 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
beforeAt := lastLine[:s.fileAtStartIdx]
|
||||
|
||||
var replacement string
|
||||
if suggestion.IsMCPResource {
|
||||
switch {
|
||||
case s.fileEditMode:
|
||||
// /edit path mode — no @ prefix; the path is the bare argument.
|
||||
// MCP resources are excluded upstream, so only file/dir entries reach here.
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
if needsQuote {
|
||||
replacement = `"` + suggestion.RelPath + `"`
|
||||
} else {
|
||||
replacement = suggestion.RelPath
|
||||
}
|
||||
case suggestion.IsMCPResource:
|
||||
// MCP resources use @mcp:server:uri format.
|
||||
// Quote if the URI contains spaces.
|
||||
ref := "mcp:" + suggestion.MCPServerName + ":" + suggestion.MCPResourceURI
|
||||
@@ -899,7 +884,7 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
replacement = "@" + ref
|
||||
}
|
||||
replacement += " "
|
||||
} else {
|
||||
default:
|
||||
needsQuote := strings.Contains(suggestion.RelPath, " ")
|
||||
if needsQuote {
|
||||
replacement = `@"` + suggestion.RelPath + `"`
|
||||
@@ -925,9 +910,61 @@ func (s *InputComponent) applyFileCompletion(idx int) {
|
||||
if suggestion.IsDir && !suggestion.IsMCPResource {
|
||||
// Keep popup open — trigger a refresh for the new directory.
|
||||
s.lastValue = "" // force re-evaluation on next update tick
|
||||
} else {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.selected = 0
|
||||
return
|
||||
}
|
||||
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.selected = 0
|
||||
|
||||
if s.fileEditMode {
|
||||
// A file was selected via /edit — submit on the next tick so the
|
||||
// popup dismisses cleanly before $EDITOR takes the terminal.
|
||||
s.fileEditMode = false
|
||||
s.submitNext = true
|
||||
}
|
||||
}
|
||||
|
||||
// updateEditFilePopup queries the file-suggestion engine for the /edit path
|
||||
// prefix and populates the popup state. cmdLen is the byte offset of the path
|
||||
// argument within the current line (i.e. length of "/edit " or "/ed ").
|
||||
// Directories are kept so the user can drill down; MCP resources are skipped.
|
||||
func (s *InputComponent) updateEditFilePopup(cmdLen int, pathPrefix string) {
|
||||
var suggestions []FileSuggestion
|
||||
if s.cwd != "" {
|
||||
suggestions = GetFileSuggestions(pathPrefix, s.cwd)
|
||||
}
|
||||
if len(suggestions) == 0 {
|
||||
s.showPopup = false
|
||||
s.fileMode = false
|
||||
s.fileEditMode = false
|
||||
return
|
||||
}
|
||||
|
||||
sort.Slice(suggestions, func(i, j int) bool {
|
||||
return suggestions[i].Score > suggestions[j].Score
|
||||
})
|
||||
if len(suggestions) > maxFileSuggestions {
|
||||
suggestions = suggestions[:maxFileSuggestions]
|
||||
}
|
||||
|
||||
s.showPopup = true
|
||||
s.fileMode = true
|
||||
s.fileEditMode = true
|
||||
s.argMode = false
|
||||
s.filePrefix = pathPrefix
|
||||
s.fileAtStartIdx = cmdLen
|
||||
s.fileSuggestions = suggestions
|
||||
s.fileSynthCmds = make([]commands.SlashCommand, len(suggestions))
|
||||
s.filtered = make([]FuzzyMatch, len(suggestions))
|
||||
for i, fs := range suggestions {
|
||||
name := fs.RelPath
|
||||
desc := ""
|
||||
if fs.IsDir {
|
||||
desc = "directory"
|
||||
}
|
||||
s.fileSynthCmds[i] = commands.SlashCommand{Name: name, Description: desc}
|
||||
s.filtered[i] = FuzzyMatch{Command: &s.fileSynthCmds[i], Score: fs.Score}
|
||||
}
|
||||
s.selected = 0
|
||||
}
|
||||
|
||||
@@ -25,17 +25,6 @@ type TextMessageItem struct {
|
||||
timestamp time.Time
|
||||
}
|
||||
|
||||
// NewTextMessageItem creates a new text message for the scrollback.
|
||||
// The content should be pre-rendered using MessageRenderer for proper styling.
|
||||
func NewTextMessageItem(id string, role string, content string) *TextMessageItem {
|
||||
return &TextMessageItem{
|
||||
id: id,
|
||||
role: role,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewStyledMessageItem creates a message item with pre-rendered styled content.
|
||||
// This is the preferred way to create messages when you have styled content from MessageRenderer.
|
||||
func NewStyledMessageItem(id string, role string, rawContent string, preRendered string) *TextMessageItem {
|
||||
@@ -316,57 +305,6 @@ func (m *StreamingBashOutputItem) MarkComplete() {
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// SystemMessageItem - System messages (commands, info, errors)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SystemMessageItem represents a system message (commands, info, errors).
|
||||
type SystemMessageItem struct {
|
||||
id string
|
||||
content string
|
||||
timestamp time.Time
|
||||
cachedRender string
|
||||
cachedWidth int
|
||||
}
|
||||
|
||||
// NewSystemMessageItem creates a new system message for the scrollback.
|
||||
func NewSystemMessageItem(id, content string) *SystemMessageItem {
|
||||
return &SystemMessageItem{
|
||||
id: id,
|
||||
content: content,
|
||||
timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) ID() string {
|
||||
return m.id
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Render(width int) string {
|
||||
// Return cached render if width matches
|
||||
if m.cachedWidth == width && m.cachedRender != "" {
|
||||
return m.cachedRender
|
||||
}
|
||||
|
||||
// Simple system message formatting
|
||||
rendered := "│ " + strings.ReplaceAll(m.content, "\n", "\n│ ")
|
||||
|
||||
// Cache and return
|
||||
m.cachedRender = rendered
|
||||
m.cachedWidth = width
|
||||
return rendered
|
||||
}
|
||||
|
||||
func (m *SystemMessageItem) Height() int {
|
||||
if m.cachedRender != "" {
|
||||
return strings.Count(m.cachedRender, "\n") + 1
|
||||
}
|
||||
// Estimate
|
||||
if m.cachedWidth > 0 {
|
||||
return (len(m.content) / max(m.cachedWidth-10, 40)) + 3
|
||||
}
|
||||
return 3
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helper: generateMessageID
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
+583
-131
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/ui/commands"
|
||||
uicore "github.com/mark3labs/kit/internal/ui/core"
|
||||
"github.com/mark3labs/kit/internal/ui/fileutil"
|
||||
"github.com/mark3labs/kit/internal/ui/imagepreview"
|
||||
"github.com/mark3labs/kit/internal/ui/prefs"
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
@@ -124,13 +126,31 @@ type AppController interface {
|
||||
// attachments (e.g. pasted images) into the currently running agent
|
||||
// turn. Behaves like Steer but includes file parts alongside the text.
|
||||
SteerWithFiles(prompt string, files []kit.LLMFilePart) int
|
||||
// PopLastUserMessage truncates the tree session at the parent of the
|
||||
// most recent user message on the current branch, syncs the in-memory
|
||||
// message store, and returns that user prompt (plus any image file
|
||||
// parts) so the caller can resubmit it. Used by /retry to recover from
|
||||
// provider errors (overloaded, timeout) without duplicating the user
|
||||
// message in context. Returns an error if the agent is busy, no tree
|
||||
// session is active, or no user message exists on the current branch.
|
||||
PopLastUserMessage() (string, []kit.LLMFilePart, error)
|
||||
}
|
||||
|
||||
// SkillItem holds display metadata about a loaded skill for the startup
|
||||
// [Skills] section. Built by the CLI layer from the SDK's []*kit.Skill.
|
||||
type SkillItem struct {
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Name string // Skill name (e.g. "btca-cli").
|
||||
Path string // Absolute path to the skill file.
|
||||
Source string // "project" or "user" (global).
|
||||
Description string // Short summary used in autocomplete and help.
|
||||
}
|
||||
|
||||
// ExtensionItem holds display metadata about a loaded extension for the
|
||||
// startup [Extensions] section. Built by the CLI layer from the SDK's
|
||||
// []kit.ExtensionInfo.
|
||||
type ExtensionItem struct {
|
||||
Name string // Extension display name (filename without .go extension).
|
||||
Path string // Absolute path to the extension's .go file.
|
||||
Source string // "project" or "user" (global).
|
||||
}
|
||||
|
||||
@@ -363,6 +383,16 @@ type AppModelOptions struct {
|
||||
// watcher detects changes. May be nil if skill hot-reload is not needed.
|
||||
GetSkillItems func() []SkillItem
|
||||
|
||||
// ExtensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section. Each entry shows the filename of an extension that was
|
||||
// discovered and loaded (global, project-local, or explicit).
|
||||
ExtensionItems []ExtensionItem
|
||||
|
||||
// GetExtensionItems, if non-nil, returns the current extension items.
|
||||
// Called on extension hot-reload to refresh the list. May be nil if no
|
||||
// extensions are loaded.
|
||||
GetExtensionItems func() []ExtensionItem
|
||||
|
||||
// MCPToolCount is the number of tools loaded from external MCP servers.
|
||||
MCPToolCount int
|
||||
|
||||
@@ -607,6 +637,14 @@ type AppModel struct {
|
||||
// skill list after content hot-reload. May be nil.
|
||||
getSkillItems func() []SkillItem
|
||||
|
||||
// extensionItems lists loaded extensions for the [Extensions] startup
|
||||
// section (filenames only).
|
||||
extensionItems []ExtensionItem
|
||||
|
||||
// getExtensionItems returns the current extension items. Used to refresh
|
||||
// the list after extension hot-reload. May be nil.
|
||||
getExtensionItems func() []ExtensionItem
|
||||
|
||||
// mcpToolCount and extensionToolCount track tool counts by source for
|
||||
// the startup info display.
|
||||
mcpToolCount int
|
||||
@@ -860,6 +898,8 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
m.contextPaths = opts.ContextPaths
|
||||
m.skillItems = opts.SkillItems
|
||||
m.getSkillItems = opts.GetSkillItems
|
||||
m.extensionItems = opts.ExtensionItems
|
||||
m.getExtensionItems = opts.GetExtensionItems
|
||||
m.mcpToolCount = opts.MCPToolCount
|
||||
m.extensionToolCount = opts.ExtensionToolCount
|
||||
m.startupExtensionMessages = opts.StartupExtensionMessages
|
||||
@@ -912,6 +952,20 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge skills into autocomplete as /skill:<name> commands. Skills accept
|
||||
// optional trailing args, so HasArgs is true — Enter populates the input
|
||||
// with "/skill:name " rather than auto-submitting.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.SkillItems) > 0 {
|
||||
for _, s := range opts.SkillItems {
|
||||
ic.commands = append(ic.commands, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Merge MCP prompts into autocomplete as /<server>:<prompt> commands.
|
||||
if ic, ok := m.input.(*InputComponent); ok && len(opts.MCPPrompts) > 0 {
|
||||
for _, p := range opts.MCPPrompts {
|
||||
@@ -1014,8 +1068,21 @@ func (m *AppModel) AddStartupMessageToScrollList() {
|
||||
pairs = append(pairs, [2]string{"Skills", strings.Join(names, ", ")})
|
||||
}
|
||||
|
||||
// Extension tool count (only shown when > 0).
|
||||
if m.extensionToolCount > 0 {
|
||||
// Extensions — listed by filename. Each extension shows its basename
|
||||
// without the .go suffix, matching the [Skills] section's style.
|
||||
if len(m.extensionItems) > 0 {
|
||||
names := make([]string, len(m.extensionItems))
|
||||
for i, ei := range m.extensionItems {
|
||||
names[i] = ei.Name
|
||||
}
|
||||
value := strings.Join(names, ", ")
|
||||
if m.extensionToolCount > 0 {
|
||||
value += fmt.Sprintf(" (%d tools)", m.extensionToolCount)
|
||||
}
|
||||
pairs = append(pairs, [2]string{"Extensions", value})
|
||||
} else if m.extensionToolCount > 0 {
|
||||
// Fallback: tool count only (extensions registered tools but the CLI
|
||||
// did not provide ExtensionItems for some reason).
|
||||
pairs = append(pairs, [2]string{"Extensions", fmt.Sprintf("%d tools", m.extensionToolCount)})
|
||||
}
|
||||
|
||||
@@ -1141,53 +1208,7 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.modelSelector = nil
|
||||
m.state = stateInput
|
||||
if m.setModel != nil {
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
fallback := models.SuggestThinkingLevelFallback(currentLevel, modelName)
|
||||
if fallback != models.ThinkingOff {
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Note: Model %s doesn't support '%s' thinking level. Adjusted to '%s'.",
|
||||
modelName, currentLevel, fallback,
|
||||
))
|
||||
m.thinkingLevel = string(fallback)
|
||||
if m.setThinkingLevel != nil {
|
||||
_ = m.setThinkingLevel(string(fallback))
|
||||
}
|
||||
go func() { _ = prefs.SaveThinkingLevelPreference(string(fallback)) }()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.setModel(msg.ModelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
} else {
|
||||
// Update display state directly — we cannot use
|
||||
// NotifyModelChanged (prog.Send) from inside Update()
|
||||
// without deadlocking BubbleTea.
|
||||
parts := strings.SplitN(msg.ModelString, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", msg.ModelString))
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(msg.ModelString) }()
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
newModel := msg.ModelString
|
||||
prev := previousModel
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
}
|
||||
m.switchModel(msg.ModelString)
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
|
||||
@@ -1251,7 +1272,11 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.scrollList.autoScroll = false
|
||||
case tea.MouseWheelDown:
|
||||
m.scrollList.ScrollBy(scrollLines)
|
||||
if m.scrollList.AtBottom() {
|
||||
// Only re-enable auto-scroll when the user is not actively
|
||||
// selecting text. Otherwise a wheel-down during a drag-select
|
||||
// would re-arm GotoBottom on the next stream chunk, shifting
|
||||
// the highlighted row out from under the cursor.
|
||||
if m.scrollList.AtBottom() && !m.scrollList.IsMouseDown() {
|
||||
m.scrollList.autoScroll = true
|
||||
}
|
||||
}
|
||||
@@ -1259,9 +1284,14 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// ── Mouse click selection (crush-style character-level) ──────────────────
|
||||
case tea.MouseClickMsg:
|
||||
if msg.Button == tea.MouseLeft {
|
||||
// Calculate viewport-relative coordinates.
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
// Compute the scrollback origin from the current frame's layout
|
||||
// rather than the stale cached value from the previous View().
|
||||
// scrollbackYOffset/scrollList.height are only refreshed inside
|
||||
// View() and lag behind any state change that resized the header
|
||||
// (extension widgets, warning rows, etc.) since the last render.
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
// Clear any previous selection on a new click.
|
||||
// HandleMouseDown will set up new selection state.
|
||||
if m.scrollList.HandleMouseDown(msg.X, viewY) {
|
||||
@@ -1272,8 +1302,9 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Mouse motion/drag for character-level selection ──────────────────────
|
||||
case tea.MouseMotionMsg:
|
||||
viewY := msg.Y - m.scrollbackYOffset
|
||||
if viewY >= 0 && viewY < m.scrollList.height {
|
||||
yOff, vpHeight := m.currentScrollbackBounds()
|
||||
viewY := msg.Y - yOff
|
||||
if viewY >= 0 && viewY < vpHeight {
|
||||
m.scrollList.HandleMouseDrag(msg.X, viewY)
|
||||
}
|
||||
|
||||
@@ -1603,10 +1634,16 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// ── Cancel timer expired ─────────────────────────────────────────────────
|
||||
case uicore.CancelTimerExpiredMsg:
|
||||
if m.canceling {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.canceling = false
|
||||
|
||||
// ── Ctrl+C reset timer expired ────────────────────────────────────────────
|
||||
case uicore.CtrlCResetMsg:
|
||||
if m.ctrlCPressedOnce {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
m.ctrlCPressedOnce = false
|
||||
|
||||
// ── Input submitted ──────────────────────────────────────────────────────
|
||||
@@ -1721,14 +1758,27 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// messages stay in chronological order.
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
// Insert inline thumbnail previews after the user message.
|
||||
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
|
||||
}
|
||||
} else {
|
||||
m.printUserMessage(displayText)
|
||||
// Insert inline thumbnail previews after the user message.
|
||||
cmds = append(cmds, m.transcriptPreviewCmd(msg.Images, m.lastMessageID()))
|
||||
}
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
|
||||
// ── Async transcript image preview ───────────────────────────────────────
|
||||
case imagePreviewReadyMsg:
|
||||
if msg.block != "" {
|
||||
item := NewStyledMessageItem(generateMessageID(), "user", "", msg.block)
|
||||
m.insertMessageAfter(msg.anchorID, item)
|
||||
m.refreshContent()
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
// ── Shell command (! / !!) ───────────────────────────────────────────────
|
||||
case uicore.ShellCommandMsg:
|
||||
// Show spinner while the shell command runs.
|
||||
@@ -2324,10 +2374,21 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
case editFileMsg:
|
||||
// User returned from $EDITOR after `/edit <path>`. The file was
|
||||
// edited directly on disk — no textarea changes. Report the result.
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Editor exited with error: %v", msg.err))
|
||||
} else {
|
||||
m.printSystemMessage(fmt.Sprintf("Edited `%s`", msg.path))
|
||||
}
|
||||
m.layoutDirty = true
|
||||
|
||||
case extReloadResultMsg:
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
} else {
|
||||
m.refreshExtensionItems()
|
||||
m.printSystemMessage("Extensions reloaded.")
|
||||
}
|
||||
|
||||
@@ -2373,6 +2434,19 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.printSystemMessage(msg.Text)
|
||||
}
|
||||
|
||||
// ── Clipboard image attached / thumbnail rendered ────────────────────────
|
||||
// Both messages change the input region's rendered height (the pill and
|
||||
// the async half-block preview), so forward them to the input and mark the
|
||||
// layout dirty — otherwise distributeHeight keeps a stale, too-short input
|
||||
// height and the preview is clipped off the bottom of the screen.
|
||||
case clipboardImageMsg, thumbnailReadyMsg:
|
||||
if m.input != nil {
|
||||
updated, cmd := m.input.Update(msg)
|
||||
m.input, _ = updated.(inputComponentIface)
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.layoutDirty = true
|
||||
|
||||
default:
|
||||
// Pass unrecognised messages to all children.
|
||||
if m.input != nil {
|
||||
@@ -2972,6 +3046,85 @@ func truncateMessageForBlock(msg string, maxLines, width int) string {
|
||||
// Print helpers — add content to ScrollList
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// imagePreviewReadyMsg carries an asynchronously rendered transcript image
|
||||
// preview block back to the Update loop, where it is inserted into the
|
||||
// ScrollList directly after the originating user message (identified by
|
||||
// anchorID). Inserting by anchor — rather than appending — keeps the preview
|
||||
// next to its message even when the agent's streamed reply has already been
|
||||
// appended while the thumbnail was being decoded off the event loop.
|
||||
type imagePreviewReadyMsg struct {
|
||||
block string
|
||||
anchorID string
|
||||
}
|
||||
|
||||
// transcriptPreviewCmd returns a tea.Cmd that renders half-block thumbnail
|
||||
// previews for the given clipboard images off the Bubble Tea event loop
|
||||
// (decode + resample must not block Update). The rendered block is delivered
|
||||
// via imagePreviewReadyMsg, tagged with anchorID so the consumer can place it
|
||||
// directly after the originating user message. Returns nil when there is
|
||||
// nothing to render or no room for a preview; an empty result (terminal lacks
|
||||
// color support) yields a nil message that Bubble Tea ignores.
|
||||
func (m *AppModel) transcriptPreviewCmd(images []uicore.ImageAttachment, anchorID string) tea.Cmd {
|
||||
if len(images) == 0 {
|
||||
return nil
|
||||
}
|
||||
cols := thumbMaxCols
|
||||
if m.width > 6 && m.width-6 < cols {
|
||||
cols = m.width - 6
|
||||
}
|
||||
if cols < 1 {
|
||||
return nil
|
||||
}
|
||||
bg := style.GetTheme().Background
|
||||
imgs := images
|
||||
return func() tea.Msg {
|
||||
pad := lipgloss.NewStyle().PaddingLeft(2)
|
||||
var blocks []string
|
||||
for _, img := range imgs {
|
||||
thumb, err := imagepreview.Render(img.Data, img.MediaType, cols, thumbMaxRows, bg)
|
||||
if err != nil || thumb == "" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, pad.Render(thumb))
|
||||
}
|
||||
if len(blocks) == 0 {
|
||||
return nil
|
||||
}
|
||||
return imagePreviewReadyMsg{block: strings.Join(blocks, "\n"), anchorID: anchorID}
|
||||
}
|
||||
}
|
||||
|
||||
// lastMessageID returns the ID of the most recently added ScrollList message,
|
||||
// or "" when there are none. Used to anchor an async transcript preview to the
|
||||
// user message that was just printed.
|
||||
func (m *AppModel) lastMessageID() string {
|
||||
if len(m.messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
return m.messages[len(m.messages)-1].ID()
|
||||
}
|
||||
|
||||
// insertMessageAfter inserts item immediately after the message whose ID
|
||||
// matches anchorID. If anchorID is empty or not found, item is appended.
|
||||
func (m *AppModel) insertMessageAfter(anchorID string, item MessageItem) {
|
||||
idx := -1
|
||||
if anchorID != "" {
|
||||
for i, msgItem := range m.messages {
|
||||
if msgItem.ID() == anchorID {
|
||||
idx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if idx < 0 {
|
||||
m.messages = append(m.messages, item)
|
||||
return
|
||||
}
|
||||
m.messages = append(m.messages, nil)
|
||||
copy(m.messages[idx+2:], m.messages[idx+1:])
|
||||
m.messages[idx+1] = item
|
||||
}
|
||||
|
||||
// printUserMessage renders a user message into the ScrollList.
|
||||
func (m *AppModel) printUserMessage(text string) {
|
||||
// Check if this exact message was just added (prevents duplicates)
|
||||
@@ -3095,6 +3248,12 @@ func (m *AppModel) handleSlashCommand(sc *commands.SlashCommand, args string) te
|
||||
return m.handleResumeCommand()
|
||||
case "/export":
|
||||
return m.handleExportCommand(args)
|
||||
case "/copy":
|
||||
return m.handleCopyCommand()
|
||||
case "/retry":
|
||||
return m.handleRetryCommand()
|
||||
case "/edit":
|
||||
return m.handleEditCommand(args)
|
||||
case "/share":
|
||||
return m.handleShareCommand()
|
||||
case "/import":
|
||||
@@ -3395,13 +3554,56 @@ func (m *AppModel) refreshPromptTemplates() {
|
||||
}
|
||||
}
|
||||
|
||||
// refreshSkillItems reloads skill items from the provider callback.
|
||||
// Called on ContentReloadEvent.
|
||||
// refreshSkillItems reloads skill items from the provider callback and
|
||||
// updates the autocomplete entries. Called on ContentReloadEvent.
|
||||
func (m *AppModel) refreshSkillItems() {
|
||||
if m.getSkillItems == nil {
|
||||
return
|
||||
}
|
||||
m.skillItems = m.getSkillItems()
|
||||
newItems := m.getSkillItems()
|
||||
m.skillItems = newItems
|
||||
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
// Remove old Skills commands and add fresh ones.
|
||||
var kept []commands.SlashCommand
|
||||
for _, sc := range ic.commands {
|
||||
if sc.Category != "Skills" {
|
||||
kept = append(kept, sc)
|
||||
}
|
||||
}
|
||||
for _, s := range newItems {
|
||||
kept = append(kept, commands.SlashCommand{
|
||||
Name: "/skill:" + s.Name,
|
||||
Description: formatSkillDescription(s),
|
||||
Category: "Skills",
|
||||
HasArgs: true,
|
||||
})
|
||||
}
|
||||
ic.commands = kept
|
||||
}
|
||||
}
|
||||
|
||||
// refreshExtensionItems reloads extension items from the provider callback
|
||||
// so the [Extensions] startup section reflects the current set after a
|
||||
// hot-reload. Called from the extReloadResultMsg handler.
|
||||
func (m *AppModel) refreshExtensionItems() {
|
||||
if m.getExtensionItems == nil {
|
||||
return
|
||||
}
|
||||
m.extensionItems = m.getExtensionItems()
|
||||
}
|
||||
|
||||
// formatSkillDescription returns the autocomplete description for a skill,
|
||||
// prefixed with [project] or [user] so users can tell colliding names apart.
|
||||
func formatSkillDescription(s SkillItem) string {
|
||||
prefix := "[user]"
|
||||
if s.Source == "project" {
|
||||
prefix = "[project]"
|
||||
}
|
||||
if s.Description == "" {
|
||||
return prefix
|
||||
}
|
||||
return prefix + " " + s.Description
|
||||
}
|
||||
|
||||
// refreshMCPPrompts reloads MCP prompts from the provider callback and
|
||||
@@ -3476,6 +3678,9 @@ func (m *AppModel) printHelpMessage() {
|
||||
"**System:**\n" +
|
||||
"- `/compact [instructions]`: Summarise older messages to free context space\n" +
|
||||
"- `/clear`: Clear message history\n" +
|
||||
"- `/copy`: Copy the last message to the system clipboard\n" +
|
||||
"- `/retry`: Resubmit the last user message (e.g. after a provider error)\n" +
|
||||
"- `/edit [path]`: Open a file in `$EDITOR` (fuzzy-find from cwd)\n" +
|
||||
"- `/export [path]`: Export session as JSONL\n" +
|
||||
"- `/import <path.jsonl>`: Import session from JSONL file\n" +
|
||||
"- `/reset-usage`: Reset usage statistics\n" +
|
||||
@@ -3712,7 +3917,12 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
}
|
||||
// Auto-scroll to bottom if enabled (iteratr pattern)
|
||||
// Don't call SetItems() - the slice reference hasn't changed
|
||||
if m.scrollList != nil {
|
||||
//
|
||||
// CRITICAL: never scroll the viewport while the user is actively
|
||||
// selecting text (mouse button held). Doing so shifts the
|
||||
// highlighted content out from under the cursor and produces the
|
||||
// off-by-N-row drift users see when copy-selecting during streaming.
|
||||
if m.scrollList != nil && !m.scrollList.IsMouseDown() {
|
||||
if m.scrollList.autoScroll {
|
||||
m.scrollList.GotoBottom()
|
||||
} else if m.scrollList.AtBottom() {
|
||||
@@ -3740,6 +3950,36 @@ func (m *AppModel) appendStreamingChunk(role, content string) {
|
||||
m.refreshContent()
|
||||
}
|
||||
|
||||
// currentScrollbackBounds returns the live (yOffset, viewportHeight) for the
|
||||
// scrollback region, computed from the current state — not from the cached
|
||||
// values populated inside View().
|
||||
//
|
||||
// scrollbackYOffset and scrollList.height are refreshed once per render, so
|
||||
// any state change that resizes the header (extension widget toggles,
|
||||
// warning rows, queued messages, etc.) leaves the cached values one frame
|
||||
// stale. Mouse click handlers in Update() can then place the cursor on the
|
||||
// wrong line, producing the off-by-N-row drift seen during copy-selection.
|
||||
//
|
||||
// This recomputes the header height by rendering it (cheap — the renderer
|
||||
// returns "" when no extension header is set) and recomputes the viewport
|
||||
// height the same way distributeHeight() does, so both inputs to the
|
||||
// y → (item, line) mapping are always current.
|
||||
func (m *AppModel) currentScrollbackBounds() (yOffset, viewportHeight int) {
|
||||
// Force a fresh layout if anything in Update() marked the state dirty;
|
||||
// otherwise scrollList.height still reflects the previous frame.
|
||||
if m.layoutDirty {
|
||||
m.distributeHeight()
|
||||
m.layoutDirty = false
|
||||
}
|
||||
if headerView := m.renderHeaderFooter(m.getHeader); headerView != "" {
|
||||
yOffset = lipgloss.Height(headerView)
|
||||
}
|
||||
if m.scrollList != nil {
|
||||
viewportHeight = m.scrollList.height
|
||||
}
|
||||
return yOffset, viewportHeight
|
||||
}
|
||||
|
||||
// distributeHeight recalculates child component heights after a window resize,
|
||||
// queue change, widget update, or state transition, and propagates the computed
|
||||
// stream height to the StreamComponent.
|
||||
@@ -3812,7 +4052,20 @@ func (m *AppModel) distributeHeight() {
|
||||
headerFooterLines += lipgloss.Height(footerView)
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines, 0)
|
||||
// Account for transient warning rows that View() injects between the
|
||||
// scrollback and the separator. These flags are toggled by ESC/Ctrl+C
|
||||
// handlers; without subtracting them here the joined view exceeds
|
||||
// m.height by one line per active warning and the bottom of the screen
|
||||
// gets silently clipped — which in turn invalidates scrollbackYOffset.
|
||||
var warningLines int
|
||||
if m.canceling {
|
||||
warningLines++
|
||||
}
|
||||
if m.ctrlCPressedOnce {
|
||||
warningLines++
|
||||
}
|
||||
|
||||
streamHeight := max(m.height-separatorLines-widgetLines-headerFooterLines-queuedLines-inputLines-statusBarLines-warningLines, 0)
|
||||
|
||||
// In alt screen mode, give the calculated height to ScrollList instead of stream.
|
||||
// The stream component still exists but is embedded as the last item in scrollList.
|
||||
@@ -3912,11 +4165,31 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
m.switchModel(args)
|
||||
return nil
|
||||
}
|
||||
|
||||
// switchModel performs a direct model switch, shared by the model selector
|
||||
// overlay and the /model slash command: it adjusts the thinking level when
|
||||
// the new model doesn't support the current one, calls the setModel
|
||||
// callback, updates display state, persists preferences, and emits the
|
||||
// ModelChange extension event.
|
||||
//
|
||||
// Display state is updated directly — we cannot use NotifyModelChanged
|
||||
// (prog.Send) from inside Update() without deadlocking BubbleTea.
|
||||
func (m *AppModel) switchModel(modelString string) {
|
||||
if m.setModel == nil {
|
||||
m.printSystemMessage("Model switching is not available.")
|
||||
return
|
||||
}
|
||||
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
|
||||
// Check if thinking level needs adjustment for the new model.
|
||||
// Some models (e.g., OpenAI gpt-5.4) don't support "minimal" and require "none".
|
||||
if m.thinkingLevel != "" && m.thinkingLevel != "off" {
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
modelName := parts[1]
|
||||
currentLevel := models.ParseThinkingLevel(m.thinkingLevel)
|
||||
if !models.IsValidThinkingLevelForModel(currentLevel, modelName) {
|
||||
@@ -3936,32 +4209,26 @@ func (m *AppModel) handleModelCommand(args string) tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// Direct model switch with the provided model string.
|
||||
previousModel := m.providerName + "/" + m.modelName
|
||||
if err := m.setModel(args); err != nil {
|
||||
if err := m.setModel(modelString); err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to switch model: %v", err))
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
// Update display state directly (cannot use prog.Send from Update).
|
||||
parts := strings.SplitN(args, "/", 2)
|
||||
if len(parts) == 2 {
|
||||
if parts := strings.SplitN(modelString, "/", 2); len(parts) == 2 {
|
||||
m.providerName = parts[0]
|
||||
m.modelName = parts[1]
|
||||
}
|
||||
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
prev := previousModel
|
||||
newModel := args
|
||||
go emit(newModel, prev, "user")
|
||||
}
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", modelString))
|
||||
|
||||
// Persist model selection for next launch.
|
||||
go func() { _ = prefs.SaveModelPreference(args) }()
|
||||
go func() { _ = prefs.SaveModelPreference(modelString) }()
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf("Switched to %s", args))
|
||||
return nil
|
||||
if m.emitModelChange != nil {
|
||||
emit := m.emitModelChange
|
||||
go emit(modelString, previousModel, "user")
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -4236,6 +4503,183 @@ func (m *AppModel) handleNameCommand(args string) tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleCopyCommand copies the last user or assistant message to the system
|
||||
// clipboard. Skips transient system messages (e.g. /help output) so the user
|
||||
// gets the actual last conversational message.
|
||||
func (m *AppModel) handleCopyCommand() tea.Cmd {
|
||||
if len(m.messages) == 0 {
|
||||
m.printSystemMessage("No messages to copy.")
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
text string
|
||||
role string
|
||||
)
|
||||
for i := len(m.messages) - 1; i >= 0; i-- {
|
||||
switch msg := m.messages[i].(type) {
|
||||
case *TextMessageItem:
|
||||
if msg.role == "user" || msg.role == "assistant" {
|
||||
text = msg.content
|
||||
role = msg.role
|
||||
}
|
||||
case *StreamingMessageItem:
|
||||
if msg.role == "assistant" || msg.role == "reasoning" {
|
||||
text = msg.content.String()
|
||||
role = msg.role
|
||||
}
|
||||
}
|
||||
if text != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(text) == "" {
|
||||
m.printSystemMessage("No copyable message found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.printSystemMessage(fmt.Sprintf(
|
||||
"Copied last %s message to clipboard (%d chars).", role, len(text),
|
||||
))
|
||||
return clipboard.CopyToClipboard(text)
|
||||
}
|
||||
|
||||
// handleRetryCommand resubmits the most recent user message on the current
|
||||
// branch. Used to recover from transient provider errors (overloaded,
|
||||
// timeout) without users having to retype — and without the duplicate-user-
|
||||
// message bloat that retyping creates.
|
||||
//
|
||||
// Flow:
|
||||
// 1. App.PopLastUserMessage() truncates the tree at the parent of the last
|
||||
// user message and returns its text + any image parts. The failed turn's
|
||||
// entries become orphaned (still on disk, off-branch) so they will not
|
||||
// be re-sent to the LLM.
|
||||
// 2. The visible message list is rebuilt from the truncated branch so the
|
||||
// prior user message + any partial assistant + error rendering vanish.
|
||||
// 3. The prompt is resubmitted via Run/RunWithFiles, mirroring the normal
|
||||
// SubmitMsg display path (badge formatting, pending-prints flush,
|
||||
// stateWorking transition).
|
||||
func (m *AppModel) handleRetryCommand() tea.Cmd {
|
||||
if m.appCtrl == nil {
|
||||
m.printSystemMessage("App controller unavailable.")
|
||||
return nil
|
||||
}
|
||||
|
||||
prompt, files, err := m.appCtrl.PopLastUserMessage()
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Cannot retry: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rebuild the visible ScrollList from the truncated branch so the failed
|
||||
// turn's user message and any partial assistant/error rendering disappear
|
||||
// before the resubmit prints a fresh user message.
|
||||
m.messages = []MessageItem{}
|
||||
m.renderSessionHistory()
|
||||
|
||||
// Mirror SubmitMsg's badge formatting for the display text.
|
||||
var imageCount, fileOnlyCount int
|
||||
for _, f := range files {
|
||||
if strings.HasPrefix(f.MediaType, "image/") {
|
||||
imageCount++
|
||||
} else {
|
||||
fileOnlyCount++
|
||||
}
|
||||
}
|
||||
displayText := prompt
|
||||
if imageCount > 0 || fileOnlyCount > 0 {
|
||||
var badges []string
|
||||
if imageCount > 0 {
|
||||
badges = append(badges, fmt.Sprintf("%d image(s) pasted", imageCount))
|
||||
}
|
||||
if fileOnlyCount > 0 {
|
||||
badges = append(badges, fmt.Sprintf("%d file(s) attached", fileOnlyCount))
|
||||
}
|
||||
displayText = fmt.Sprintf("%s\n[%s]", prompt, strings.Join(badges, ", "))
|
||||
}
|
||||
|
||||
var qLen int
|
||||
if len(files) > 0 {
|
||||
qLen = m.appCtrl.RunWithFiles(prompt, files)
|
||||
} else {
|
||||
qLen = m.appCtrl.Run(prompt)
|
||||
}
|
||||
if qLen > 0 {
|
||||
m.queuedMessages = append(m.queuedMessages, displayText)
|
||||
m.layoutDirty = true
|
||||
} else {
|
||||
m.pendingUserPrints = append(m.pendingUserPrints, displayText)
|
||||
m.flushStreamAndPendingUserMessages()
|
||||
}
|
||||
if m.state != stateWorking {
|
||||
m.state = stateWorking
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleEditCommand opens the supplied path in $EDITOR via tea.ExecProcess,
|
||||
// pausing the TUI for the duration of the editor session. The path is
|
||||
// resolved relative to cwd; ~/ and absolute paths are honoured. Non-existent
|
||||
// paths are allowed — most editors will create the file on save.
|
||||
//
|
||||
// On exit an editFileMsg is emitted with the resolved path (or error) so the
|
||||
// Update loop can report the result. The textarea is not touched — use
|
||||
// Ctrl+X e if you want to round-trip a prompt through $EDITOR instead.
|
||||
func (m *AppModel) handleEditCommand(args string) tea.Cmd {
|
||||
path := strings.TrimSpace(args)
|
||||
if path == "" {
|
||||
m.printSystemMessage("Usage: `/edit <path>` — or type `/edit ` and pick a file from the popup.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Strip optional surrounding double-quotes (the autocomplete inserts
|
||||
// these when a path contains spaces).
|
||||
if len(path) >= 2 && strings.HasPrefix(path, `"`) && strings.HasSuffix(path, `"`) {
|
||||
path = path[1 : len(path)-1]
|
||||
}
|
||||
|
||||
// Resolve ~/, relative, and absolute paths against cwd.
|
||||
resolved := path
|
||||
if strings.HasPrefix(resolved, "~/") {
|
||||
if home, err := os.UserHomeDir(); err == nil {
|
||||
resolved = filepath.Join(home, resolved[2:])
|
||||
}
|
||||
}
|
||||
if !filepath.IsAbs(resolved) {
|
||||
cwd, err := os.Getwd()
|
||||
if err == nil {
|
||||
resolved = filepath.Join(cwd, resolved)
|
||||
}
|
||||
}
|
||||
resolved = filepath.Clean(resolved)
|
||||
|
||||
// Reject paths that exist but are directories — $EDITOR semantics vary.
|
||||
if info, err := os.Stat(resolved); err == nil && info.IsDir() {
|
||||
m.printSystemMessage(fmt.Sprintf("`%s` is a directory, not a file.", resolved))
|
||||
return nil
|
||||
}
|
||||
|
||||
editorApp := os.Getenv("VISUAL")
|
||||
if editorApp == "" {
|
||||
editorApp = os.Getenv("EDITOR")
|
||||
}
|
||||
if editorApp == "" {
|
||||
m.printSystemMessage("Set `$EDITOR` or `$VISUAL` to use `/edit`")
|
||||
return nil
|
||||
}
|
||||
|
||||
editorCmd, cmdErr := editor.Command(editorApp, resolved)
|
||||
if cmdErr != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to open editor: %v", cmdErr))
|
||||
return nil
|
||||
}
|
||||
|
||||
return tea.ExecProcess(editorCmd, func(err error) tea.Msg {
|
||||
return editFileMsg{path: resolved, err: err}
|
||||
})
|
||||
}
|
||||
|
||||
// handleExportCommand exports the current session to a file.
|
||||
// Usage: /export — copies the JSONL file to cwd with a descriptive name.
|
||||
//
|
||||
@@ -4351,61 +4795,11 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
return r
|
||||
}, name)
|
||||
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
tmpPath, err := buildShareFile(name, data, sysPromptJSON)
|
||||
if err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to create temp file: %v", err))
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to share session: %v", err))
|
||||
return nil
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the header.
|
||||
// The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
|
||||
if len(lines) > 0 {
|
||||
// Write header (first line)
|
||||
if _, err := tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write system prompt entry
|
||||
if _, err := tmpFile.Write(sysPromptJSON); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write system prompt: %v", err))
|
||||
return nil
|
||||
}
|
||||
if _, err := tmpFile.WriteString("\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write remaining lines
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err := tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to write temp file: %v", err))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ = tmpFile.Close()
|
||||
|
||||
m.printSystemMessage("Uploading session to GitHub Gist...")
|
||||
|
||||
@@ -4431,6 +4825,56 @@ func (m *AppModel) handleShareCommand() tea.Cmd {
|
||||
}
|
||||
}
|
||||
|
||||
// buildShareFile assembles a temp JSONL file containing the session data
|
||||
// with the system-prompt entry inserted after the header line. On success
|
||||
// the caller owns the returned file and must remove it when done; on error
|
||||
// any partially-written temp file has already been cleaned up.
|
||||
func buildShareFile(name string, data, sysPromptJSON []byte) (tmpPath string, err error) {
|
||||
tmpFile, err := os.CreateTemp("", fmt.Sprintf("kit-%s-*.jsonl", name))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath = tmpFile.Name()
|
||||
defer func() {
|
||||
_ = tmpFile.Close()
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
// Write the session data with the system prompt entry inserted after the
|
||||
// header. The header is the first line, so we write:
|
||||
// 1. First line (header) from original data
|
||||
// 2. System prompt entry
|
||||
// 3. Remaining lines from original data
|
||||
lines := strings.Split(string(data), "\n")
|
||||
if len(lines) > 0 && lines[len(lines)-1] == "" {
|
||||
lines = lines[:len(lines)-1] // Remove trailing empty line
|
||||
}
|
||||
if len(lines) == 0 {
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
if _, err = tmpFile.WriteString(lines[0] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.Write(sysPromptJSON); err != nil {
|
||||
return "", fmt.Errorf("write system prompt: %w", err)
|
||||
}
|
||||
if _, err = tmpFile.WriteString("\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
for i := 1; i < len(lines); i++ {
|
||||
if lines[i] == "" {
|
||||
continue // Skip empty lines
|
||||
}
|
||||
if _, err = tmpFile.WriteString(lines[i] + "\n"); err != nil {
|
||||
return "", fmt.Errorf("write temp file: %w", err)
|
||||
}
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// handleImportCommand imports a session from a JSONL file.
|
||||
// Usage: /import path.jsonl
|
||||
func (m *AppModel) handleImportCommand(args string) tea.Cmd {
|
||||
@@ -4646,6 +5090,14 @@ type externalEditorMsg struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// editFileMsg is sent when the user returns from $EDITOR after invoking the
|
||||
// /edit slash command on a specific file. Unlike externalEditorMsg, no text
|
||||
// is read back — the user edited the file directly on disk.
|
||||
type editFileMsg struct {
|
||||
path string
|
||||
err error
|
||||
}
|
||||
|
||||
// shareResultMsg carries the result of an async gist upload.
|
||||
type shareResultMsg struct {
|
||||
err error
|
||||
|
||||
@@ -2,6 +2,7 @@ package ui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -87,6 +88,10 @@ func (s *stubAppController) SteerWithFiles(prompt string, _ []kit.LLMFilePart) i
|
||||
return s.queueLen
|
||||
}
|
||||
|
||||
func (s *stubAppController) PopLastUserMessage() (string, []kit.LLMFilePart, error) {
|
||||
return "", nil, fmt.Errorf("no user message to retry")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Stub child components
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
uicore "github.com/mark3labs/kit/internal/ui/core"
|
||||
)
|
||||
|
||||
// drainCmds runs a tea.Cmd chain back through m.Update like the BubbleTea
|
||||
// event loop, expanding batches, until no further messages are produced.
|
||||
func drainCmds(t *testing.T, m *AppModel, cmd tea.Cmd) *AppModel {
|
||||
t.Helper()
|
||||
queue := []tea.Cmd{cmd}
|
||||
for i := 0; i < 50 && len(queue) > 0; i++ {
|
||||
c := queue[0]
|
||||
queue = queue[1:]
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
msg := c()
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if batch, ok := msg.(tea.BatchMsg); ok {
|
||||
queue = append(queue, batch...)
|
||||
continue
|
||||
}
|
||||
updated, nc := m.Update(msg)
|
||||
m = updated.(*AppModel)
|
||||
_ = m.View()
|
||||
if nc != nil {
|
||||
queue = append(queue, nc)
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func measuredInputHeight(m *AppModel) int {
|
||||
rendered := m.renderInput()
|
||||
if rendered == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(rendered, "\n") + 1
|
||||
}
|
||||
|
||||
// TestPendingThumbnailTriggersLayoutRecompute is a regression test for the bug
|
||||
// where a pasted image's async half-block preview rendered but was clipped off
|
||||
// the bottom of the screen: the thumbnail arrives via thumbnailReadyMsg after
|
||||
// distributeHeight already measured the input region without it. The parent
|
||||
// must mark the layout dirty so the (now taller) input is re-measured.
|
||||
func TestPendingThumbnailTriggersLayoutRecompute(t *testing.T) {
|
||||
// Force a truecolor profile so imagepreview.Render deterministically
|
||||
// produces a thumbnail regardless of the CI terminal's color support.
|
||||
// Without this, a low-color test environment yields an empty preview and
|
||||
// the glyph / height assertions below would flake.
|
||||
t.Setenv("TERM", "xterm-256color")
|
||||
t.Setenv("COLORTERM", "truecolor")
|
||||
t.Setenv("NO_COLOR", "")
|
||||
|
||||
real := NewInputComponent(80, nil)
|
||||
m, _, _ := newTestAppModel(nil)
|
||||
m.input = real
|
||||
m = sendMsg(m, tea.WindowSizeMsg{Width: 80, Height: 24})
|
||||
|
||||
heightBefore := measuredInputHeight(m)
|
||||
|
||||
updated, cmd := m.Update(clipboardImageMsg{image: &uicore.ImageAttachment{
|
||||
Data: makeTestPNG(t, 16, 16),
|
||||
MediaType: "image/png",
|
||||
}})
|
||||
m = updated.(*AppModel)
|
||||
_ = m.View()
|
||||
m = drainCmds(t, m, cmd)
|
||||
|
||||
heightAfter := measuredInputHeight(m)
|
||||
if heightAfter <= heightBefore {
|
||||
t.Errorf("input region should grow to fit the thumbnail (before=%d after=%d)", heightBefore, heightAfter)
|
||||
}
|
||||
|
||||
if !strings.Contains(m.View().Content, "▀") {
|
||||
t.Error("parent View should contain the half-block thumbnail (was clipped or not rendered)")
|
||||
}
|
||||
}
|
||||
+182
-46
@@ -20,17 +20,23 @@ type PopupItem struct {
|
||||
Meta any // opaque data returned on selection
|
||||
}
|
||||
|
||||
// PopupList is a generic, themed, scrollable fuzzy-find popup list. It is
|
||||
// rendered as a centered overlay on top of the normal TUI layout and can be
|
||||
// reused by any feature that needs a selection popup (slash commands, model
|
||||
// selector, session picker, extension-provided lists, etc.).
|
||||
// PopupList is a generic, themed, scrollable popup list used by every
|
||||
// list-style popup in the TUI (slash commands, @file autocomplete, model
|
||||
// picker, session picker, tree navigation, etc.).
|
||||
//
|
||||
// The caller is responsible for:
|
||||
// - Building the initial item list
|
||||
// - Providing a fuzzy-filter callback (or nil for substring matching)
|
||||
// - Handling the result when the user selects or cancels
|
||||
// Two layout modes:
|
||||
// - Centered (default): bordered ~80-col box centered on the screen. Used
|
||||
// for the input-bar popups (/ and @) and the model picker.
|
||||
// - FullScreen: bordered panel filling almost the entire terminal. Used by
|
||||
// /tree, /fork, /sessions and other browse-many-items popups.
|
||||
//
|
||||
// Navigation: up/down to move, enter to select, esc to cancel, type to filter.
|
||||
// Two usage modes:
|
||||
// - Internal state: caller creates the list with items, calls HandleKey for
|
||||
// navigation/search, and PopupList owns the cursor and search string.
|
||||
// Used by selectors like ModelSelector, TreeSelector, SessionSelector.
|
||||
// - External state: caller drives the items / cursor / search themselves
|
||||
// (e.g. InputComponent, where typing in the textarea filters the list).
|
||||
// Caller uses SetItems / SetCursor / SetSearch and only calls Render.
|
||||
type PopupList struct {
|
||||
// Title shown at the top of the popup.
|
||||
Title string
|
||||
@@ -38,20 +44,45 @@ type PopupList struct {
|
||||
Subtitle string
|
||||
// FooterHint overrides the default keyboard-hint footer.
|
||||
FooterHint string
|
||||
// ExtraFooter is appended to the footer line (after the default hint).
|
||||
// Used by selectors to surface mode info like the active filter.
|
||||
ExtraFooter string
|
||||
|
||||
allItems []PopupItem // full unfiltered list
|
||||
filtered []PopupItem // subset matching the current search
|
||||
cursor int
|
||||
search string
|
||||
// FullScreen renders the popup at almost the full terminal size instead
|
||||
// of a centered ~80-col box. Used by tree/session/fork selectors.
|
||||
FullScreen bool
|
||||
|
||||
// ShowSearch toggles the "> <query>" search input line. Default true.
|
||||
ShowSearch bool
|
||||
|
||||
// HideCount suppresses the "(i/N)" count in the footer.
|
||||
HideCount bool
|
||||
|
||||
// MaxVisible caps the number of items visible at once. 0 = derive from
|
||||
// available height.
|
||||
MaxVisible int
|
||||
|
||||
// RenderItem optionally renders a single item row. When nil, the
|
||||
// built-in label + description + active-checkmark renderer is used.
|
||||
// innerWidth is the usable line width inside the popup (after border
|
||||
// and padding). The returned string must already be styled — the
|
||||
// shared selection-row background is applied by the popup only when
|
||||
// RenderItem is nil.
|
||||
RenderItem func(item PopupItem, innerWidth int, isCursor bool) string
|
||||
|
||||
// FilterFunc is called with (query, allItems) and should return the
|
||||
// filtered+scored subset. When nil, a default substring match is used.
|
||||
// filtered+scored subset. When nil, a default substring + fuzzy match
|
||||
// is used. Only consulted in internal-state mode (via HandleKey).
|
||||
FilterFunc func(query string, items []PopupItem) []PopupItem
|
||||
|
||||
width int
|
||||
height int
|
||||
maxVisible int // max items visible at once (0 = auto from height)
|
||||
showSearch bool
|
||||
allItems []PopupItem // full unfiltered list (internal-state mode)
|
||||
filtered []PopupItem // items currently rendered (driven by FilterFunc
|
||||
// in internal-state mode, or set directly via SetItems in external mode)
|
||||
cursor int
|
||||
search string
|
||||
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
// PopupResult is returned by HandleKey to tell the caller what happened.
|
||||
@@ -72,7 +103,7 @@ func NewPopupList(title string, items []PopupItem, width, height int) *PopupList
|
||||
filtered: items,
|
||||
width: width,
|
||||
height: height,
|
||||
showSearch: true,
|
||||
ShowSearch: true,
|
||||
}
|
||||
// Position cursor on the active item if one exists.
|
||||
for i, item := range p.filtered {
|
||||
@@ -90,25 +121,102 @@ func (p *PopupList) SetSize(width, height int) {
|
||||
p.height = height
|
||||
}
|
||||
|
||||
// SetItems replaces the displayed item list and clamps the cursor. Used by
|
||||
// external-state callers (e.g. InputComponent) that filter items themselves.
|
||||
// In internal-state mode, this also replaces the unfiltered backing list.
|
||||
func (p *PopupList) SetItems(items []PopupItem) {
|
||||
p.allItems = items
|
||||
p.filtered = items
|
||||
if p.cursor >= len(p.filtered) {
|
||||
p.cursor = max(len(p.filtered)-1, 0)
|
||||
}
|
||||
if p.cursor < 0 {
|
||||
p.cursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
// SetCursor moves the selection to the given index (clamped to range).
|
||||
func (p *PopupList) SetCursor(i int) {
|
||||
if len(p.filtered) == 0 {
|
||||
p.cursor = 0
|
||||
return
|
||||
}
|
||||
if i < 0 {
|
||||
i = 0
|
||||
}
|
||||
if i >= len(p.filtered) {
|
||||
i = len(p.filtered) - 1
|
||||
}
|
||||
p.cursor = i
|
||||
}
|
||||
|
||||
// Cursor returns the current selection index.
|
||||
func (p *PopupList) Cursor() int { return p.cursor }
|
||||
|
||||
// SetSearch replaces the search string without rebuilding the filtered list.
|
||||
// Used by external-state callers that filter items themselves.
|
||||
func (p *PopupList) SetSearch(s string) { p.search = s }
|
||||
|
||||
// Items returns the currently-visible (filtered) items.
|
||||
func (p *PopupList) Items() []PopupItem { return p.filtered }
|
||||
|
||||
// Search returns the current search string.
|
||||
func (p *PopupList) Search() string { return p.search }
|
||||
|
||||
// dimensions returns the (popupWidth, popupHeight, innerWidth, innerHeight)
|
||||
// the popup will render at, given its current size and FullScreen flag.
|
||||
func (p *PopupList) dimensions() (popupW, popupH, innerW, innerH int) {
|
||||
if p.FullScreen {
|
||||
// Leave a small margin so the border doesn't kiss the screen edge.
|
||||
popupW = max(p.width-2, 20)
|
||||
popupH = max(p.height-2, 10)
|
||||
} else {
|
||||
// Centered: cap at 80 cols, leave a 4-col margin.
|
||||
popupW = max(min(p.width-4, 80), 20)
|
||||
// Height is dynamic — let it grow with content within the screen.
|
||||
popupH = 0
|
||||
}
|
||||
// Border (2) + horizontal padding (4) = 6 chrome cols.
|
||||
innerW = max(popupW-6, 10)
|
||||
if popupH > 0 {
|
||||
// Border (2) + vertical padding (2) = 4 chrome rows.
|
||||
innerH = max(popupH-4, 6)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// visibleCount returns the number of items visible at once.
|
||||
func (p *PopupList) visibleCount() int {
|
||||
if p.maxVisible > 0 {
|
||||
return p.maxVisible
|
||||
if p.MaxVisible > 0 {
|
||||
return p.MaxVisible
|
||||
}
|
||||
// Reserve: title(1) + subtitle(1) + search(1) + separator(1) + footer(2) + border(2) + padding(2) = 10
|
||||
if p.FullScreen {
|
||||
_, _, _, innerH := p.dimensions()
|
||||
// Reserve: title(1) + subtitle(0|1) + search(0|2) + sep(1) + footer(2)
|
||||
overhead := 4
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.ShowSearch {
|
||||
overhead += 2
|
||||
}
|
||||
return max(innerH-overhead, 3)
|
||||
}
|
||||
// Centered: derive from terminal height (legacy behaviour).
|
||||
overhead := 8
|
||||
if p.Subtitle != "" {
|
||||
overhead++
|
||||
}
|
||||
if p.showSearch {
|
||||
overhead += 2 // search line + separator
|
||||
if p.ShowSearch {
|
||||
overhead += 2
|
||||
}
|
||||
return max(p.height/2-overhead, 3)
|
||||
}
|
||||
|
||||
// HandleKey processes a single key event and returns the result. The caller
|
||||
// should inspect PopupResult to decide whether to re-render, close the popup,
|
||||
// or act on a selection.
|
||||
// or act on a selection. Internal-state mode only — external-state callers
|
||||
// drive cursor/search themselves and never call this.
|
||||
//
|
||||
// keyName is the Bubble Tea key string (e.g. "up", "down", "enter", "esc").
|
||||
// keyText is the printable text for character keys (e.g. "a", "1").
|
||||
@@ -191,7 +299,7 @@ func (p *PopupList) HandleKey(keyName, keyText string) PopupResult {
|
||||
// as a centered overlay via lipgloss.Place + overlayContent.
|
||||
func (p *PopupList) Render() string {
|
||||
theme := style.GetTheme()
|
||||
popupWidth := max(min(p.width-4, 80), 20)
|
||||
popupW, popupH, innerW, _ := p.dimensions()
|
||||
popupBg := theme.Background
|
||||
|
||||
popupStyle := lipgloss.NewStyle().
|
||||
@@ -199,11 +307,12 @@ func (p *PopupList) Render() string {
|
||||
BorderForeground(theme.Primary).
|
||||
Background(popupBg).
|
||||
Padding(1, 2).
|
||||
Width(popupWidth).
|
||||
MarginBottom(1)
|
||||
|
||||
// Inner content width: popup minus border (2) and horizontal padding (4).
|
||||
innerWidth := max(popupWidth-6, 10)
|
||||
Width(popupW)
|
||||
if popupH > 0 {
|
||||
popupStyle = popupStyle.Height(popupH)
|
||||
} else {
|
||||
popupStyle = popupStyle.MarginBottom(1)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
@@ -212,7 +321,7 @@ func (p *PopupList) Render() string {
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
b.WriteString(titleStyle.Render(p.Title))
|
||||
b.WriteString("\n")
|
||||
|
||||
@@ -221,17 +330,17 @@ func (p *PopupList) Render() string {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
b.WriteString(subtitleStyle.Render(p.Subtitle))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Search input.
|
||||
if p.showSearch {
|
||||
if p.ShowSearch {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(popupBg).
|
||||
Width(innerWidth)
|
||||
Width(innerW)
|
||||
if p.search != "" {
|
||||
b.WriteString(searchStyle.Render(fmt.Sprintf("> %s", p.search)))
|
||||
} else {
|
||||
@@ -243,7 +352,7 @@ func (p *PopupList) Render() string {
|
||||
sepStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg)
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerWidth)))
|
||||
b.WriteString(sepStyle.Render(strings.Repeat("─", innerW)))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
@@ -251,20 +360,20 @@ func (p *PopupList) Render() string {
|
||||
normalItemBg := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.Text).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
|
||||
selectedItemBg := lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1).
|
||||
Bold(true)
|
||||
|
||||
scrollStyle := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
Foreground(theme.VeryMuted).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
|
||||
vis := p.visibleCount()
|
||||
@@ -274,7 +383,7 @@ func (p *PopupList) Render() string {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(popupBg).
|
||||
Width(innerWidth).
|
||||
Width(innerW).
|
||||
Padding(0, 1)
|
||||
if p.search != "" {
|
||||
items = append(items, emptyStyle.Render("No matches for \""+p.search+"\""))
|
||||
@@ -282,9 +391,14 @@ func (p *PopupList) Render() string {
|
||||
items = append(items, emptyStyle.Render("No items"))
|
||||
}
|
||||
} else {
|
||||
// Center the cursor in the visible window so the user always sees
|
||||
// context above and below. Clamp to bounds.
|
||||
startIdx := 0
|
||||
if p.cursor >= vis {
|
||||
startIdx = p.cursor - vis + 1
|
||||
if len(p.filtered) > vis {
|
||||
startIdx = max(p.cursor-vis/2, 0)
|
||||
if startIdx+vis > len(p.filtered) {
|
||||
startIdx = len(p.filtered) - vis
|
||||
}
|
||||
}
|
||||
endIdx := min(startIdx+vis, len(p.filtered))
|
||||
|
||||
@@ -292,10 +406,27 @@ func (p *PopupList) Render() string {
|
||||
items = append(items, scrollStyle.Render(" ↑ more above"))
|
||||
}
|
||||
|
||||
// Account for the consumed padding (1 left + 1 right = 2 cols)
|
||||
// when rendering item content so RenderItem callbacks can match.
|
||||
itemContentWidth := max(innerW-2, 6)
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
entry := p.filtered[i]
|
||||
isCursor := i == p.cursor
|
||||
|
||||
if p.RenderItem != nil {
|
||||
// Custom renderer: caller produces the inner text. We still
|
||||
// wrap it in a full-width row so the selection highlight
|
||||
// covers the line edge-to-edge.
|
||||
rowStyle := normalItemBg
|
||||
if isCursor {
|
||||
rowStyle = selectedItemBg
|
||||
}
|
||||
content := p.RenderItem(entry, itemContentWidth, isCursor)
|
||||
items = append(items, rowStyle.Render(content))
|
||||
continue
|
||||
}
|
||||
|
||||
itemStyle := normalItemBg
|
||||
if isCursor {
|
||||
itemStyle = selectedItemBg
|
||||
@@ -310,7 +441,7 @@ func (p *PopupList) Render() string {
|
||||
}
|
||||
|
||||
// Build content: indicator + label + description + active checkmark.
|
||||
content := p.renderItemContent(indicator, entry, innerWidth, isCursor)
|
||||
content := p.renderItemContent(indicator, entry, itemContentWidth, isCursor)
|
||||
items = append(items, itemStyle.Render(content))
|
||||
}
|
||||
|
||||
@@ -323,19 +454,24 @@ func (p *PopupList) Render() string {
|
||||
|
||||
// Footer with count and keyboard hints.
|
||||
var footerParts []string
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
if !p.HideCount {
|
||||
footerParts = append(footerParts, fmt.Sprintf("(%d/%d)", p.cursor+1, len(p.filtered)))
|
||||
}
|
||||
|
||||
footerHint := p.FooterHint
|
||||
if footerHint == "" {
|
||||
if innerWidth >= 50 {
|
||||
if innerW >= 50 {
|
||||
footerHint = "↑↓ navigate • enter select • esc cancel • type to filter"
|
||||
} else if innerWidth >= 30 {
|
||||
} else if innerW >= 30 {
|
||||
footerHint = "↑↓ nav • ↵ select • esc"
|
||||
} else {
|
||||
footerHint = "↑↓ ↵ esc"
|
||||
}
|
||||
}
|
||||
footerParts = append(footerParts, footerHint)
|
||||
if p.ExtraFooter != "" {
|
||||
footerParts = append(footerParts, p.ExtraFooter)
|
||||
}
|
||||
|
||||
footer := lipgloss.NewStyle().
|
||||
Background(popupBg).
|
||||
|
||||
@@ -19,28 +19,7 @@ import (
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
// Any @file tokens in the content are highlighted with the theme accent color.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
}
|
||||
|
||||
// Wrap content before passing to herald Alert so long lines break
|
||||
// inside the alert box. Subtract 4 to account for the alert bar
|
||||
// prefix ("│ ") and a small margin.
|
||||
if width > 4 {
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
// Highlight @file tokens with accent color so file references are
|
||||
// visually distinct from surrounding prompt text.
|
||||
content = HighlightFileTokens(content, theme)
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
// UserBlock-related rendering helpers and herald typography.
|
||||
|
||||
// HighlightFileTokens wraps @file tokens in the given text with the theme
|
||||
// accent color so they stand out visually in rendered user messages.
|
||||
@@ -154,44 +133,6 @@ func ErrorBlock(errorMsg string, ty *herald.Typography, theme style.Theme) strin
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// ToolBlock renders a tool execution result with header and body.
|
||||
func ToolBlock(displayName, params, body string, isError bool, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
var icon string
|
||||
iconColor := theme.Success
|
||||
if isError {
|
||||
icon = "×"
|
||||
iconColor = theme.Error
|
||||
} else {
|
||||
icon = "✓"
|
||||
}
|
||||
|
||||
// Style the tool name with color
|
||||
nameColor := theme.Info
|
||||
if isError {
|
||||
nameColor = theme.Error
|
||||
}
|
||||
styledName := lipgloss.NewStyle().Foreground(nameColor).Bold(true).Render(displayName)
|
||||
styledIcon := lipgloss.NewStyle().Foreground(iconColor).Render(icon)
|
||||
|
||||
// Build the content: icon + name + params on first line, then body
|
||||
headerLine := styledIcon + " " + styledName
|
||||
if params != "" {
|
||||
headerLine += " " + lipgloss.NewStyle().Foreground(theme.Muted).Render(params)
|
||||
}
|
||||
|
||||
if strings.TrimSpace(body) == "" {
|
||||
body = ty.Italic("(no output)")
|
||||
}
|
||||
|
||||
// Compose: icon + name + params, then body
|
||||
fullContent := ty.Compose(
|
||||
headerLine,
|
||||
"",
|
||||
body,
|
||||
)
|
||||
return styleMarginBottom(theme, fullContent)
|
||||
}
|
||||
|
||||
// styleMarginBottom applies a 1-line margin bottom using the theme.
|
||||
func styleMarginBottom(theme style.Theme, content string) string {
|
||||
return style.GetCachedStyles().MarginBottom1.Render(content)
|
||||
|
||||
@@ -4,30 +4,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// testTypography creates a herald Typography for tests.
|
||||
func testTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertLabel(herald.AlertTip, ""),
|
||||
herald.WithAlertIcon(herald.AlertTip, ""),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHighlightFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
|
||||
@@ -88,24 +67,25 @@ func TestHighlightFileTokens(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserBlockHighlightsFileTokens(t *testing.T) {
|
||||
// TestHighlightFileTokensInjectsANSI verifies that HighlightFileTokens
|
||||
// preserves the original @file references in the output and wraps each
|
||||
// token with ANSI escape codes for the theme accent color.
|
||||
func TestHighlightFileTokensInjectsANSI(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
ty := testTypography(theme)
|
||||
|
||||
// A user message with @file tokens should contain ANSI escapes around the token.
|
||||
content := "refactor @main.go and @utils.go"
|
||||
result := UserBlock(content, 80, ty, theme)
|
||||
result := HighlightFileTokens(content, theme)
|
||||
|
||||
// The rendered output should contain both file references.
|
||||
// The output should still contain both file references.
|
||||
if !strings.Contains(result, "@main.go") {
|
||||
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @main.go, got:\n%s", result)
|
||||
}
|
||||
if !strings.Contains(result, "@utils.go") {
|
||||
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
|
||||
t.Errorf("HighlightFileTokens output should contain @utils.go, got:\n%s", result)
|
||||
}
|
||||
|
||||
// Verify ANSI codes are present (the tokens are styled).
|
||||
if !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
|
||||
t.Errorf("HighlightFileTokens output should contain ANSI escape codes for styled @file tokens")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,10 +60,13 @@ func NewScrollList(width, height int) *ScrollList {
|
||||
}
|
||||
|
||||
// SetItems replaces the items in the scroll list. If auto-scroll is enabled,
|
||||
// the viewport will scroll to the bottom to show the latest content.
|
||||
// the viewport will scroll to the bottom to show the latest content — EXCEPT
|
||||
// when the user is actively selecting text (mouse button held), in which case
|
||||
// the scroll position is locked so the highlighted content stays under the
|
||||
// cursor. The pending bottom-scroll is deferred to MouseUp.
|
||||
func (s *ScrollList) SetItems(items []MessageItem) {
|
||||
s.items = items
|
||||
if s.autoScroll {
|
||||
if s.autoScroll && !s.sel.MouseDown {
|
||||
s.GotoBottom()
|
||||
}
|
||||
}
|
||||
@@ -157,6 +160,10 @@ func (s *ScrollList) HandleMouseDown(x, y int) bool {
|
||||
// HandleMouseDrag handles mouse motion while button is held.
|
||||
// Updates the selection endpoint for character-level precision.
|
||||
// Returns true if selection was updated.
|
||||
//
|
||||
// Defensively disables auto-scroll on every drag update — even if the
|
||||
// MouseDown handler missed (e.g. click landed in viewport padding), any
|
||||
// active drag means the user is selecting and the viewport must not jump.
|
||||
func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
if !s.sel.MouseDown {
|
||||
return false
|
||||
@@ -171,6 +178,9 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Hard-lock the viewport while dragging.
|
||||
s.autoScroll = false
|
||||
|
||||
s.sel.DragItemIdx = itemIdx
|
||||
s.sel.DragLineIdx = lineIdx
|
||||
s.sel.DragCol = x
|
||||
@@ -178,6 +188,13 @@ func (s *ScrollList) HandleMouseDrag(x, y int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// IsMouseDown reports whether the user currently has the mouse button held
|
||||
// (i.e. a selection drag is in progress). Used by the parent model to avoid
|
||||
// re-enabling auto-scroll during streaming while the user is selecting.
|
||||
func (s *ScrollList) IsMouseDown() bool {
|
||||
return s.sel.MouseDown
|
||||
}
|
||||
|
||||
// HandleMouseUp handles mouse button release.
|
||||
// Returns true if there was an active selection.
|
||||
func (s *ScrollList) HandleMouseUp() bool {
|
||||
@@ -521,6 +538,21 @@ func (s *ScrollList) View() string {
|
||||
for idx := s.offsetIdx; idx < len(s.items) && remainingHeight > 0; idx++ {
|
||||
item := s.items[idx]
|
||||
content := item.Render(s.width)
|
||||
|
||||
// Items that render to an empty string contribute zero height to
|
||||
// the viewport. This MUST match renderedHeight()'s semantics —
|
||||
// otherwise getItemAndLineAtY (which uses renderedHeight) treats
|
||||
// the item as 0 lines while View() emits one blank line via
|
||||
// strings.Split("", "\n") = [""], producing a 1-row downward
|
||||
// drift in mouse hit-testing per empty item between offsetIdx
|
||||
// and the cursor (most visibly streaming-reasoning items before
|
||||
// any reasoning has streamed, which extension widgets surface by
|
||||
// shrinking the scrollback).
|
||||
if content == "" {
|
||||
s.heightCache[item.ID()] = 0
|
||||
continue
|
||||
}
|
||||
|
||||
contentLines := strings.Split(content, "\n")
|
||||
|
||||
// Refresh height cache from the actual render (authoritative).
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
package ui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fakeItem is a deterministic MessageItem for ScrollList tests.
|
||||
type fakeItem struct {
|
||||
id string
|
||||
lines int
|
||||
}
|
||||
|
||||
func (f *fakeItem) ID() string { return f.id }
|
||||
func (f *fakeItem) Render(_ int) string {
|
||||
if f.lines <= 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, f.lines)
|
||||
for i := range parts {
|
||||
parts[i] = fmt.Sprintf("%s-line-%d", f.id, i)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
func (f *fakeItem) Height() int { return f.lines }
|
||||
|
||||
// makeItems builds n fake items of `lines` height each.
|
||||
func makeItems(n, lines int) []MessageItem {
|
||||
out := make([]MessageItem, n)
|
||||
for i := range out {
|
||||
out[i] = &fakeItem{id: fmt.Sprintf("item-%d", i), lines: lines}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TestScrollList_MouseDownPreventsAutoScroll verifies the core fix for the
|
||||
// copy-selection drift bug: while the user has the mouse button held
|
||||
// (drag-selecting), incoming content updates must NOT shift the viewport,
|
||||
// because doing so moves the highlighted content out from under the cursor.
|
||||
func TestScrollList_MouseDownPreventsAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2)) // 40 lines of content into a 10-line viewport
|
||||
// Capture the auto-scrolled-to-bottom position.
|
||||
startOffsetIdx := sl.offsetIdx
|
||||
startOffsetLine := sl.offsetLine
|
||||
|
||||
// User clicks somewhere in the visible area, starting a drag-select.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown should accept a click inside the viewport")
|
||||
}
|
||||
if !sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be true after HandleMouseDown")
|
||||
}
|
||||
|
||||
// New content arrives. With autoScroll still true, SetItems would
|
||||
// normally call GotoBottom() and shift the viewport. The fix should
|
||||
// suppress that while MouseDown is held.
|
||||
sl.SetItems(makeItems(30, 2)) // 60 lines now
|
||||
if sl.offsetIdx != startOffsetIdx || sl.offsetLine != startOffsetLine {
|
||||
t.Errorf("viewport scrolled during active drag: was (%d,%d), now (%d,%d)",
|
||||
startOffsetIdx, startOffsetLine, sl.offsetIdx, sl.offsetLine)
|
||||
}
|
||||
|
||||
// User releases the mouse — drag is over.
|
||||
sl.HandleMouseUp()
|
||||
if sl.IsMouseDown() {
|
||||
t.Fatalf("IsMouseDown should be false after HandleMouseUp")
|
||||
}
|
||||
|
||||
// After release, a fresh content update should resume auto-scrolling
|
||||
// (move the offset to track the new bottom).
|
||||
afterReleaseIdx := sl.offsetIdx
|
||||
afterReleaseLine := sl.offsetLine
|
||||
sl.SetItems(makeItems(50, 2))
|
||||
if sl.offsetIdx == afterReleaseIdx && sl.offsetLine == afterReleaseLine {
|
||||
t.Errorf("autoscroll did not resume after MouseUp: offset stuck at (%d,%d)",
|
||||
afterReleaseIdx, afterReleaseLine)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_DragDisablesAutoScroll verifies that any successful
|
||||
// HandleMouseDrag call clears autoScroll, even when HandleMouseDown didn't
|
||||
// observe it (e.g. a stale wheel-down event set it back to true mid-stream).
|
||||
func TestScrollList_DragDisablesAutoScroll(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems(makeItems(20, 2))
|
||||
|
||||
// Begin a selection.
|
||||
if !sl.HandleMouseDown(5, 3) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
// Simulate an external code path that re-enabled autoScroll while
|
||||
// MouseDown is still held (the precise condition that caused drift).
|
||||
sl.autoScroll = true
|
||||
|
||||
// Drag motion should hard-lock the viewport again.
|
||||
if !sl.HandleMouseDrag(10, 4) {
|
||||
t.Fatalf("HandleMouseDrag failed")
|
||||
}
|
||||
if sl.autoScroll {
|
||||
t.Errorf("HandleMouseDrag must clear autoScroll to prevent mid-drag jumps")
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_SetItemsRespectsMouseDown is the most direct regression
|
||||
// test: even with autoScroll enabled and new content appended at the
|
||||
// bottom, SetItems must not move the viewport while a mouse drag is in
|
||||
// progress. This is what caused the "highlighting shifts by 1+ rows
|
||||
// during streaming" symptom reported by the user.
|
||||
func TestScrollList_SetItemsRespectsMouseDown(t *testing.T) {
|
||||
sl := NewScrollList(80, 5)
|
||||
sl.SetItems(makeItems(10, 2)) // 20 lines into a 5-line viewport
|
||||
// At bottom.
|
||||
preIdx, preLine := sl.offsetIdx, sl.offsetLine
|
||||
|
||||
// Hold mouse down (no actual drag needed).
|
||||
if !sl.HandleMouseDown(0, 0) {
|
||||
t.Fatalf("HandleMouseDown failed")
|
||||
}
|
||||
|
||||
// Append several more items as if streaming. With the bug, each
|
||||
// SetItems would call GotoBottom and shift the offset.
|
||||
for n := 11; n <= 15; n++ {
|
||||
sl.SetItems(makeItems(n, 2))
|
||||
if sl.offsetIdx != preIdx || sl.offsetLine != preLine {
|
||||
t.Fatalf("viewport drifted during streaming with mouse held: "+
|
||||
"start=(%d,%d) now=(%d,%d) after adding item %d",
|
||||
preIdx, preLine, sl.offsetIdx, sl.offsetLine, n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestScrollList_EmptyItemsDoNotShiftMouseMapping is the regression test
|
||||
// for the second drift bug: items that render to "" must contribute the
|
||||
// same number of rows in View() (zero) as in renderedHeight(), or mouse
|
||||
// hit-testing drifts by one row per empty item between offsetIdx and the
|
||||
// cursor. This was surfaced by extension widgets (e.g. subagent-monitor)
|
||||
// that shrink the scrollback so empty streaming-reasoning items end up
|
||||
// in the visible window.
|
||||
//
|
||||
// Setup: 1 normal item + 1 empty item + 1 normal item. Click on the line
|
||||
// where the third item begins. With the bug, getItemAndLineAtY skips the
|
||||
// empty item (renderedHeight=0) and reports lineIdx pointing one row
|
||||
// past where View() actually painted that line.
|
||||
func TestScrollList_EmptyItemsDoNotShiftMouseMapping(t *testing.T) {
|
||||
sl := NewScrollList(80, 10)
|
||||
sl.SetItems([]MessageItem{
|
||||
&fakeItem{id: "a", lines: 2}, // viewY 0–1
|
||||
&fakeItem{id: "empty", lines: 0}, // renders "" — contributes 0 rows
|
||||
&fakeItem{id: "b", lines: 2}, // viewY 2–3
|
||||
})
|
||||
|
||||
// Render the viewport once so the cache reflects what View() actually
|
||||
// emits (this is the path that previously diverged from renderedHeight
|
||||
// for empty items).
|
||||
rendered := sl.View()
|
||||
lines := strings.Split(rendered, "\n")
|
||||
|
||||
// Sanity: View() must emit exactly height lines.
|
||||
if len(lines) != 10 {
|
||||
t.Fatalf("View() returned %d lines, want 10", len(lines))
|
||||
}
|
||||
// Item b's first line should appear at viewY=2, NOT viewY=3.
|
||||
if !strings.Contains(lines[2], "b-line-0") {
|
||||
t.Errorf("viewY=2 should render b-line-0 (empty item contributes 0 rows), got %q", lines[2])
|
||||
}
|
||||
|
||||
// Now the actual hit-test contract: clicking on viewY=2 must map to
|
||||
// item b line 0 — the same coordinate View() rendered there.
|
||||
idx, line := sl.getItemAndLineAtY(2)
|
||||
if idx != 2 || line != 0 {
|
||||
t.Errorf("getItemAndLineAtY(2) = (%d,%d), want (2,0)", idx, line)
|
||||
}
|
||||
|
||||
// And clicking on the second line of b (viewY=3) must map to b line 1.
|
||||
idx, line = sl.getItemAndLineAtY(3)
|
||||
if idx != 2 || line != 1 {
|
||||
t.Errorf("getItemAndLineAtY(3) = (%d,%d), want (2,1)", idx, line)
|
||||
}
|
||||
}
|
||||
@@ -230,8 +230,10 @@ func FindWordBoundaries(line string, col int) (startCol, endCol int) {
|
||||
|
||||
// HighlightLine applies reverse-video highlighting to a portion of a rendered
|
||||
// line (which may contain ANSI escape codes). startCol/endCol are in display
|
||||
// columns. If startCol == -1, the entire line is highlighted. If startCol ==
|
||||
// endCol, returns the line unchanged.
|
||||
// columns. If startCol == -1, the entire line is highlighted. If endCol ==
|
||||
// -1, the highlight runs from startCol to the end of the line (the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection).
|
||||
// If startCol == endCol, returns the line unchanged.
|
||||
//
|
||||
// Uses ultraviolet ScreenBuffer for cell-level ANSI manipulation.
|
||||
func HighlightLine(line string, startCol, endCol int) string {
|
||||
@@ -250,6 +252,16 @@ func HighlightLine(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (returned by IsLineInRange
|
||||
// for the first line of a multi-line selection). Without this branch,
|
||||
// the start line of a multi-line drag would never be highlighted —
|
||||
// the user perceives this as the selection being shifted one row down
|
||||
// from the cursor, especially when extension widgets shrink the
|
||||
// scrollback and make the start line land on a tall styled block.
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return line
|
||||
}
|
||||
@@ -296,6 +308,11 @@ func ExtractText(line string, startCol, endCol int) string {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
// "From startCol to end of line" sentinel (see HighlightLine).
|
||||
if endCol < 0 {
|
||||
endCol = lineWidth
|
||||
}
|
||||
|
||||
if startCol >= endCol || startCol >= lineWidth {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -357,6 +357,54 @@ func TestHighlightLine_NoSelection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlightLine_EndOfLineSentinel verifies that endCol=-1 is interpreted
|
||||
// as "highlight from startCol to end of line", matching the sentinel
|
||||
// returned by IsLineInRange for the first line of a multi-line selection.
|
||||
//
|
||||
// Regression: without this contract, the start line of any multi-line drag
|
||||
// would silently fall through HighlightLine's startCol >= endCol guard and
|
||||
// render unstyled, making the selection appear to begin one row below the
|
||||
// cursor — the exact "tracking gets shifted" symptom users reported when
|
||||
// extension widgets shrank the scrollback enough that the click landed on a
|
||||
// styled tool-result block.
|
||||
func TestHighlightLine_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
result := HighlightLine(line, 0, -1)
|
||||
if result == line {
|
||||
t.Errorf("endCol=-1 should highlight from startCol to end of line; got unchanged input")
|
||||
}
|
||||
if len(result) <= len(line) {
|
||||
t.Errorf("highlighted result should be longer than plain input (ANSI codes added); got len=%d want > %d", len(result), len(line))
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractText_EndOfLineSentinel mirrors TestHighlightLine_EndOfLineSentinel
|
||||
// for the extraction path used by the clipboard copy.
|
||||
func TestExtractText_EndOfLineSentinel(t *testing.T) {
|
||||
line := "Hello, World!"
|
||||
got := ExtractText(line, 7, -1)
|
||||
want := "World!"
|
||||
if got != want {
|
||||
t.Errorf("ExtractText(line, 7, -1) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLineInRange_StartLineSentinelHighlights composes IsLineInRange with
|
||||
// HighlightLine end-to-end: the start line of a multi-line, single-item
|
||||
// selection must actually emit highlight ANSI codes. This is the contract
|
||||
// the rendering path in scrolllist.View() relies on.
|
||||
func TestIsLineInRange_StartLineSentinelHighlights(t *testing.T) {
|
||||
r := Range{StartItemIdx: 5, EndItemIdx: 5, StartLine: 0, EndLine: 2, StartCol: 0, EndCol: 10}
|
||||
inRange, sc, ec := IsLineInRange(r, 5, 0)
|
||||
if !inRange {
|
||||
t.Fatalf("item 5 line 0 should be in range")
|
||||
}
|
||||
highlighted := HighlightLine("first line of selection", sc, ec)
|
||||
if highlighted == "first line of selection" {
|
||||
t.Errorf("first line of multi-line selection was not highlighted (sc=%d ec=%d)", sc, ec)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiClickDetection verifies the click counting logic.
|
||||
func TestMultiClickDetection(t *testing.T) {
|
||||
s := NewState()
|
||||
|
||||
+131
-304
@@ -5,7 +5,6 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"charm.land/bubbles/v2/key"
|
||||
tea "charm.land/bubbletea/v2"
|
||||
@@ -62,17 +61,14 @@ func (m SessionFilterMode) String() string {
|
||||
// controlCharsRe matches ASCII control characters for stripping from previews.
|
||||
var controlCharsRe = regexp.MustCompile(`[\x00-\x1f\x7f]`)
|
||||
|
||||
// SessionSelectorComponent is a full-screen Bubble Tea component that lets
|
||||
// the user browse and select from available sessions. Modeled after pi's
|
||||
// session picker: right-aligned metadata, background-highlighted selection,
|
||||
// scope/filter toggles, and inline search.
|
||||
// SessionSelectorComponent is a Bubble Tea component that lets the user browse
|
||||
// and select from available sessions. It wraps PopupList in FullScreen mode:
|
||||
// PopupList owns the cursor/search/scroll math/chrome; this component owns
|
||||
// the session list, scope/filter toggles, and delete-confirmation flow.
|
||||
type SessionSelectorComponent struct {
|
||||
allSessions []session.SessionInfo
|
||||
cwdSessions []session.SessionInfo
|
||||
filtered []session.SessionInfo
|
||||
|
||||
cursor int
|
||||
search string
|
||||
filtered []session.SessionInfo // matches popup.Items() 1:1
|
||||
|
||||
scope SessionScopeMode
|
||||
filter SessionFilterMode
|
||||
@@ -80,6 +76,7 @@ type SessionSelectorComponent struct {
|
||||
// currentPath is the active session file path for marking it in the list.
|
||||
currentPath string
|
||||
|
||||
popup *PopupList
|
||||
width int
|
||||
height int
|
||||
active bool
|
||||
@@ -110,7 +107,12 @@ func NewSessionSelector(cwd string, width, height int) *SessionSelectorComponent
|
||||
ss.scope = SessionScopeAll
|
||||
}
|
||||
|
||||
ss.rebuildFiltered()
|
||||
ss.popup = NewPopupList("Resume Session", nil, width, height)
|
||||
ss.popup.FullScreen = true
|
||||
ss.popup.FooterHint = "↑↓ nav • ↵ open • esc cancel • tab scope • ^N named • d delete • type to search"
|
||||
ss.popup.RenderItem = ss.renderEntry
|
||||
|
||||
ss.rebuild()
|
||||
return ss
|
||||
}
|
||||
|
||||
@@ -131,10 +133,11 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
case tea.WindowSizeMsg:
|
||||
ss.width = msg.Width
|
||||
ss.height = msg.Height
|
||||
ss.popup.SetSize(msg.Width, msg.Height)
|
||||
return ss, nil
|
||||
|
||||
case tea.KeyPressMsg:
|
||||
// Delete confirmation mode.
|
||||
// Delete confirmation mode swallows all keys until y/n.
|
||||
if ss.confirmDelete >= 0 {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
@@ -145,7 +148,7 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if err := session.DeleteSession(info.Path); err == nil {
|
||||
name := sessionDisplayName(info)
|
||||
ss.removeSession(info.Path)
|
||||
ss.rebuildFiltered()
|
||||
ss.rebuild()
|
||||
return ss, func() tea.Msg {
|
||||
return SessionDeletedMsg{Name: name}
|
||||
}
|
||||
@@ -159,64 +162,14 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("up"))):
|
||||
if ss.cursor > 0 {
|
||||
ss.cursor--
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("down"))):
|
||||
if ss.cursor < len(ss.filtered)-1 {
|
||||
ss.cursor++
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgup"))):
|
||||
ss.cursor -= ss.visibleHeight()
|
||||
if ss.cursor < 0 {
|
||||
ss.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("pgdown"))):
|
||||
ss.cursor += ss.visibleHeight()
|
||||
if ss.cursor >= len(ss.filtered) {
|
||||
ss.cursor = len(ss.filtered) - 1
|
||||
}
|
||||
if ss.cursor < 0 {
|
||||
ss.cursor = 0
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("home"))):
|
||||
ss.cursor = 0
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("end"))):
|
||||
ss.cursor = max(len(ss.filtered)-1, 0)
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("enter"))):
|
||||
if ss.cursor < len(ss.filtered) {
|
||||
info := ss.filtered[ss.cursor]
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectedMsg{Path: info.Path}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("esc"))):
|
||||
if ss.search != "" {
|
||||
ss.search = ""
|
||||
ss.rebuildFiltered()
|
||||
} else {
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("tab"))):
|
||||
if ss.scope == SessionScopeCwd {
|
||||
ss.scope = SessionScopeAll
|
||||
} else {
|
||||
ss.scope = SessionScopeCwd
|
||||
}
|
||||
ss.rebuildFiltered()
|
||||
ss.rebuild()
|
||||
return ss, nil
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+n"))):
|
||||
if ss.filter == SessionFilterAll {
|
||||
@@ -224,25 +177,48 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
} else {
|
||||
ss.filter = SessionFilterAll
|
||||
}
|
||||
ss.rebuildFiltered()
|
||||
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("d"))):
|
||||
if ss.cursor < len(ss.filtered) {
|
||||
ss.confirmDelete = ss.cursor
|
||||
}
|
||||
ss.rebuild()
|
||||
return ss, nil
|
||||
|
||||
default:
|
||||
if msg.Text != "" && len(msg.Text) == 1 {
|
||||
ch := msg.Text[0]
|
||||
if ch >= 32 && ch < 127 {
|
||||
ss.search += string(ch)
|
||||
ss.rebuildFiltered()
|
||||
case key.Matches(msg, key.NewBinding(key.WithKeys("ctrl+d"))):
|
||||
// Ctrl+D as an explicit delete shortcut. Plain "d" still works
|
||||
// below when the search field is empty so it doesn't conflict
|
||||
// with typing the letter 'd' into a query.
|
||||
if c := ss.popup.Cursor(); c < len(ss.filtered) {
|
||||
ss.confirmDelete = c
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// Plain 'd' triggers delete only when there's no active search
|
||||
// query (otherwise the user would never be able to type 'd' into
|
||||
// a search like "doc").
|
||||
if msg.String() == "d" && !ss.popup.IsSearching() {
|
||||
if c := ss.popup.Cursor(); c < len(ss.filtered) {
|
||||
ss.confirmDelete = c
|
||||
return ss, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate everything else to the popup.
|
||||
result := ss.popup.HandleKey(msg.String(), msg.Text)
|
||||
if result.Changed {
|
||||
ss.syncFiltered()
|
||||
}
|
||||
if result.Selected != nil {
|
||||
cursor := ss.popup.Cursor()
|
||||
if cursor < len(ss.filtered) {
|
||||
info := ss.filtered[cursor]
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectedMsg{Path: info.Path}
|
||||
}
|
||||
}
|
||||
if key.Matches(msg, key.NewBinding(key.WithKeys("backspace"))) && len(ss.search) > 0 {
|
||||
ss.search = ss.search[:len(ss.search)-1]
|
||||
ss.rebuildFiltered()
|
||||
}
|
||||
if result.Cancelled {
|
||||
ss.active = false
|
||||
return ss, func() tea.Msg {
|
||||
return SessionSelectorCancelledMsg{}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,152 +227,17 @@ func (ss *SessionSelectorComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// View implements tea.Model.
|
||||
func (ss *SessionSelectorComponent) View() tea.View {
|
||||
theme := style.GetTheme()
|
||||
|
||||
// Full-screen bordered container - uses entire terminal width and height
|
||||
maxWidth := ss.width - 2 // Small margin on each side
|
||||
if maxWidth < 20 {
|
||||
maxWidth = ss.width
|
||||
}
|
||||
maxHeight := ss.height - 2 // Small margin top/bottom to prevent overflow
|
||||
if maxHeight < 10 {
|
||||
maxHeight = ss.height
|
||||
}
|
||||
horizontalPadding := 1
|
||||
innerWidth := maxWidth - 4 // Account for border (2) + padding (2)
|
||||
innerHeight := maxHeight - 4 // Account for border (2) + padding (2)
|
||||
|
||||
// Container style with border - full width/height like a framed panel
|
||||
containerStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Primary).
|
||||
Background(theme.Background).
|
||||
Padding(1, horizontalPadding).
|
||||
Width(maxWidth).
|
||||
Height(maxHeight)
|
||||
|
||||
var contentBuilder strings.Builder
|
||||
|
||||
// ── Header: title + scope badges ─────────────────────────────
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(theme.Accent).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(titleStyle.Render(fmt.Sprintf("Resume Session (%s)", ss.scope)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Help / keybindings ───────────────────────────────────────
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if innerWidth >= 75 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab: scope N: named D: delete R: rename type to search esc: cancel"))
|
||||
} else if innerWidth >= 50 {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab scope N named D del type to search esc"))
|
||||
} else {
|
||||
contentBuilder.WriteString(helpStyle.Render("tab N D esc"))
|
||||
}
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Search (only shown when active) ──────────────────────────
|
||||
if ss.search != "" {
|
||||
searchStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(searchStyle.Render(fmt.Sprintf("> %s", ss.search)))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Separator line
|
||||
sepWidth := innerWidth
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// ── Delete confirmation ──────────────────────────────────────
|
||||
// Compose dynamic footer extras: scope + filter + (delete confirm).
|
||||
extra := fmt.Sprintf("scope: %s • filter: %s", ss.scope, ss.filter)
|
||||
if ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) {
|
||||
warnStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true).
|
||||
Background(theme.Background)
|
||||
name := sessionDisplayName(ss.filtered[ss.confirmDelete])
|
||||
contentBuilder.WriteString(warnStyle.Render(fmt.Sprintf("Delete %q? (y/N)", truncateRunes(name, 40))))
|
||||
contentBuilder.WriteString("\n")
|
||||
name := truncateRunes(sessionDisplayName(ss.filtered[ss.confirmDelete]), 30)
|
||||
extra = fmt.Sprintf("delete %q? y/N", name)
|
||||
}
|
||||
ss.popup.Title = fmt.Sprintf("Resume Session (%s)", ss.scope)
|
||||
ss.popup.ExtraFooter = extra
|
||||
|
||||
// ── Session list ─────────────────────────────────────────────
|
||||
if len(ss.filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
if ss.search != "" {
|
||||
contentBuilder.WriteString(emptyStyle.Render(fmt.Sprintf("No sessions matching %q", ss.search)))
|
||||
} else if ss.filter == SessionFilterNamed {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No named sessions. Press N to show all."))
|
||||
} else if ss.scope == SessionScopeCwd {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions in current folder. Press tab to view all."))
|
||||
} else {
|
||||
contentBuilder.WriteString(emptyStyle.Render("No sessions found"))
|
||||
}
|
||||
contentBuilder.WriteString("\n")
|
||||
} else {
|
||||
// Compute visible window based on inner container height
|
||||
// Chrome: header(2) + separator(1) + footer separator(1) + footer(1) = 5
|
||||
chromeLines := 5
|
||||
if ss.search != "" {
|
||||
chromeLines++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chromeLines++
|
||||
}
|
||||
visH := max(innerHeight-chromeLines, 3)
|
||||
|
||||
// Center the cursor in the visible window.
|
||||
startIdx := max(0, min(ss.cursor-visH/2, len(ss.filtered)-visH))
|
||||
endIdx := min(startIdx+visH, len(ss.filtered))
|
||||
|
||||
for i := startIdx; i < endIdx; i++ {
|
||||
info := ss.filtered[i]
|
||||
isCursor := i == ss.cursor
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := i == ss.confirmDelete
|
||||
line := ss.renderEntry(info, isCursor, isCurrent, isDeleting, innerWidth)
|
||||
contentBuilder.WriteString(line)
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
|
||||
// Scroll position indicator.
|
||||
if len(ss.filtered) > visH {
|
||||
posStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(posStyle.Render(fmt.Sprintf("(%d/%d)", ss.cursor+1, len(ss.filtered))))
|
||||
contentBuilder.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Footer separator
|
||||
contentBuilder.WriteString(
|
||||
lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background).
|
||||
Render(strings.Repeat("─", sepWidth)))
|
||||
contentBuilder.WriteString("\n")
|
||||
|
||||
// Footer with filter info
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Background(theme.Background)
|
||||
contentBuilder.WriteString(footerStyle.Render(fmt.Sprintf("Filter: %s", ss.filter)))
|
||||
|
||||
// Apply the bordered container
|
||||
content := contentBuilder.String()
|
||||
borderedContent := containerStyle.Render(content)
|
||||
|
||||
v := tea.NewView(borderedContent)
|
||||
rendered := ss.popup.RenderCentered(ss.width, ss.height)
|
||||
v := tea.NewView(rendered)
|
||||
v.AltScreen = true
|
||||
return v
|
||||
}
|
||||
@@ -408,20 +249,9 @@ func (ss *SessionSelectorComponent) IsActive() bool {
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
func (ss *SessionSelectorComponent) visibleHeight() int {
|
||||
// Reserve: title(1) + help(1) + blank(1) + scroll indicator(1) = 4.
|
||||
// Optional: search(1), delete confirm(1).
|
||||
chrome := 4
|
||||
if ss.search != "" {
|
||||
chrome++
|
||||
}
|
||||
if ss.confirmDelete >= 0 {
|
||||
chrome++
|
||||
}
|
||||
return max(ss.height-chrome, 3)
|
||||
}
|
||||
|
||||
func (ss *SessionSelectorComponent) rebuildFiltered() {
|
||||
// rebuild applies the scope and filter selections, then publishes the
|
||||
// resulting session list to the popup.
|
||||
func (ss *SessionSelectorComponent) rebuild() {
|
||||
var source []session.SessionInfo
|
||||
if ss.scope == SessionScopeCwd {
|
||||
source = ss.cwdSessions
|
||||
@@ -439,23 +269,33 @@ func (ss *SessionSelectorComponent) rebuildFiltered() {
|
||||
source = named
|
||||
}
|
||||
|
||||
if ss.search != "" {
|
||||
query := strings.ToLower(ss.search)
|
||||
var matches []session.SessionInfo
|
||||
for _, s := range source {
|
||||
haystack := strings.ToLower(s.Name + " " + s.FirstMessage + " " + s.Cwd)
|
||||
if strings.Contains(haystack, query) {
|
||||
matches = append(matches, s)
|
||||
}
|
||||
// Build PopupItems. The Label holds a haystack string (name + first
|
||||
// message + cwd) so PopupList's default filter can match against any
|
||||
// of those fields. We render each row with a custom RenderItem.
|
||||
items := make([]PopupItem, len(source))
|
||||
for i, s := range source {
|
||||
haystack := strings.TrimSpace(s.Name + " " + s.FirstMessage + " " + s.Cwd)
|
||||
items[i] = PopupItem{
|
||||
Label: haystack,
|
||||
Active: s.Path == ss.currentPath,
|
||||
Meta: s,
|
||||
}
|
||||
ss.filtered = matches
|
||||
} else {
|
||||
ss.filtered = source
|
||||
}
|
||||
ss.popup.SetItems(items)
|
||||
ss.syncFiltered()
|
||||
}
|
||||
|
||||
if ss.cursor >= len(ss.filtered) {
|
||||
ss.cursor = max(len(ss.filtered)-1, 0)
|
||||
// syncFiltered refreshes the filtered slice from popup.Items() so cursor
|
||||
// indices map back to session.SessionInfo for the parent.
|
||||
func (ss *SessionSelectorComponent) syncFiltered() {
|
||||
items := ss.popup.Items()
|
||||
out := make([]session.SessionInfo, 0, len(items))
|
||||
for _, it := range items {
|
||||
if s, ok := it.Meta.(session.SessionInfo); ok {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
ss.filtered = out
|
||||
}
|
||||
|
||||
func (ss *SessionSelectorComponent) removeSession(path string) {
|
||||
@@ -473,87 +313,74 @@ func removeByPath(sessions []session.SessionInfo, path string) []session.Session
|
||||
return result
|
||||
}
|
||||
|
||||
// renderEntry renders a single session line with right-aligned metadata.
|
||||
// Layout: [cursor 2] [message ...variable...] [padding] [count age] [cwd?]
|
||||
func (ss *SessionSelectorComponent) renderEntry(info session.SessionInfo, isCursor, isCurrent, isDeleting bool, width int) string {
|
||||
// renderEntry is the RenderItem callback handed to PopupList. It produces a
|
||||
// single-line entry with left-aligned message text and right-aligned
|
||||
// metadata (message count + relative time, plus optional cwd in "All" scope).
|
||||
//
|
||||
// When isCursor we return a plain (unstyled) string so PopupList's outer
|
||||
// row style can paint one continuous fg+bg span. Mixing inner lipgloss
|
||||
// Render calls with an outer Background() breaks the highlight into bars,
|
||||
// because each inner Render emits an ANSI reset that drops the background.
|
||||
func (ss *SessionSelectorComponent) renderEntry(item PopupItem, innerWidth int, isCursor bool) string {
|
||||
theme := style.GetTheme()
|
||||
info, ok := item.Meta.(session.SessionInfo)
|
||||
if !ok {
|
||||
return item.Label
|
||||
}
|
||||
isCurrent := info.Path == ss.currentPath
|
||||
isDeleting := ss.confirmDelete >= 0 && ss.confirmDelete < len(ss.filtered) &&
|
||||
ss.filtered[ss.confirmDelete].Path == info.Path
|
||||
|
||||
// ── Cursor indicator (2 chars) ───────────────────────────────
|
||||
cursorStr := " "
|
||||
// Cursor indicator (2 cells).
|
||||
indicator := " "
|
||||
if isCursor {
|
||||
cursorStr = lipgloss.NewStyle().Foreground(theme.Accent).Render("> ")
|
||||
indicator = "> "
|
||||
}
|
||||
const cursorW = 2
|
||||
|
||||
// ── Right part: message count + relative time (+ optional cwd) ──
|
||||
// Right-hand metadata.
|
||||
age := relativeTime(info.Modified)
|
||||
msgCount := fmt.Sprintf("%d", info.MessageCount)
|
||||
rightPart := msgCount + " " + age
|
||||
right := fmt.Sprintf("%d %s", info.MessageCount, age)
|
||||
if ss.scope == SessionScopeAll && info.Cwd != "" {
|
||||
shortCwd := shortenPath(info.Cwd)
|
||||
if len(shortCwd) > 25 {
|
||||
shortCwd = "..." + shortCwd[len(shortCwd)-22:]
|
||||
}
|
||||
rightPart = shortCwd + " " + rightPart
|
||||
shortCwd := truncateRunes(shortenPath(info.Cwd), 25)
|
||||
right = shortCwd + " " + right
|
||||
}
|
||||
rightW := utf8.RuneCountInString(rightPart)
|
||||
rightW := lipgloss.Width(right)
|
||||
|
||||
// Message text width: innerWidth minus indicator(2) minus right minus gap(2).
|
||||
availForMsg := max(innerWidth-2-rightW-2, 10)
|
||||
|
||||
// ── Message text ─────────────────────────────────────────────
|
||||
displayText := sessionDisplayName(info)
|
||||
// Strip control characters and collapse whitespace.
|
||||
displayText = controlCharsRe.ReplaceAllString(displayText, " ")
|
||||
displayText = strings.Join(strings.Fields(displayText), " ")
|
||||
displayText = truncateRunes(displayText, availForMsg)
|
||||
|
||||
availableForMsg := max(width-cursorW-rightW-2, 10) // 2 for min spacing
|
||||
displayText = truncateRunes(displayText, availableForMsg)
|
||||
msgW := utf8.RuneCountInString(displayText)
|
||||
msgW := lipgloss.Width(displayText)
|
||||
spacing := max(innerWidth-2-msgW-rightW, 1)
|
||||
|
||||
// ── Style the message ────────────────────────────────────────
|
||||
var msgStyle lipgloss.Style
|
||||
// Selected row: raw string, outer row style paints it.
|
||||
if isCursor {
|
||||
return indicator + displayText + strings.Repeat(" ", spacing) + right
|
||||
}
|
||||
|
||||
// Color the message text by state.
|
||||
var msgStyle, rightStyle lipgloss.Style
|
||||
switch {
|
||||
case isDeleting:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
case isCurrent:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent)
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Accent).Bold(true)
|
||||
case info.Name != "":
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Warning)
|
||||
default:
|
||||
msgStyle = lipgloss.NewStyle().Foreground(theme.Text)
|
||||
}
|
||||
|
||||
// ── Style the right part ─────────────────────────────────────
|
||||
rightColor := theme.Muted
|
||||
if isDeleting {
|
||||
rightColor = theme.Error
|
||||
}
|
||||
var styledRight string
|
||||
|
||||
// ── Assemble with spacing ────────────────────────────────────
|
||||
spacing := max(width-cursorW-msgW-rightW, 1)
|
||||
|
||||
// If selected, use inverted colors like PopupList
|
||||
if isCursor {
|
||||
// Inverted colors for selected item
|
||||
msgStyle = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Background).
|
||||
Bold(true)
|
||||
styledRight = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(rightColor).
|
||||
Render(rightPart)
|
||||
cursorStr = lipgloss.NewStyle().
|
||||
Background(theme.Primary).
|
||||
Foreground(theme.Accent).
|
||||
Render("> ")
|
||||
rightStyle = lipgloss.NewStyle().Foreground(theme.Error)
|
||||
} else {
|
||||
styledRight = lipgloss.NewStyle().Foreground(rightColor).Render(rightPart)
|
||||
rightStyle = lipgloss.NewStyle().Foreground(theme.Muted)
|
||||
}
|
||||
|
||||
styledMsg := msgStyle.Render(displayText)
|
||||
line := cursorStr + styledMsg + strings.Repeat(" ", spacing) + styledRight
|
||||
|
||||
return line
|
||||
return indicator + msgStyle.Render(displayText) + strings.Repeat(" ", spacing) + rightStyle.Render(right)
|
||||
}
|
||||
|
||||
// --- Package helpers ---
|
||||
@@ -570,7 +397,7 @@ func sessionDisplayName(info session.SessionInfo) string {
|
||||
return "(empty session)"
|
||||
}
|
||||
|
||||
// truncateRunes truncates a string to at most maxRunes runes, appending "..."
|
||||
// truncateRunes truncates a string to at most maxRunes runes, appending "…"
|
||||
// if truncated.
|
||||
func truncateRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
|
||||
@@ -211,106 +211,11 @@ func DefaultTheme() Theme {
|
||||
}
|
||||
}
|
||||
|
||||
// StyleCard creates a lipgloss style for card-like containers with rounded borders,
|
||||
// padding, and appropriate width. Used for grouping related content in a visually
|
||||
// distinct box.
|
||||
func StyleCard(width int, theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(theme.Border).
|
||||
Padding(1, 2).
|
||||
MarginBottom(1)
|
||||
}
|
||||
|
||||
// IsDarkBackground returns the cached terminal background detection result.
|
||||
func IsDarkBackground() bool {
|
||||
return isDarkBg
|
||||
}
|
||||
|
||||
// StyleHeader creates a lipgloss style for primary headers using the theme's
|
||||
// primary color with bold text for emphasis and hierarchy.
|
||||
func StyleHeader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Primary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleSubheader creates a lipgloss style for secondary headers using the theme's
|
||||
// secondary color with bold text, providing visual hierarchy below primary headers.
|
||||
func StyleSubheader(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Secondary).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleMuted creates a lipgloss style for de-emphasized text using muted colors
|
||||
// and italic formatting, suitable for supplementary or less important information.
|
||||
func StyleMuted(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Italic(true)
|
||||
}
|
||||
|
||||
// StyleSuccess creates a lipgloss style for success messages using green colors
|
||||
// with bold text to indicate successful operations or positive outcomes.
|
||||
func StyleSuccess(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleError creates a lipgloss style for error messages using red colors
|
||||
// with bold text to ensure visibility of problems or failures.
|
||||
func StyleError(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Error).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleWarning creates a lipgloss style for warning messages using yellow/amber
|
||||
// colors with bold text to draw attention to potential issues or cautions.
|
||||
func StyleWarning(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Warning).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// StyleInfo creates a lipgloss style for informational messages using blue colors
|
||||
// with bold text for general notifications and status updates.
|
||||
func StyleInfo(theme Theme) lipgloss.Style {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(theme.Info).
|
||||
Bold(true)
|
||||
}
|
||||
|
||||
// CreateSeparator generates a horizontal separator line with the specified width,
|
||||
// character, and color. Useful for visually dividing sections of content in the UI.
|
||||
func CreateSeparator(width int, char string, c color.Color) string {
|
||||
return lipgloss.NewStyle().
|
||||
Foreground(c).
|
||||
Width(width).
|
||||
Render(lipgloss.PlaceHorizontal(width, lipgloss.Center, char))
|
||||
}
|
||||
|
||||
// CreateProgressBar generates a visual progress bar with filled and empty segments
|
||||
// based on the percentage complete. The bar uses Unicode block characters for smooth
|
||||
// appearance and theme colors to indicate progress.
|
||||
func CreateProgressBar(width int, percentage float64, theme Theme) string {
|
||||
filled := int(float64(width) * percentage / 100)
|
||||
empty := width - filled
|
||||
|
||||
filledBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Success).
|
||||
Render(lipgloss.PlaceHorizontal(filled, lipgloss.Left, "█"))
|
||||
|
||||
emptyBar := lipgloss.NewStyle().
|
||||
Foreground(theme.Muted).
|
||||
Render(lipgloss.PlaceHorizontal(empty, lipgloss.Left, "░"))
|
||||
|
||||
return filledBar + emptyBar
|
||||
}
|
||||
|
||||
// CreateBadge generates a styled badge or label with inverted colors (text on
|
||||
// colored background) for highlighting important tags, statuses, or categories.
|
||||
func CreateBadge(text string, c color.Color) string {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user