mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
199 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a8e684dff | |||
| 7ef99ac60f | |||
| a67f514560 | |||
| b6bb35cb71 | |||
| 4e82fac442 | |||
| 5ec2217b0f | |||
| 8a851723ba | |||
| 53b628c5f8 | |||
| e1c94cb362 | |||
| ecf95b52e1 | |||
| 0641c92acc | |||
| 3bb20f5283 | |||
| 633fa38b2b | |||
| f905cee48c | |||
| 182c10ea1a | |||
| fcaa52bf1c | |||
| 7e6455732c | |||
| 71301a9035 | |||
| 0974d37ab2 | |||
| 398e825df8 | |||
| 3c51c20be7 | |||
| 25410af440 | |||
| 26c9f009f9 | |||
| e068487ff7 | |||
| 0ffb0ba788 | |||
| 65c6e9f797 | |||
| 68d798d2f4 | |||
| eefd5565f8 | |||
| 9d1b8a102e | |||
| f57e045c69 | |||
| eb5da28a15 | |||
| cd8e2a7654 | |||
| 64da1caf41 | |||
| 7eaeafff8c | |||
| 8ed8d23c73 | |||
| 2de98d32be | |||
| 83127467c5 | |||
| e07c94f49d | |||
| b87146a284 | |||
| 186d9f7f44 | |||
| 3a8ffc2104 | |||
| e54570162e | |||
| 34bb97a40e | |||
| f5c1a16f8a | |||
| b29d7d2166 | |||
| 3ea0db69ea | |||
| 4304a5e899 | |||
| 4019c1e4f7 | |||
| 30ad7c1d0b | |||
| e33564c569 | |||
| 5ff28445fd | |||
| 13d177e5d0 | |||
| 3ffc995f27 | |||
| b2bd016135 | |||
| 812dedaea2 | |||
| f65b6737f2 | |||
| 5d45aa196b | |||
| debb39f56c | |||
| 7ce6f4fd9e | |||
| c2f2bdb3d3 | |||
| 201d14804e | |||
| 7e54710d4a | |||
| 88870be4d2 | |||
| 46bf809715 | |||
| e19e9642a2 | |||
| 32675b8b35 | |||
| aecce001ee | |||
| 32d73171fd | |||
| 265fd2ec0c | |||
| efebf2eba6 | |||
| f7b655ae33 | |||
| 35982b41ad | |||
| 788e3b71fd | |||
| 3496bc2684 | |||
| 997c7d15ff | |||
| 83246e47d5 | |||
| 50e7b78c33 | |||
| b937af3056 | |||
| a5e995c750 | |||
| e95e08a699 | |||
| bcaf92f62a | |||
| ead4afbfe6 | |||
| 685aaf207f | |||
| 76ff6c9639 | |||
| 1cf24ee5de | |||
| c9637090fa | |||
| 0ff0ff42ab | |||
| a4fb32ff2b | |||
| 7d2f078111 | |||
| b0b66941ab | |||
| cbb7387a72 | |||
| 19430b0ecb | |||
| 8e3cfeede5 | |||
| 4fa5775974 | |||
| 4e7d823ee4 | |||
| 7a16c76adc | |||
| 70a21ee73a | |||
| 28d2de8f39 | |||
| 7f192ae850 | |||
| 9f6746ded9 | |||
| 7514d3a0ff | |||
| c83281a52b | |||
| 4515bb92c2 | |||
| e326b84204 | |||
| 1b93049b8e | |||
| 4912449dda | |||
| b70cce4f34 | |||
| 4c566836b2 | |||
| bb3261883a | |||
| 512d0f16ce | |||
| 8159431ce4 | |||
| 9f9f265fb3 | |||
| 9d38349091 | |||
| fec8bac800 | |||
| e76f5f3d45 | |||
| 1ad493c5c7 | |||
| ea6ddc8792 | |||
| 6d4e8bcec5 | |||
| e2ed345280 | |||
| e542eb797e | |||
| e631fc1b17 | |||
| 290c5a4774 | |||
| 287d60c31e | |||
| 3d45d98895 | |||
| db4be4f9a2 | |||
| 80093e69ed | |||
| ef519ba517 | |||
| d79eb1f0fa | |||
| ac8ee6525d | |||
| e35e8382d6 | |||
| fbb3408a25 | |||
| 44fed9a647 | |||
| e7f11487b9 | |||
| 054c417603 | |||
| 94d62a6ef0 | |||
| 91e6dfd2c8 | |||
| b6a0c4b44c | |||
| 8eb0fa855a | |||
| 3bf696c546 | |||
| 3e461a0539 | |||
| a2ece01ecf | |||
| 623c9fb5ad | |||
| 139506f336 | |||
| 6d424554ad | |||
| 5a3d3fdd7d | |||
| c91225629d | |||
| 5a71cde5ff | |||
| 044d3eb206 | |||
| 80f3a642a3 | |||
| 26f0969e3e | |||
| 4af75901b5 | |||
| 49ff4c0678 | |||
| b0802a5c32 | |||
| dfe65ca227 | |||
| d4ec756ce5 | |||
| 2971e73ee8 | |||
| 5aa6c9e116 | |||
| bca08476de | |||
| 6a599d86af | |||
| fd6f200659 | |||
| b295a25946 | |||
| f0e4e2f757 | |||
| d25249506a | |||
| 971521f534 | |||
| 8c00682367 | |||
| 58caf155c1 | |||
| 3f08bf2424 | |||
| 9fbbab05f6 | |||
| b0991c7aa6 | |||
| 9c90563765 | |||
| f36166bee5 | |||
| 879e81f9b5 | |||
| 727b42acfe | |||
| 4830981570 | |||
| dcfebafcc5 | |||
| 1f5c103667 | |||
| 4caa8ba3dc | |||
| 15ef8ad78b | |||
| 551f2710d9 | |||
| 67bda5cad5 | |||
| 01d7d754ef | |||
| c6304f1e92 | |||
| bc3c733ae3 | |||
| 428ee2b8be | |||
| eb1d7fd07e | |||
| 1e3e5cafd3 | |||
| 0b93e58fb9 | |||
| 2bb01ed72c | |||
| b6ecc36ea1 | |||
| d4f27bc912 | |||
| f12e195390 | |||
| b68b3dd0bf | |||
| 48521bf76d | |||
| 16df3a738c | |||
| 9d0b8c8cef | |||
| d9326fcf21 | |||
| 22c479277e | |||
| 8ae204f12f | |||
| 8b1665a4ce |
@@ -0,0 +1,79 @@
|
||||
name: Bug Report
|
||||
description: Report a bug or issue with Kit
|
||||
title: "fix: "
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Bug Description
|
||||
description: What happened? What did you expect to happen?
|
||||
placeholder: |
|
||||
The BorderColor field in ToolRenderConfig is documented but never applied
|
||||
during tool rendering. I expected the tool block to render with my custom
|
||||
color, but it uses the default styling instead.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide clear steps to reproduce the issue
|
||||
placeholder: |
|
||||
1. Create an extension with `api.RegisterToolRenderer(ext.ToolRenderConfig{...})`
|
||||
2. Set `BorderColor: "#89b4fa"` in the config
|
||||
3. Run a tool that uses this renderer
|
||||
4. Observe the border color is not applied
|
||||
render: markdown
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: code
|
||||
attributes:
|
||||
label: Relevant Code / Configuration
|
||||
description: Paste any code, configuration, or error messages
|
||||
placeholder: |
|
||||
```go
|
||||
api.RegisterToolRenderer(ext.ToolRenderConfig{
|
||||
ToolName: "bash",
|
||||
DisplayName: "Shell",
|
||||
BorderColor: "#a6e3a1", // This is ignored!
|
||||
Background: "#1e1e2e", // This is ignored!
|
||||
})
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: input
|
||||
id: component
|
||||
attributes:
|
||||
label: Affected Component
|
||||
description: Which part of Kit is affected?
|
||||
placeholder: e.g., extensions, ui, tool rendering, session management
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Kit Version
|
||||
description: What version of Kit are you running?
|
||||
placeholder: e.g., v0.1.0, commit hash, or "main"
|
||||
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Any other context, proposed fixes, or related issues
|
||||
placeholder: |
|
||||
The issue appears to be in `internal/ui/messages.go:RenderToolMessage()`
|
||||
which ignores the BorderColor and Background fields from ToolRendererData.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been reported yet
|
||||
required: true
|
||||
- label: I've tested with the latest version of Kit
|
||||
required: false
|
||||
@@ -0,0 +1,11 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Kit Documentation
|
||||
url: https://github.com/mark3labs/kit/tree/main/www/pages
|
||||
about: Check the documentation before filing an issue
|
||||
- name: Extension Examples
|
||||
url: https://github.com/mark3labs/kit/tree/main/examples/extensions
|
||||
about: See working extension examples for reference
|
||||
- name: Discussions
|
||||
url: https://github.com/mark3labs/kit/discussions
|
||||
about: For questions, ideas, or general discussion
|
||||
@@ -0,0 +1,40 @@
|
||||
name: Documentation Issue
|
||||
description: Report missing, incorrect, or unclear documentation
|
||||
title: "docs: "
|
||||
labels: ["documentation"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Documentation Issue
|
||||
description: What's wrong or missing in the documentation?
|
||||
placeholder: |
|
||||
The ToolRenderConfig documentation mentions BorderColor and Background fields,
|
||||
but the code doesn't actually use them. The docs should either be updated
|
||||
to reflect reality, or the bug should be fixed.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: location
|
||||
attributes:
|
||||
label: Documentation Location
|
||||
description: Where is the affected documentation?
|
||||
placeholder: e.g., README.md, examples/extensions/tool-renderer-demo.go, pkg/kit docs
|
||||
|
||||
- type: textarea
|
||||
id: suggestion
|
||||
attributes:
|
||||
label: Suggested Improvement
|
||||
description: How should the documentation be improved?
|
||||
placeholder: |
|
||||
Add a note that BorderColor and Background are not yet implemented,
|
||||
or fix the bug and document the correct behavior.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've checked that this documentation issue still exists in the latest version
|
||||
required: true
|
||||
@@ -0,0 +1,64 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature or enhancement for Kit
|
||||
title: "feat: "
|
||||
labels: ["enhancement"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Feature Description
|
||||
description: What would you like to see added or changed?
|
||||
placeholder: |
|
||||
I'd like to be able to customize the border color of tool result blocks
|
||||
dynamically based on the tool type or result status.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / Use Case
|
||||
description: Why is this feature needed? What problem does it solve?
|
||||
placeholder: |
|
||||
When running multiple tools in sequence, it's hard to visually distinguish
|
||||
between file reads (blue), shell commands (green), and errors (red)
|
||||
without custom border colors.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposed
|
||||
attributes:
|
||||
label: Proposed Implementation
|
||||
description: How do you think this should work? (optional)
|
||||
placeholder: |
|
||||
Extend `ToolRenderConfig` to accept a function that receives the tool
|
||||
result and returns a color based on the content:
|
||||
|
||||
```go
|
||||
BorderColorFunc: func(result string, isError bool) string {
|
||||
if isError {
|
||||
return "#f38ba8"
|
||||
}
|
||||
return "#89b4fa"
|
||||
}
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: checkboxes
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
options:
|
||||
- label: I've considered workarounds or alternative approaches
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been requested yet
|
||||
required: true
|
||||
- label: This feature aligns with Kit's design philosophy (TUI-first, extension-based)
|
||||
required: false
|
||||
@@ -3,6 +3,7 @@
|
||||
.env
|
||||
.kit/*
|
||||
!.kit/extensions/
|
||||
!.kit/prompts/
|
||||
aidocs/
|
||||
*.log
|
||||
/kit
|
||||
|
||||
@@ -28,11 +28,15 @@ type lintResult struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint on Go file edits")
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
@@ -43,30 +47,72 @@ func Init(api ext.API) {
|
||||
return nil
|
||||
}
|
||||
|
||||
report := runGoDiagnostics(ctx.CWD, absPath)
|
||||
|
||||
// Check if there are issues and add explicit prompt for the LLM to react
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
var enhanced string
|
||||
if hasIssues {
|
||||
enhanced = e.Content + "\n\n" + report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review the issues above and fix them before proceeding."
|
||||
} else {
|
||||
enhanced = e.Content + "\n\n" + report
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Show TUI message block for diagnostics visibility (only if there are issues)
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("File: %s", filepath.Base(absPath)))
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
msgLines = append(msgLines, "", "⚠️ Please fix these issues before proceeding.")
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
@@ -78,9 +124,16 @@ func Init(api ext.API) {
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
|
||||
return &ext.ToolResultResult{Content: &enhanced}
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -106,18 +159,6 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGoDiagnostics(cwd, absPath string) string {
|
||||
gopls := runGopls(cwd, absPath)
|
||||
lint := runGolangCILint(cwd, "./...")
|
||||
|
||||
return fmt.Sprintf(
|
||||
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
filepath.Base(absPath),
|
||||
formatToolResult(gopls, "No diagnostics."),
|
||||
formatToolResult(lint, "No lint issues."),
|
||||
)
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
@@ -178,7 +219,9 @@ func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
lines = append(lines, emptyFallback)
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
@@ -197,17 +240,15 @@ func truncate(s string, max int) string {
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
// Extract gopls section
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
// Count non-empty lines excluding the header and "No diagnostics." message
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." {
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
@@ -215,7 +256,6 @@ func countIssues(report string) (goplsCount, lintCount int) {
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
// Count non-empty lines excluding the header and "No lint issues." message
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
// - The extension runner serializes handler calls per-extension, so
|
||||
// concurrent subagent events cannot race on this shared state.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -43,7 +45,8 @@ const (
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// Package-level state — safe because the runner serializes all handler
|
||||
// invocations for the same extension (per-extension reentrant mutex).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
@@ -282,8 +285,8 @@ func Init(api ext.API) {
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
// Remove the entry — build a new slice to avoid aliasing bugs
|
||||
newEntries := make([]*submonEntry, 0, len(submonEntries))
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
description: Run ACP smoke test against opencode/kimi-k2.5 to verify JSON-RPC stdio works
|
||||
---
|
||||
|
||||
Run the ACP smoke test to verify the Kit ACP server works correctly over JSON-RPC stdio with streaming responses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. Build the kit binary:
|
||||
```bash
|
||||
go build -o output/kit ./cmd/kit
|
||||
```
|
||||
|
||||
2. Run the smoke test Python script against opencode/kimi-k2.5:
|
||||
```bash
|
||||
python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
3. Verify the output shows:
|
||||
- `session/new` returns a valid `sessionId`
|
||||
- `session/prompt` streams `agent_thought_chunk` notifications (reasoning)
|
||||
- `session/prompt` streams `agent_message_chunk` notifications (response)
|
||||
- Final result has `stopReason: "end_turn"`
|
||||
- `✓ SMOKE TEST PASSED` at the end
|
||||
|
||||
4. If the test fails, check:
|
||||
- `output/kit` binary exists and is executable
|
||||
- `OPENCODE_API_KEY` or `OPENCODE_ZEN_API_KEY` environment variable is set
|
||||
- `scripts/acp_smoke_test.py` exists
|
||||
- The model `opencode/kimi-k2.5` is available (`kit models opencode | grep kimi-k2.5`)
|
||||
|
||||
5. For testing with a different model, edit the script or set the `MODEL` variable:
|
||||
```bash
|
||||
MODEL=anthropic/claude-sonnet-4-5 python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
The smoke test exercises the full ACP protocol: session lifecycle, streaming notifications, and tool-free prompt completion.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
description: Stage, commit, and push changes with an auto-generated conventional commit message
|
||||
---
|
||||
|
||||
Review the current git status and diff, then stage all changes, write a concise conventional commit message, commit, and push to the current branch.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check status**: `git status` — understand what has changed
|
||||
2. **Review the diff**: `git diff` (and `git diff --cached` if anything is already staged) — read the actual changes
|
||||
3. **Stage everything**: `git add -A`
|
||||
4. **Craft the commit message** following Conventional Commits:
|
||||
- Format: `<type>(<scope>): <short summary>`
|
||||
- Types: `feat`, `fix`, `refactor`, `chore`, `docs`, `test`, `perf`, `build`
|
||||
- Scope: optional, the subsystem affected (e.g. `ui`, `cmd`, `config`)
|
||||
- Summary: imperative mood, lowercase, no trailing period, ≤72 chars
|
||||
- Body: add a blank line then bullet points for non-trivial changes
|
||||
- Do **not** include "Generated by" or similar noise
|
||||
5. **Commit**: `git commit -m "<message>"`
|
||||
6. **Push**: `git push`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the actual diff — do not guess from filenames alone
|
||||
- Prefer one well-scoped commit; do not split unless the changes are clearly unrelated
|
||||
- Keep the subject line under 72 characters
|
||||
- Use the body to explain *what* and *why*, not *how*
|
||||
- If there is nothing to commit, say so and stop
|
||||
|
||||
$@
|
||||
@@ -0,0 +1,86 @@
|
||||
---
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
| Field | Required | Purpose |
|
||||
|-------|----------|---------|
|
||||
| **Feature Description** | Yes | What should be added or changed |
|
||||
| **Motivation / Use Case** | Yes | Why is this needed? What problem does it solve? |
|
||||
| **Proposed Implementation** | No | How do you think this should work? |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$+`
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
2. **Ask clarifying questions** if needed:
|
||||
- "What problem does this solve for you?"
|
||||
- "How would you expect this to work?"
|
||||
- "Are there similar features in other tools you use?"
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `feat: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Good examples:
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `feat: support custom themes per extension`
|
||||
- `feat: add fuzzy matching to model selector`
|
||||
- Bad examples:
|
||||
- `Feature request: can we have...` (too vague)
|
||||
- `It would be nice if...` (not imperative)
|
||||
|
||||
4. **Build the body** with the template fields:
|
||||
|
||||
**Feature Description:**
|
||||
- Clear statement of what to add/change
|
||||
- Be specific about the behavior
|
||||
- Include UI/UX details if relevant
|
||||
|
||||
**Motivation / Use Case:**
|
||||
- What problem does this solve?
|
||||
- Current workaround (if any) and why it's insufficient
|
||||
- Who benefits from this feature?
|
||||
|
||||
**Proposed Implementation** (optional but helpful):
|
||||
- High-level approach
|
||||
- API changes if applicable
|
||||
- Example usage code
|
||||
|
||||
5. **Create the issue**:
|
||||
```bash
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
```
|
||||
|
||||
6. **Confirm success**:
|
||||
- Show the issue URL and number
|
||||
- Mention it was created with the feature_request template
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Focus on the *problem* first, then the solution
|
||||
- Include concrete examples of how the feature would be used
|
||||
- Consider edge cases and mention them
|
||||
- If proposing API changes, show before/after code
|
||||
- Check if similar features exist in related tools (mention them for reference)
|
||||
- Align with Kit's philosophy: TUI-first, extension-based, keyboard-driven
|
||||
|
||||
## Example
|
||||
|
||||
User: `/feature-request I want to be able to customize tool border colors dynamically`
|
||||
|
||||
You:
|
||||
1. Title: `feat: dynamic border colors for tool results based on status`
|
||||
2. Body:
|
||||
- **Feature Description**: Allow `ToolRenderConfig` to accept a function that determines border color based on tool result content or status, enabling dynamic visual feedback.
|
||||
- **Motivation**: When running multiple tools, it's hard to distinguish file reads (blue), shell commands (green), and errors (red) without custom colors per result.
|
||||
- **Proposed Implementation**: Add `BorderColorFunc` callback that receives `(result string, isError bool)` and returns a color string.
|
||||
|
||||
3. Execute: `gh issue create --template feature_request --title "feat: ..." --body "..."`
|
||||
4. Confirm: Created issue #43 using feature_request template
|
||||
@@ -0,0 +1,100 @@
|
||||
---
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
This repository has structured issue templates. You MUST use the appropriate template:
|
||||
|
||||
| Type | Template | Use For |
|
||||
|------|----------|---------|
|
||||
| `bug` | `bug_report` | Something is broken, not working as expected |
|
||||
| `feat` | `feature_request` | New feature, enhancement, improvement |
|
||||
| `docs` | `documentation` | Missing, incorrect, or unclear documentation |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$+`:
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
2. **Ask clarifying questions** if critical info is missing:
|
||||
- For bugs: "What were you doing when this happened?" (reproduction steps)
|
||||
- For features: "What problem does this solve?" (motivation)
|
||||
- For docs: "Where did you look for this information?" (location)
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `<type>: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Examples:
|
||||
- `fix: ToolRenderConfig BorderColor ignored during rendering`
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `docs: clarify extension widget lifecycle`
|
||||
|
||||
4. **File the issue** using the template:
|
||||
```bash
|
||||
# For bugs
|
||||
gh issue create --template bug_report --title "fix: ..." --body "..."
|
||||
|
||||
# For features
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
|
||||
# For documentation
|
||||
gh issue create --template documentation --title "docs: ..." --body "..."
|
||||
```
|
||||
|
||||
The template will guide the user through the required fields. You need to provide:
|
||||
- **Bug reports**: Description, reproduction steps, expected vs actual behavior
|
||||
- **Feature requests**: Description, motivation/use case, optional proposed implementation
|
||||
- **Documentation**: Description, location of docs, suggested improvement
|
||||
|
||||
5. **Confirm success** by showing:
|
||||
- The issue URL
|
||||
- The issue number
|
||||
- Which template was used
|
||||
|
||||
## Template Field Guide
|
||||
|
||||
### Bug Report (`bug_report`)
|
||||
Required fields in the body:
|
||||
- **Bug Description** - what happened vs expected
|
||||
- **Steps to Reproduce** - numbered list to recreate the bug
|
||||
- **Relevant Code** - code snippets, configuration, error messages
|
||||
- **Component** - which part of Kit (ui, extensions, session, etc.)
|
||||
- **Version** - Kit version or commit hash
|
||||
|
||||
### Feature Request (`feature_request`)
|
||||
Required fields in the body:
|
||||
- **Feature Description** - what to add/change
|
||||
- **Motivation / Use Case** - why this is needed
|
||||
- **Proposed Implementation** - how it could work (optional)
|
||||
|
||||
### Documentation (`documentation`)
|
||||
Required fields in the body:
|
||||
- **Documentation Issue** - what's wrong or missing
|
||||
- **Documentation Location** - file or URL where docs exist
|
||||
- **Suggested Improvement** - how to fix the docs
|
||||
|
||||
## Guidelines
|
||||
|
||||
- ALWAYS use `--template <name>` instead of bare `gh issue create`
|
||||
- Include file paths and line numbers when you know them
|
||||
- Use triple backticks for code blocks
|
||||
- Keep the body factual - avoid speculation unless in "Proposed Fix" section
|
||||
- If you're unsure about technical details, say so in the issue
|
||||
- For UI bugs, describe what you see vs what you expect
|
||||
- For API bugs, include the relevant struct/function names
|
||||
|
||||
## Example Usage
|
||||
|
||||
User: `/file-issue The ToolRenderConfig BorderColor field is documented but never used in rendering`
|
||||
|
||||
You:
|
||||
1. Determine this is a **bug** (documented field doesn't work)
|
||||
2. Use `--template bug_report`
|
||||
3. Gather: reproduction steps (register renderer with BorderColor), expected (custom color), actual (default color)
|
||||
4. Create issue with title `fix: ToolRenderConfig BorderColor and Background fields are ignored`
|
||||
5. Confirm: Created issue #42 using bug_report template
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
A prompt template is a `.md` file in `.kit/prompts/` (project-local) or `~/.kit/prompts/` (global).
|
||||
It becomes a `/slug` slash command in the kit input box — typed as `/filename` with optional arguments.
|
||||
|
||||
## File format
|
||||
|
||||
```
|
||||
---
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Keep prompts action-oriented — they should tell kit *what to do*, not just *what to think about*
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
description: Semantic version tagging workflow - analyzes commits and tags releases
|
||||
---
|
||||
|
||||
# Release Tagging Workflow
|
||||
|
||||
Tag a new version of this Go project following semantic versioning.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch remote tags**: `git fetch --tags origin`
|
||||
|
||||
2. **Find latest version**: `git tag -l | sort -V | tail -5` to see recent tags
|
||||
|
||||
3. **Analyze changes since last tag**:
|
||||
- `git log <latest-tag>..HEAD --oneline` - list commits
|
||||
- `git diff <latest-tag>..HEAD --stat` - see file stats
|
||||
- `git diff <latest-tag>..HEAD --name-only` - see changed files
|
||||
|
||||
4. **Determine version bump** (Semantic Versioning):
|
||||
- **MAJOR (X.0.0)**: Breaking API changes, incompatible modifications
|
||||
- **MINOR (0.X.0)**: New features, backward-compatible additions
|
||||
- **PATCH (0.0.X)**: Bug fixes, backward-compatible fixes
|
||||
|
||||
Look for indicators:
|
||||
- `feat:` or `feature:` commits → MINOR
|
||||
- `fix:` or `bugfix:` commits → PATCH
|
||||
- `breaking:` or `BREAKING CHANGE:` → MAJOR
|
||||
- Breaking API changes in `pkg/` or public interfaces → MAJOR
|
||||
- New commands, flags, or features → MINOR
|
||||
- Documentation-only changes → PATCH (or skip)
|
||||
|
||||
5. **Calculate new version**: Increment appropriate segment, reset lower segments to 0
|
||||
|
||||
6. **Draft tag message**:
|
||||
- Summarize key changes from commits
|
||||
- Group by type (Features, Fixes, Breaking Changes)
|
||||
- Keep concise but informative
|
||||
|
||||
7. **Create annotated tag**: `git tag -a vX.Y.Z -m "vX.Y.Z - <summary>\n\n<detailed list>"`
|
||||
|
||||
8. **Push tag**: `git push origin vX.Y.Z`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Always fetch remote tags first to avoid conflicts
|
||||
- Use annotated tags (`-a`) with descriptive messages
|
||||
- Follow semver strictly - when in doubt, prefer conservative bump (patch over minor)
|
||||
- For Go projects, changes to `pkg/` or exported APIs warrant careful version consideration
|
||||
- If no changes since last tag, suggest skipping the release
|
||||
- Include commit summaries in the tag message body
|
||||
|
||||
## Example Tag Message Format
|
||||
|
||||
```
|
||||
v0.30.1 - Bug fixes for model handling and UI improvements
|
||||
|
||||
Fixes:
|
||||
- Properly handle think tags from Qwen/DeepSeek models
|
||||
- Handle custom provider model persistence and bare model names
|
||||
|
||||
Improvements:
|
||||
- UI style refactoring and cleanup
|
||||
```
|
||||
|
||||
Wait for the user to confirm the version and message before executing tag commands.
|
||||
|
||||
---
|
||||
|
||||
$@
|
||||
@@ -1,22 +1,3 @@
|
||||
<!-- OPENSPEC:START -->
|
||||
# OpenSpec Instructions
|
||||
|
||||
These instructions are for AI assistants working in this project.
|
||||
|
||||
Always open `@/openspec/AGENTS.md` when the request:
|
||||
- Mentions planning or proposals (words like proposal, spec, change, plan)
|
||||
- Introduces new capabilities, breaking changes, architecture shifts, or big performance/security work
|
||||
- Sounds ambiguous and you need the authoritative spec before coding
|
||||
|
||||
Use `@/openspec/AGENTS.md` to learn:
|
||||
- How to create and apply change proposals
|
||||
- Spec format and conventions
|
||||
- Project structure and guidelines
|
||||
|
||||
Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
|
||||
<!-- OPENSPEC:END -->
|
||||
|
||||
# KIT Agent Guidelines
|
||||
|
||||
## Build/Test Commands
|
||||
@@ -42,6 +23,33 @@ Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
- **Extension system** (`internal/extensions/`): Yaegi-interpreted Go, 13 lifecycle events, custom tools/commands/widgets/overlays/editor interceptors
|
||||
- **TUI** (`internal/ui/`): Bubble Tea v2 parent-child model (`AppModel` → `InputComponent`, `StreamComponent`, etc.)
|
||||
- **Decoupling pattern**: `cmd/root.go` has converter functions (e.g. `widgetProviderForUI()`) that bridge `internal/extensions/` types to `internal/ui/` types — the UI never imports extensions directly
|
||||
- **Public SDK** (`pkg/kit/`): The public-facing Go SDK for embedding Kit as a library. See rules below.
|
||||
|
||||
## Public SDK (`pkg/kit/`) Rules
|
||||
|
||||
`pkg/kit/` is the **public API surface** consumed by external Go developers. All exported symbols, types, function names, and godoc comments in this package are part of the SDK contract.
|
||||
|
||||
### No Dependency Name Leakage
|
||||
Internal dependency names (e.g. `charm.land/fantasy`, library-specific jargon) **must not** appear in:
|
||||
- **Exported function/method names** — use generic terms (`LLM`, `Provider`, `Message`) instead of library names
|
||||
- **Exported type names** — type aliases should use domain names (e.g. `LLMMessage`, not `FantasyMessage`)
|
||||
- **Godoc comments** on exported symbols — these are visible in `go doc` output and pkg.go.dev
|
||||
- **Struct field names and tags** on exported types
|
||||
|
||||
Using dependency types directly in **function bodies** (private implementation) is fine — that's invisible to SDK consumers.
|
||||
|
||||
### Naming Conventions for SDK Symbols
|
||||
- Type aliases re-exporting dependency types: use `LLM*` prefix (e.g. `LLMMessage`, `LLMUsage`, `LLMResponse`)
|
||||
- Conversion helpers: use `ConvertToLLM*` / `ConvertFromLLM*` (not the dependency name)
|
||||
- Provider queries: use `GetLLMProviders` (not `GetFantasyProviders`)
|
||||
- When wrapping internal methods, the `pkg/kit/` name should be dependency-agnostic even if the `internal/` method still uses the old name
|
||||
|
||||
### Deprecation Pattern
|
||||
When renaming a public SDK symbol, keep the old name as a deprecated wrapper for one release cycle:
|
||||
```go
|
||||
// Deprecated: Use NewName instead.
|
||||
func OldName() { return NewName() }
|
||||
```
|
||||
|
||||
## Key Patterns
|
||||
|
||||
@@ -92,3 +100,21 @@ Positional args are the prompt. `@file` args attach file content. Key flags: `--
|
||||
- Never guess or manually search the filesystem for external projects
|
||||
- Example: `btca ask -r https://github.com/user/repo -q "How does X work?"`
|
||||
- See `.agents/skills/btca-cli/SKILL.md` for full btca usage
|
||||
|
||||
## BTCA Configured Resources
|
||||
The following external repositories are configured in `btca.config.jsonc` for research:
|
||||
|
||||
- bubbletea
|
||||
- lipgloss
|
||||
- bubbles
|
||||
- glamour
|
||||
- fantasy
|
||||
- catwalk
|
||||
- crush
|
||||
- pi
|
||||
- iteratr
|
||||
- yaegi
|
||||
- acp-go-sdk
|
||||
- opencode
|
||||
- herald
|
||||
- herald-md
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Autoscroll Fix - Final Summary
|
||||
|
||||
## Root Cause
|
||||
|
||||
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
|
||||
|
||||
### The Problem
|
||||
|
||||
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
|
||||
2. The `Height()` method returns `0` when `cachedRender == ""`
|
||||
3. `GotoBottom()` was calling:
|
||||
```go
|
||||
itemHeight := item.Height() // Returns 0 for reasoning
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width) // Renders but doesn't cache (reasoning)
|
||||
itemHeight = item.Height() // Still returns 0!
|
||||
}
|
||||
```
|
||||
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
|
||||
|
||||
## The Solution
|
||||
|
||||
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
|
||||
|
||||
```go
|
||||
// OLD: item.Height() which checks cached render
|
||||
itemHeight := item.Height()
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width)
|
||||
itemHeight = item.Height() // Still might be 0!
|
||||
}
|
||||
|
||||
// NEW: Calculate from rendered string directly
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
```
|
||||
|
||||
This works for **all** items regardless of whether they cache their render or not.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### `internal/ui/scrolllist.go`
|
||||
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
|
||||
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
|
||||
|
||||
### `internal/ui/model.go`
|
||||
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
|
||||
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
|
||||
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
|
||||
|
||||
## Testing Results
|
||||
|
||||
✅ **Test prompt**: "explore this repo"
|
||||
|
||||
**Before fix**:
|
||||
- Autoscroll stopped after reasoning block completed
|
||||
- Viewport stuck showing end of reasoning ("Thought for 203ms")
|
||||
- Assistant response streamed off-screen below
|
||||
|
||||
**After fix**:
|
||||
- Autoscroll works throughout reasoning block
|
||||
- Autoscroll continues during reasoning → assistant transition
|
||||
- Viewport stays at bottom showing latest assistant content
|
||||
- Final position shows end of response (build commands section)
|
||||
|
||||
## Behavior Verified
|
||||
|
||||
1. ✅ Streaming text auto-scrolls to bottom
|
||||
2. ✅ Works across reasoning → assistant transition
|
||||
3. ✅ Manual scroll up (PgUp) disables autoscroll
|
||||
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
|
||||
5. ✅ Accurate positioning with no offset errors
|
||||
|
||||
## Performance Note
|
||||
|
||||
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
|
||||
- `Render()` is already optimized with caching for non-reasoning items
|
||||
- `GotoBottom()` is only called during content updates (not every frame)
|
||||
- Reasoning blocks need to render anyway for live duration updates
|
||||
- This matches iteratr's approach of ensuring items are rendered before height calculations
|
||||
@@ -18,7 +18,8 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash (with interactive sudo password prompt), read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **Smart @ Attachments**: Binary files auto-detected via MIME type, MCP resources via `@mcp:server:uri`
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -125,8 +126,13 @@ model: anthropic/claude-sonnet-latest
|
||||
max-tokens: 4096
|
||||
temperature: 0.7
|
||||
stream: true
|
||||
thinking-level: off # off, minimal, low, medium, high
|
||||
```
|
||||
|
||||
All of the above keys can also be set programmatically via the SDK
|
||||
(`kit.Options.MaxTokens`, `Options.Temperature`, `Options.ThinkingLevel`, etc.)
|
||||
without touching config files — see [SDK options](#with-options).
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
@@ -186,11 +192,13 @@ mcpServers:
|
||||
--no-prompt-templates Disable prompt template loading
|
||||
|
||||
# Generation parameters
|
||||
--max-tokens Maximum tokens in response (default: 4096)
|
||||
--max-tokens Maximum tokens in response (default: 8192, auto-raised up to 32768 for models with larger known output limits)
|
||||
--temperature Randomness 0.0-1.0 (default: 0.7)
|
||||
--top-p Nucleus sampling 0.0-1.0 (default: 0.95)
|
||||
--top-k Limit top K tokens (default: 40)
|
||||
--stop-sequences Custom stop sequences (comma-separated)
|
||||
--frequency-penalty Penalize frequent tokens 0.0-2.0 (default: 0.0)
|
||||
--presence-penalty Penalize present tokens 0.0-2.0 (default: 0.0)
|
||||
--thinking-level Extended thinking level: off, minimal, low, medium, high (default: off)
|
||||
|
||||
# System
|
||||
@@ -209,7 +217,7 @@ kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
kit update-models [source] # Update model database (from models.dev, URL, file, or 'embedded')
|
||||
|
||||
# Extension management
|
||||
@@ -307,41 +315,50 @@ kit -e examples/extensions/minimal.go
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
**Bridged SDK APIs** (NEW): Extensions can now access internal SDK capabilities:
|
||||
- **Tree Navigation**: Navigate conversation history (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), summarize branches (`SummarizeBranch`), and implement fresh context loops (`CollapseBranch`)
|
||||
- **Skill Loading**: Dynamically load and inject skills at runtime (`LoadSkill`, `DiscoverSkills`, `InjectSkillAsContext`)
|
||||
- **Template Parsing**: Parse and render templates with `{{variables}}` (`ParseTemplate`, `RenderTemplate`), parse CLI-style arguments (`ParseArguments`, `SimpleParseArguments`), and evaluate model conditionals (`EvaluateModelConditional`, `RenderWithModelConditionals`)
|
||||
- **Model Resolution**: Resolve model fallback chains (`ResolveModelChain`), query model capabilities (`GetModelCapabilities`, `CheckModelAvailable`), and extract provider/model ID (`GetCurrentProvider`, `GetCurrentModelID`)
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
|
||||
- `minimal.go` - Clean UI with custom footer
|
||||
- `auto-commit.go` - Auto-commit on shutdown
|
||||
- `bookmark.go` - Bookmark conversations
|
||||
- `branded-output.go` - Branded output rendering
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
- `inline-bash.go` - Inline bash execution
|
||||
- `interactive-shell.go` - Interactive shell integration
|
||||
- `kit-kit.go` - Kit-in-Kit (sub-agent spawning)
|
||||
- `lsp-diagnostics.go` - LSP diagnostic integration
|
||||
- `notify.go` - Desktop notifications
|
||||
- `overlay-demo.go` - Modal dialogs
|
||||
- `permission-gate.go` - Permission gating for tools
|
||||
- `pirate.go` - Pirate-themed personality
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
- `tool-renderer-demo.go` - Custom tool call rendering
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
- [`minimal.go`](examples/extensions/minimal.go) - Clean UI with custom footer
|
||||
- [`auto-commit.go`](examples/extensions/auto-commit.go) - Auto-commit on shutdown
|
||||
- [`bookmark.go`](examples/extensions/bookmark.go) - Bookmark conversations
|
||||
- [`branded-output.go`](examples/extensions/branded-output.go) - Branded output rendering
|
||||
- [`bridge-demo.go`](examples/extensions/bridge_demo.go) - Bridged SDK API demo (tree navigation, skills, templates, model resolution)
|
||||
- [`compact-notify.go`](examples/extensions/compact-notify.go) - Notification on compaction
|
||||
- [`confirm-destructive.go`](examples/extensions/confirm-destructive.go) - Confirm destructive operations
|
||||
- [`context-inject.go`](examples/extensions/context-inject.go) - Inject context into conversations
|
||||
- [`conversation-manager.go`](examples/extensions/conversation-manager.go) - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- [`custom-editor-demo.go`](examples/extensions/custom-editor-demo.go) - Vim-like modal editor
|
||||
- [`dev-reload.go`](examples/extensions/dev-reload.go) - Development live-reload
|
||||
- [`header-footer-demo.go`](examples/extensions/header-footer-demo.go) - Custom headers and footers
|
||||
- [`inline-bash.go`](examples/extensions/inline-bash.go) - Inline bash execution
|
||||
- [`interactive-shell.go`](examples/extensions/interactive-shell.go) - Interactive shell integration
|
||||
- [`kit-kit.go`](examples/extensions/kit-kit.go) - Kit-in-Kit (sub-agent spawning)
|
||||
- [`lsp-diagnostics.go`](examples/extensions/lsp-diagnostics.go) - LSP diagnostic integration
|
||||
- [`notify.go`](examples/extensions/notify.go) - Desktop notifications
|
||||
- [`overlay-demo.go`](examples/extensions/overlay-demo.go) - Modal dialogs
|
||||
- [`permission-gate.go`](examples/extensions/permission-gate.go) - Permission gating for tools
|
||||
- [`pirate.go`](examples/extensions/pirate.go) - Pirate-themed personality
|
||||
- [`plan-mode.go`](examples/extensions/plan-mode.go) - Read-only planning mode
|
||||
- [`project-rules.go`](examples/extensions/project-rules.go) - Project-specific rules
|
||||
- [`prompt-demo.go`](examples/extensions/prompt-demo.go) - Interactive prompts (select/confirm/input)
|
||||
- [`prompt-templates.go`](examples/extensions/prompt-templates.go) - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- [`protected-paths.go`](examples/extensions/protected-paths.go) - Path protection for sensitive files
|
||||
- [`subagent-widget.go`](examples/extensions/subagent-widget.go) - Multi-agent orchestration with status widget
|
||||
- [`subagent-test.go`](examples/extensions/subagent-test.go) - Subagent testing utilities
|
||||
- [`summarize.go`](examples/extensions/summarize.go) - Conversation summarization
|
||||
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
|
||||
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
|
||||
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
|
||||
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
|
||||
|
||||
Also see `.kit/extensions/go-edit-lint.go` (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
@@ -398,7 +415,7 @@ func TestMyExtension(t *testing.T) {
|
||||
- `AssertPrinted()`, `AssertPrintedContains()` — Verify output
|
||||
- `AssertToolRegistered()`, `AssertCommandRegistered()` — Verify registration
|
||||
|
||||
See `examples/extensions/tool-logger_test.go` for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
See [`examples/extensions/tool-logger_test.go`](examples/extensions/tool-logger_test.go) for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
|
||||
### Prompt Templates
|
||||
|
||||
@@ -420,10 +437,13 @@ Focus on $1 specifically.
|
||||
|
||||
**Argument placeholders:**
|
||||
- `$1`, `$2`, etc. — Individual arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments (zero or more)
|
||||
- `$+` — All arguments (one or more required; error if none given)
|
||||
- `${@:2}` — Arguments from position 2 onwards
|
||||
- `${@:1:3}` — 3 arguments starting at position 1
|
||||
|
||||
Placeholders inside fenced code blocks (```) and inline code spans are ignored.
|
||||
|
||||
Disable templates with `--no-prompt-templates` or load a specific template with `--prompt-template <name>`.
|
||||
|
||||
## Session Management
|
||||
@@ -469,9 +489,18 @@ During an interactive session, use these slash commands:
|
||||
| `/import <path>` | Import and switch to a session from a JSONL file |
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/fork` | Fork to new session from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
|
||||
### Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Description |
|
||||
|----------|-------------|
|
||||
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
|
||||
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
|
||||
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
|
||||
| `↑` / `↓` | Navigate prompt history |
|
||||
|
||||
## Go SDK
|
||||
|
||||
Embed Kit in your Go applications:
|
||||
@@ -494,7 +523,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
@@ -517,13 +546,32 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
Streaming: true,
|
||||
Quiet: true,
|
||||
|
||||
// Generation parameters (override env/config/per-model defaults)
|
||||
MaxTokens: 16384, // 0 = auto-resolve (env → config → per-model → 8192 floor)
|
||||
ThinkingLevel: "medium", // "off", "low", "medium", "high"
|
||||
Temperature: ptr(float32(0.2)), // pointer so 0.0 != unset; nil = provider default
|
||||
TopP: nil, // nil = leave provider/per-model default
|
||||
TopK: nil,
|
||||
FrequencyPenalty: nil,
|
||||
PresencePenalty: nil,
|
||||
|
||||
// Provider configuration (override env/config without reaching into viper)
|
||||
ProviderAPIKey: "sk-...", // "" = use config / provider env var
|
||||
ProviderURL: "https://proxy.internal/v1", // "" = provider default
|
||||
TLSSkipVerify: false, // only takes effect when true
|
||||
|
||||
// Session options
|
||||
SessionPath: "./session.jsonl", // Open specific session
|
||||
Continue: true, // Resume most recent session
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
@@ -532,26 +580,91 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
**Generation & provider fields** (added in v0.55+) let SDK consumers configure
|
||||
Kit entirely in-code without `viper.Set()` workarounds or shipping a `.kit.yml`.
|
||||
Precedence is `Options` > `KIT_*` env vars > `.kit.yml` > per-model defaults
|
||||
(`modelSettings` / `customModels`) > provider-level defaults. Sampling params
|
||||
are pointer types so explicit `0.0` is distinguishable from "leave alone"; a
|
||||
non-zero `MaxTokens` suppresses automatic right-sizing the same way `--max-tokens`
|
||||
does on the CLI.
|
||||
|
||||
### MCP OAuth (remote MCP servers)
|
||||
|
||||
When a remote MCP server returns 401, Kit runs the full OAuth flow (dynamic
|
||||
client registration → PKCE → token exchange → persistence) but delegates the
|
||||
user-facing step — showing the authorization URL and receiving the callback —
|
||||
to an `MCPAuthHandler` that you pass explicitly via `Options.MCPAuthHandler`.
|
||||
If nil, OAuth is disabled and the authorization-required error surfaces to the
|
||||
caller; the SDK never auto-opens a browser or binds a localhost port.
|
||||
|
||||
```go
|
||||
// CLI/TUI apps: opens the system browser + prints status to stderr.
|
||||
authHandler, _ := kit.NewCLIMCPAuthHandler()
|
||||
defer authHandler.Close()
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPAuthHandler: authHandler,
|
||||
})
|
||||
|
||||
// Custom UX: reuse the SDK's port + callback server, supply your own
|
||||
// presentation via OnAuthURL (TUI modal, QR code, web redirect, etc.).
|
||||
// h, _ := kit.NewDefaultMCPAuthHandler()
|
||||
// h.OnAuthURL = func(server, authURL string) { myUI.Show(server, authURL) }
|
||||
//
|
||||
// Full control (web apps, daemons): implement kit.MCPAuthHandler yourself —
|
||||
// no localhost binding, no side effects.
|
||||
```
|
||||
|
||||
Tokens are persisted to `$XDG_CONFIG_HOME/.kit/mcp_tokens.json` by default; swap
|
||||
in a custom `MCPTokenStoreFactory` for encrypted, DB-backed, or in-memory
|
||||
storage. See the [SDK options docs](/sdk/options#mcp-oauth-authorization) for
|
||||
the full matrix.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" description:"Search query"`
|
||||
}
|
||||
|
||||
searchTool := kit.NewTool("search", "Search the codebase",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("Found: ..."), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
println("Calling tool:", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
println("Tool failed:", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Tool call started
|
||||
println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Tool call completed
|
||||
if isError {
|
||||
println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Streaming text chunk
|
||||
print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -715,7 +828,7 @@ Use `custom/custom` when pointing Kit at any OpenAI-compatible endpoint with `--
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through fantasy's `openaicompat` provider and supports:
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through the `openaicompat` provider and supports:
|
||||
|
||||
- Zero cost tracking (input/output = 0)
|
||||
- 262K context window, 65K output limit
|
||||
|
||||
@@ -76,6 +76,18 @@
|
||||
"name": "opencode",
|
||||
"url": "https://github.com/anomalyco/opencode",
|
||||
"branch": "dev"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald",
|
||||
"url": "https://github.com/indaco/herald",
|
||||
"branch": "main"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald-md",
|
||||
"url": "https://github.com/indaco/herald-md",
|
||||
"branch": "main"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
"github.com/mark3labs/kit/internal/acpserver"
|
||||
@@ -54,6 +55,8 @@ func runACP(cmd *cobra.Command, _ []string) error {
|
||||
conn.SetLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
// Also set charmbracelet/log level for acpserver package logging
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
// Wait for either the client to disconnect or a signal.
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ func printAllProviders(showAll bool) error {
|
||||
if showAll {
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
} else {
|
||||
providerIDs = kit.GetFantasyProviders()
|
||||
providerIDs = kit.GetLLMProviders()
|
||||
}
|
||||
sort.Strings(providerIDs)
|
||||
|
||||
|
||||
+967
-217
File diff suppressed because it is too large
Load Diff
@@ -41,7 +41,6 @@ func BuildAppOptions(mcpConfig *config.Config, modelName string, serverNames, to
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
Quiet: quietFlag,
|
||||
Debug: viper.GetBool("debug"),
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +130,6 @@ func SetupCLIForNonInteractive(k *kit.Kit) (*ui.CLI, error) {
|
||||
Agent: agentAdapter,
|
||||
ModelString: viper.GetString("model"),
|
||||
Debug: viper.GetBool("debug"),
|
||||
Compact: viper.GetBool("compact"),
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestAllExtensions_Load is a smoke test that verifies every single-file
|
||||
// example extension in this directory can be loaded by the Yaegi interpreter
|
||||
// without errors. This catches syntax errors, missing symbols, bad imports,
|
||||
// and Init signature mismatches.
|
||||
func TestAllExtensions_Load(t *testing.T) {
|
||||
files := extensionFiles(t)
|
||||
|
||||
for _, file := range files {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
ext := harness.LoadFile(file)
|
||||
if ext == nil {
|
||||
t.Fatalf("%s: extension should not be nil after loading", file)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Logf("successfully loaded %d extensions", len(files))
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// extensionFiles returns all single-file extensions in the current directory.
|
||||
// It skips test files, the test template, and files without an Init function.
|
||||
func extensionFiles(t *testing.T) []string {
|
||||
t.Helper()
|
||||
|
||||
skip := map[string]bool{
|
||||
"extension_test_template.go": true,
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(".")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read directory: %v", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() || filepath.Ext(name) != ".go" {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") || skip[name] {
|
||||
continue
|
||||
}
|
||||
src, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", name, err)
|
||||
}
|
||||
if !strings.Contains(string(src), "func Init(") {
|
||||
continue
|
||||
}
|
||||
files = append(files, name)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no extensions found — check the directory")
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// TestAllExtensions_Lifecycle verifies that every extension survives a full
|
||||
// SessionStart → SessionShutdown round-trip without errors.
|
||||
func TestAllExtensions_Lifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "smoke-test-session",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart error: %v", err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_CommandSanity checks that every registered command has
|
||||
// a non-empty name, a non-empty description, no spaces in the name, no
|
||||
// leading slash, a non-nil Execute function, and no duplicate names.
|
||||
func TestAllExtensions_CommandSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
cmds := harness.RegisteredCommands()
|
||||
seen := make(map[string]bool)
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == "" {
|
||||
t.Error("command has empty name")
|
||||
}
|
||||
if strings.Contains(cmd.Name, " ") {
|
||||
t.Errorf("command %q contains spaces", cmd.Name)
|
||||
}
|
||||
if strings.HasPrefix(cmd.Name, "/") {
|
||||
t.Errorf("command %q has leading slash (framework adds it)", cmd.Name)
|
||||
}
|
||||
if cmd.Description == "" {
|
||||
t.Errorf("command %q has empty description", cmd.Name)
|
||||
}
|
||||
if cmd.Execute == nil {
|
||||
t.Errorf("command %q has nil Execute function", cmd.Name)
|
||||
}
|
||||
if seen[cmd.Name] {
|
||||
t.Errorf("duplicate command name %q", cmd.Name)
|
||||
}
|
||||
seen[cmd.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ToolSanity checks that every registered tool has a
|
||||
// non-empty name, a non-empty description, at least one executor, valid
|
||||
// JSON in its Parameters field, and no duplicate names.
|
||||
func TestAllExtensions_ToolSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
tools := harness.RegisteredTools()
|
||||
seen := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "" {
|
||||
t.Error("tool has empty name")
|
||||
}
|
||||
if tool.Description == "" {
|
||||
t.Errorf("tool %q has empty description", tool.Name)
|
||||
}
|
||||
if tool.Execute == nil && tool.ExecuteWithContext == nil {
|
||||
t.Errorf("tool %q has no executor (both Execute and ExecuteWithContext are nil)", tool.Name)
|
||||
}
|
||||
if tool.Parameters != "" && !json.Valid([]byte(tool.Parameters)) {
|
||||
t.Errorf("tool %q has invalid JSON in Parameters: %s", tool.Name, tool.Parameters)
|
||||
}
|
||||
if seen[tool.Name] {
|
||||
t.Errorf("duplicate tool name %q", tool.Name)
|
||||
}
|
||||
seen[tool.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ZeroValueEvents fires every event type (as zero-value
|
||||
// structs) at each extension and verifies no errors are returned. Extensions
|
||||
// should be resilient to events they don't handle and to events with empty
|
||||
// fields.
|
||||
func TestAllExtensions_ZeroValueEvents(t *testing.T) {
|
||||
// Build the set of zero-value events for every event type.
|
||||
zeroEvents := []extensions.Event{
|
||||
extensions.ToolCallEvent{},
|
||||
extensions.ToolExecutionStartEvent{},
|
||||
extensions.ToolExecutionEndEvent{},
|
||||
extensions.ToolOutputEvent{},
|
||||
extensions.ToolResultEvent{},
|
||||
extensions.InputEvent{},
|
||||
extensions.BeforeAgentStartEvent{},
|
||||
extensions.AgentStartEvent{},
|
||||
extensions.AgentEndEvent{},
|
||||
extensions.MessageStartEvent{},
|
||||
extensions.MessageUpdateEvent{},
|
||||
extensions.MessageEndEvent{},
|
||||
extensions.SessionStartEvent{},
|
||||
extensions.SessionShutdownEvent{},
|
||||
extensions.ModelChangeEvent{},
|
||||
extensions.ContextPrepareEvent{},
|
||||
extensions.BeforeForkEvent{},
|
||||
extensions.BeforeSessionSwitchEvent{},
|
||||
extensions.BeforeCompactEvent{},
|
||||
extensions.SubagentStartEvent{},
|
||||
extensions.SubagentChunkEvent{},
|
||||
extensions.SubagentEndEvent{},
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for _, ev := range zeroEvents {
|
||||
_, err := harness.Emit(ev)
|
||||
if err != nil {
|
||||
t.Errorf("event %T returned error: %v", ev, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_WidgetSanity emits SessionStart and then checks that
|
||||
// any widgets set during initialization have non-empty IDs and valid
|
||||
// placements.
|
||||
func TestAllExtensions_WidgetSanity(t *testing.T) {
|
||||
validPlacements := map[extensions.WidgetPlacement]bool{
|
||||
"above": true,
|
||||
"below": true,
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
// Trigger SessionStart so extensions that set widgets on init do so.
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "widget-sanity-test",
|
||||
})
|
||||
|
||||
// Widgets is an exported field on MockContext; reads are safe
|
||||
// here because Emit returned synchronously.
|
||||
for id, w := range harness.Context().Widgets {
|
||||
if w.ID == "" {
|
||||
t.Errorf("widget stored with key %q has empty ID", id)
|
||||
}
|
||||
if w.ID != id {
|
||||
t.Errorf("widget key %q doesn't match widget ID %q", id, w.ID)
|
||||
}
|
||||
if !validPlacements[w.Placement] {
|
||||
t.Errorf("widget %q has invalid placement %q (want \"above\" or \"below\")", id, w.Placement)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_IdempotentLifecycle verifies that receiving SessionStart
|
||||
// twice and SessionShutdown twice doesn't cause errors — extensions should
|
||||
// be defensive about repeated lifecycle events.
|
||||
func TestAllExtensions_IdempotentLifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "idempotent-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build ignore
|
||||
|
||||
// bridge_demo.go - Demonstrates the new bridged SDK APIs for extensions.
|
||||
// This extension showcases tree navigation, skill loading, template parsing,
|
||||
// and model resolution capabilities.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
discoveredSkills []ext.Skill
|
||||
currentBranch []ext.TreeNode
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Register /tree-info command to demonstrate tree navigation
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree-info",
|
||||
Description: "Show current conversation tree information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
info := fmt.Sprintf("Current branch has %d nodes:\n", len(branch))
|
||||
for i, node := range branch {
|
||||
info += fmt.Sprintf(" [%d] %s (%s): %s...\n", i, node.Type, node.ID[:8], truncate(node.Content, 40))
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /discover-skills command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "discover-skills",
|
||||
Description: "Discover and list available skills",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("discovery failed: %s", result.Error)
|
||||
}
|
||||
discoveredSkills = result.Skills
|
||||
|
||||
info := fmt.Sprintf("Discovered %d skills:\n", len(result.Skills))
|
||||
for _, s := range result.Skills {
|
||||
info += fmt.Sprintf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /parse-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "parse-template",
|
||||
Description: "Parse a template and show extracted variables",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "Hello {{name}}, welcome to {{place}}!"
|
||||
}
|
||||
tpl := ctx.ParseTemplate("demo", args)
|
||||
info := fmt.Sprintf("Template: %s\nVariables: %v", tpl.Content, tpl.Variables)
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /render-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "render-template",
|
||||
Description: "Render a template with variables (usage: /render-template name=John place=Kit)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
tpl := ctx.ParseTemplate("demo", "Hello {{name}}, welcome to {{place}}!")
|
||||
vars := ctx.ParseArguments(args, ext.ArgumentPattern{
|
||||
Flags: map[string]string{"name": "name", "place": "place"},
|
||||
})
|
||||
rendered := ctx.RenderTemplate(tpl, vars.Vars)
|
||||
ctx.PrintInfo("Rendered: " + rendered)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /check-model command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "check-model",
|
||||
Description: "Check model capabilities and availability",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
model := args
|
||||
if model == "" {
|
||||
model = ctx.Model
|
||||
}
|
||||
|
||||
available := ctx.CheckModelAvailable(model)
|
||||
caps, err := ctx.GetModelCapabilities(model)
|
||||
|
||||
info := fmt.Sprintf("Model: %s\n", model)
|
||||
info += fmt.Sprintf("Available: %v\n", available)
|
||||
if err == "" {
|
||||
info += fmt.Sprintf("Provider: %s\n", caps.Provider)
|
||||
info += fmt.Sprintf("Context Limit: %d\n", caps.ContextLimit)
|
||||
info += fmt.Sprintf("Reasoning: %v\n", caps.Reasoning)
|
||||
} else {
|
||||
info += fmt.Sprintf("Error: %s\n", err)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /resolve-chain command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "resolve-chain",
|
||||
Description: "Resolve a model chain (usage: /resolve-chain claude-opus,gpt-4o,claude-sonnet)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "anthropic/claude-opus-4,anthropic/claude-sonnet-4,openai/gpt-4o"
|
||||
}
|
||||
prefs := ctx.SimpleParseArguments(args, 1)
|
||||
chain := []string{}
|
||||
if len(prefs) > 1 {
|
||||
// Split the first arg by comma
|
||||
for _, p := range strings.Split(prefs[1], ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
chain = append(chain, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(chain)
|
||||
info, _ := json.MarshalIndent(result, "", " ")
|
||||
ctx.PrintInfo("Resolution Result:\n" + string(info))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /test-conditional command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "test-conditional",
|
||||
Description: "Test model conditional rendering",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
content := `<if-model is="claude-*">This is for Claude models<else>This is for other models</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content)
|
||||
ctx.PrintInfo("Input: " + content)
|
||||
ctx.PrintInfo("Output: " + rendered)
|
||||
ctx.PrintInfo(fmt.Sprintf("Current model matches 'claude-*': %v", ctx.EvaluateModelConditional("claude-*")))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnSessionStart: discover skills automatically
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error == "" && len(result.Skills) > 0 {
|
||||
discoveredSkills = result.Skills
|
||||
ctx.SetStatus("bridge-demo", fmt.Sprintf("%d skills", len(result.Skills)), 50)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
//go:build ignore
|
||||
|
||||
// conversation-manager.go - Advanced conversation tree navigation and management.
|
||||
// This extension demonstrates:
|
||||
// - Tree navigation (GetTreeNode, GetCurrentBranch, NavigateTo)
|
||||
// - Branch summarization and collapsing
|
||||
// - Interactive tree exploration
|
||||
//
|
||||
// Commands:
|
||||
// /tree - Show conversation tree structure
|
||||
// /branch - Show current branch path
|
||||
// /goto <entry-id> - Navigate to a specific entry
|
||||
// /summarize <n> - Summarize last N messages
|
||||
// /fresh-context - Collapse branch and start fresh
|
||||
// /loop <n> <prompt> - Execute prompt N times with fresh context each iteration
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
loopActive bool
|
||||
loopCount int
|
||||
loopCurrent int
|
||||
loopPrompt string
|
||||
loopStartNode string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// /tree - Show tree structure
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree",
|
||||
Description: "Show conversation tree structure",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showTree(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /branch - Show current branch
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "branch",
|
||||
Description: "Show current conversation branch",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showBranch(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /goto - Navigate to entry
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "goto",
|
||||
Description: "Navigate to a specific entry ID (usage: /goto <entry-id>)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
ctx.PrintError("Usage: /goto <entry-id>")
|
||||
return "", nil
|
||||
}
|
||||
result := ctx.NavigateTo(args)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Navigation failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
ctx.PrintInfo(fmt.Sprintf("Navigated to entry: %s", args))
|
||||
|
||||
// Show the node we navigated to
|
||||
node := ctx.GetTreeNode(args)
|
||||
if node != nil {
|
||||
ctx.PrintInfo(fmt.Sprintf("Entry type: %s, Role: %s", node.Type, node.Role))
|
||||
}
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /summarize - Summarize recent messages
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "summarize",
|
||||
Description: "Summarize last N messages (usage: /summarize [n=5])",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
n := 5
|
||||
if args != "" {
|
||||
if parsed, err := strconv.Atoi(args); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 2 {
|
||||
ctx.PrintError("Not enough messages to summarize")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Find range to summarize
|
||||
startIdx := len(branch) - n - 1
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx := len(branch) - 1
|
||||
|
||||
fromID := branch[startIdx].ID
|
||||
toID := branch[endIdx].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Summarizing messages %d to %d...", startIdx, endIdx))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Summary",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /fresh-context - Collapse and restart
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "fresh-context",
|
||||
Description: "Collapse conversation to summary and start fresh",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 3 {
|
||||
ctx.PrintError("Not enough context to collapse")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Keep first message (system), summarize rest
|
||||
fromID := branch[1].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo("Generating summary for context collapse...")
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Collapse the branch
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo("Context collapsed. Starting fresh with summary.")
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Collapsed Context",
|
||||
})
|
||||
|
||||
// Set a widget showing we're in fresh mode
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "fresh-context",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "🌱 Fresh Context Mode - Previous conversation collapsed"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#a6e3a1"},
|
||||
})
|
||||
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /loop - Execute with fresh context each iteration
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "loop",
|
||||
Description: "Execute prompt N times with fresh context (usage: /loop 5 analyze this code)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if loopActive {
|
||||
ctx.PrintError("Loop already in progress. Wait for completion.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
ctx.PrintError("Usage: /loop <count> <prompt>")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(parts[0])
|
||||
if err != nil || count <= 0 || count > 10 {
|
||||
ctx.PrintError("Invalid count (must be 1-10)")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
loopCount = count
|
||||
loopCurrent = 0
|
||||
loopPrompt = parts[1]
|
||||
loopActive = true
|
||||
|
||||
// Store current branch position
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 {
|
||||
loopStartNode = branch[len(branch)-1].ID
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Starting loop: %d iterations", loopCount))
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: 0/%d - %s", loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Start first iteration
|
||||
executeLoopIteration(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnAgentEnd handles loop continuation
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if !loopActive {
|
||||
return
|
||||
}
|
||||
|
||||
loopCurrent++
|
||||
|
||||
if loopCurrent >= loopCount {
|
||||
// Loop complete
|
||||
loopActive = false
|
||||
ctx.RemoveWidget("loop-progress")
|
||||
ctx.PrintInfo(fmt.Sprintf("✅ Loop complete: %d/%d iterations", loopCurrent, loopCount))
|
||||
|
||||
// Show final summary
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 && loopStartNode != "" {
|
||||
summary := ctx.SummarizeBranch(loopStartNode, branch[len(branch)-1].ID)
|
||||
if summary != "" {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Loop Summary",
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update progress
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: %d/%d - %s", loopCurrent, loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Collapse previous iteration for fresh context
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) >= 2 {
|
||||
// Find the user messages (look for the one before the last assistant message)
|
||||
// We want to collapse from the user message that started this iteration
|
||||
// to the last assistant response
|
||||
var collapseStartIdx = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if branch[i].Role == "assistant" {
|
||||
// Found the last assistant message, now find the user message before it
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if branch[j].Role == "user" {
|
||||
collapseStartIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if collapseStartIdx >= 0 {
|
||||
fromID := branch[collapseStartIdx].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Collapsing iteration %d for fresh context...", loopCurrent))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
if summary != "" {
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if result.Success {
|
||||
ctx.PrintInfo("Context collapsed successfully")
|
||||
} else {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay to let UI update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger next iteration
|
||||
executeLoopIteration(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// showTree displays the conversation tree structure
|
||||
func showTree(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("Tree is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Conversation Tree (%d nodes):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
prefix := " "
|
||||
if i == len(branch)-1 {
|
||||
prefix = "▶ " // Current node
|
||||
} else {
|
||||
prefix = " "
|
||||
}
|
||||
|
||||
roleIcon := "💬"
|
||||
switch node.Role {
|
||||
case "user":
|
||||
roleIcon = "👤"
|
||||
case "assistant":
|
||||
roleIcon = "🤖"
|
||||
case "system":
|
||||
roleIcon = "⚙️"
|
||||
}
|
||||
|
||||
content := truncate(node.Content, 50)
|
||||
if node.Type == "branch_summary" {
|
||||
roleIcon = "📋"
|
||||
content = "[Summary] " + truncate(node.Content, 40)
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s%s %s: %s (%s...)\n", prefix, roleIcon, node.Role, node.ID[:8], content))
|
||||
|
||||
// Show children count if any
|
||||
children := ctx.GetChildren(node.ID)
|
||||
if len(children) > 0 {
|
||||
output.WriteString(fmt.Sprintf(" └─ %d branch(es)\n", len(children)))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Tree View",
|
||||
})
|
||||
}
|
||||
|
||||
// showBranch displays the current branch path
|
||||
func showBranch(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("No active branch")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Current Branch (%d nodes from root to leaf):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
marker := " "
|
||||
if i == len(branch)-1 {
|
||||
marker = "▶ " // Current leaf
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s[%d] %s (%s): %s\n",
|
||||
marker, i, node.Type, node.ID[:8], truncate(node.Content, 40)))
|
||||
}
|
||||
|
||||
// Show current node details
|
||||
leaf := branch[len(branch)-1]
|
||||
output.WriteString(fmt.Sprintf("\nCurrent Leaf:\n"))
|
||||
output.WriteString(fmt.Sprintf(" ID: %s\n", leaf.ID))
|
||||
output.WriteString(fmt.Sprintf(" Type: %s\n", leaf.Type))
|
||||
output.WriteString(fmt.Sprintf(" Role: %s\n", leaf.Role))
|
||||
output.WriteString(fmt.Sprintf(" Model: %s\n", leaf.Model))
|
||||
output.WriteString(fmt.Sprintf(" Children: %d\n", len(leaf.Children)))
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#cba6f7",
|
||||
Subtitle: "conversation-manager · Branch View",
|
||||
})
|
||||
}
|
||||
|
||||
// executeLoopIteration triggers the next loop iteration
|
||||
func executeLoopIteration(ctx ext.Context) {
|
||||
iterationPrompt := fmt.Sprintf("[%d/%d] %s", loopCurrent+1, loopCount, loopPrompt)
|
||||
ctx.SendMessage(iterationPrompt)
|
||||
}
|
||||
|
||||
// truncate helper
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -7,10 +7,12 @@
|
||||
// development: edit your extension source, then type /reload to pick up
|
||||
// changes immediately.
|
||||
//
|
||||
// Event handlers, slash commands, tool renderers, message renderers, and
|
||||
// keyboard shortcuts update immediately. Extension-defined tools are NOT
|
||||
// updated (they are baked into the agent at creation time and require a
|
||||
// restart).
|
||||
// Note: Extensions in autoloaded directories (~/.config/kit/extensions/
|
||||
// and .kit/extensions/) are automatically reloaded on save. The /reload
|
||||
// command is useful for extensions loaded via -e from other locations.
|
||||
//
|
||||
// Event handlers, slash commands, tool definitions, tool renderers,
|
||||
// message renderers, and keyboard shortcuts all update immediately.
|
||||
//
|
||||
// Commands:
|
||||
// /reload — hot-reload all extensions from disk
|
||||
|
||||
@@ -10,13 +10,21 @@ import (
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// re matches !{...} with non-greedy content.
|
||||
var re = regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
|
||||
// Init expands inline bash expressions in user prompts before they reach the
|
||||
// LLM. Text like !{git branch --show-current} is replaced with the command's
|
||||
// stdout.
|
||||
// LLM. Text like !{git rev-parse --abbrev-ref HEAD} is replaced with the
|
||||
// command's stdout.
|
||||
//
|
||||
// In interactive mode the expansion happens at submit time via an editor
|
||||
// interceptor, so the expanded text is also visible in the user message
|
||||
// block on screen. In non-interactive mode (CLI, script, queue) the
|
||||
// expansion happens via OnInput transform.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "Fix the tests on !{git branch --show-current}"
|
||||
// "Fix the tests on !{git rev-parse --abbrev-ref HEAD}"
|
||||
// → "Fix the tests on main"
|
||||
//
|
||||
// "The current directory is !{pwd}"
|
||||
@@ -24,29 +32,59 @@ import (
|
||||
//
|
||||
// Usage: kit -e examples/extensions/inline-bash.go
|
||||
func Init(api ext.API) {
|
||||
// Matches !{...} with non-greedy content.
|
||||
re := regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
// ── Interactive mode: editor interceptor ──────────────────────────
|
||||
// Intercept Enter / Ctrl+D so we can expand !{...} BEFORE the
|
||||
// SubmitMsg is created. This ensures the expanded text appears in
|
||||
// the user message block on screen as well as in the LLM prompt.
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
if !ctx.Interactive {
|
||||
return
|
||||
}
|
||||
ctx.SetEditor(ext.EditorConfig{
|
||||
HandleKey: func(key string, currentText string) ext.EditorKeyAction {
|
||||
if (key == "enter" || key == "ctrl+d") && re.MatchString(currentText) {
|
||||
expanded := expand(currentText)
|
||||
// Clear the textarea asynchronously — calling
|
||||
// SetEditorText synchronously from inside Update()
|
||||
// would deadlock the BubbleTea event loop.
|
||||
go ctx.SetEditorText("")
|
||||
return ext.EditorKeyAction{
|
||||
Type: ext.EditorKeySubmit,
|
||||
SubmitText: expanded,
|
||||
}
|
||||
}
|
||||
return ext.EditorKeyAction{Type: ext.EditorKeyPassthrough}
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// ── Non-interactive fallback: OnInput transform ──────────────────
|
||||
// For CLI, script, and queue sources the editor interceptor is not
|
||||
// active, so we fall back to OnInput which still rewrites the
|
||||
// prompt text sent to the LLM.
|
||||
api.OnInput(func(ev ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if !re.MatchString(ev.Text) {
|
||||
if ev.Source == "interactive" || !re.MatchString(ev.Text) {
|
||||
return nil
|
||||
}
|
||||
|
||||
expanded := re.ReplaceAllStringFunc(ev.Text, func(match string) string {
|
||||
// Extract the command between !{ and }.
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
|
||||
return &ext.InputResult{
|
||||
Action: "transform",
|
||||
Text: expanded,
|
||||
Text: expand(ev.Text),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// expand replaces every !{cmd} in text with the command's stdout.
|
||||
// On error the original !{cmd} token is preserved.
|
||||
func expand(text string) string {
|
||||
return re.ReplaceAllStringFunc(text, func(match string) string {
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,6 +168,10 @@ var (
|
||||
// Test
|
||||
pendingTest *PendingTest
|
||||
|
||||
// Typing indicator
|
||||
typingTicker *time.Ticker
|
||||
typingStop chan struct{}
|
||||
|
||||
// Latest context for background goroutines
|
||||
latestCtx ext.Context
|
||||
latestCtxSet bool
|
||||
@@ -203,8 +207,23 @@ func configDir() string {
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func configPath() string {
|
||||
return filepath.Join(configDir(), "kit-telegram.json")
|
||||
// Prefer project-local config, fall back to global config.
|
||||
local := filepath.Join(configDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
return local
|
||||
}
|
||||
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(global); err == nil {
|
||||
return global
|
||||
}
|
||||
// Neither exists — return local path (will be created on connect).
|
||||
return local
|
||||
}
|
||||
|
||||
func failureLogDir() string {
|
||||
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func tgSendChatAction(token string, chatID int64, action string) error {
|
||||
_, err := telegramRequest(token, "sendChatAction", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"action": action,
|
||||
}, 15)
|
||||
return err
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Error classification
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -637,6 +664,48 @@ func clearHealthTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Typing indicator
|
||||
// ──────────────────────────────────────────────
|
||||
|
||||
func startTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
return
|
||||
}
|
||||
cfg := config
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return
|
||||
}
|
||||
token := cfg.BotToken
|
||||
chatID := cfg.ChatID
|
||||
typingTicker = time.NewTicker(4 * time.Second)
|
||||
typingStop = make(chan struct{})
|
||||
// Send immediately, then every 4 seconds.
|
||||
go func() {
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
for {
|
||||
select {
|
||||
case <-typingTicker.C:
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
case <-typingStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
typingTicker.Stop()
|
||||
close(typingStop)
|
||||
typingTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Polling lifecycle
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -908,7 +977,7 @@ func summarizeToolAction(toolName string, inputJSON string) string {
|
||||
return "searching " + getStr("pattern", "text")
|
||||
case "ls":
|
||||
return "listing " + getStr("path", "directory")
|
||||
case "spawn_subagent":
|
||||
case "subagent":
|
||||
return "spawning subagent"
|
||||
default:
|
||||
return "using " + toolName
|
||||
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
sendShutdownDisconnectedMessage()
|
||||
stopTypingLoop()
|
||||
stopPolling()
|
||||
clearHealthTimer()
|
||||
clearFooter()
|
||||
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
report("run.start", fmt.Sprintf("runId=%d", run.ID))
|
||||
startTypingLoop()
|
||||
ensureProgressMessage()
|
||||
updateProgressMessage()
|
||||
})
|
||||
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
|
||||
run := activeRun
|
||||
mu.Unlock()
|
||||
|
||||
stopTypingLoop()
|
||||
|
||||
if run != nil {
|
||||
// Capture final response from event
|
||||
if e.Response != "" {
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
// lsp-diagnostics.go — LSP-powered diagnostics for Kit's edit tool.
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits,
|
||||
// following the same pattern used by Charm's crush editor:
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits:
|
||||
// 1. After an edit, notify the LSP server of the file change
|
||||
// 2. Wait for the server to publish fresh diagnostics
|
||||
// 3. Append diagnostic output to the edit tool's result
|
||||
@@ -412,7 +410,7 @@ func (c *lspClient) changeFile(absPath, content string) {
|
||||
}
|
||||
|
||||
// waitForDiagnostics polls until the server publishes new diagnostics or
|
||||
// the timeout elapses. Mirrors crush's WaitForDiagnostics pattern.
|
||||
// the timeout elapses.
|
||||
func (c *lspClient) waitForDiagnostics(timeout time.Duration) {
|
||||
c.diagMu.Lock()
|
||||
startVersion := c.diagVersion
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build ignore
|
||||
|
||||
// prompt-templates.go - Frontmatter-driven prompt templates with model switching.
|
||||
// This extension demonstrates the new bridged SDK APIs:
|
||||
// - Tree navigation for conversation management
|
||||
// - Template parsing with {{variable}} substitution
|
||||
// - Model resolution with fallback chains
|
||||
// - Skill injection
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create ~/.config/kit/prompts/debug.md with frontmatter:
|
||||
// ---
|
||||
// description: Debug Python code
|
||||
// model: claude-sonnet-4-20250514
|
||||
// skill: python
|
||||
// ---
|
||||
// Help me debug this Python code: {{input}}
|
||||
//
|
||||
// 2. In Kit: /debug my_script.py
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// PromptTemplate represents a loaded template with frontmatter
|
||||
type PromptTemplate struct {
|
||||
Name string
|
||||
Description string
|
||||
Model string
|
||||
Skill string
|
||||
Content string
|
||||
Variables []string
|
||||
Path string
|
||||
}
|
||||
|
||||
var (
|
||||
templates = make(map[string]PromptTemplate)
|
||||
templateDir string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Determine template directory
|
||||
home, _ := os.UserHomeDir()
|
||||
templateDir = filepath.Join(home, ".config", "kit", "prompts")
|
||||
|
||||
// Ensure directory exists
|
||||
os.MkdirAll(templateDir, 0755)
|
||||
|
||||
// Register commands
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "reload-templates",
|
||||
Description: "Reload prompt templates from disk",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
loadTemplates(ctx)
|
||||
ctx.PrintInfo(fmt.Sprintf("Loaded %d templates from %s", len(templates), templateDir))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Dynamic template commands are registered after loading
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
loadTemplates(ctx)
|
||||
registerTemplateCommands(api, ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadTemplates discovers and loads all template files
|
||||
func loadTemplates(ctx ext.Context) {
|
||||
templates = make(map[string]PromptTemplate)
|
||||
|
||||
entries, err := os.ReadDir(templateDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(templateDir, entry.Name())
|
||||
tpl, err := loadTemplateFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(entry.Name(), ".md")
|
||||
templates[name] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplateFile parses a template with YAML frontmatter
|
||||
func loadTemplateFile(path string) (PromptTemplate, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return PromptTemplate{}, err
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
tpl := PromptTemplate{Path: path}
|
||||
|
||||
// Parse frontmatter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) == 2 {
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
|
||||
// Simple line-by-line frontmatter parsing
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, found := strings.Cut(line, ":")
|
||||
if found {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "description":
|
||||
tpl.Description = value
|
||||
case "model":
|
||||
tpl.Model = value
|
||||
case "skill":
|
||||
tpl.Skill = value
|
||||
}
|
||||
}
|
||||
}
|
||||
tpl.Content = body
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
|
||||
// Parse {{variables}} using simple string parsing
|
||||
// (Can't use ctx.ParseTemplate here since we're in Init, not a handler)
|
||||
var vars []string
|
||||
for {
|
||||
start := strings.Index(tpl.Content, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(tpl.Content[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
varName := strings.TrimSpace(tpl.Content[start+2 : start+end])
|
||||
vars = append(vars, varName)
|
||||
tpl.Content = tpl.Content[:start] + "{{" + varName + "}}" + tpl.Content[start+end+2:]
|
||||
}
|
||||
tpl.Variables = vars
|
||||
|
||||
return tpl, nil
|
||||
}
|
||||
|
||||
// registerTemplateCommands dynamically registers commands for each template
|
||||
func registerTemplateCommands(api ext.API, ctx ext.Context) {
|
||||
for name, tpl := range templates {
|
||||
// Skip if already registered (we'd need to track this)
|
||||
tplCopy := tpl // Capture for closure
|
||||
nameCopy := name
|
||||
|
||||
// Build description with metadata
|
||||
desc := tplCopy.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("Run %s template", nameCopy)
|
||||
}
|
||||
if tplCopy.Model != "" {
|
||||
desc += fmt.Sprintf(" [%s", tplCopy.Model)
|
||||
if tplCopy.Skill != "" {
|
||||
desc += fmt.Sprintf(" +%s", tplCopy.Skill)
|
||||
}
|
||||
desc += "]"
|
||||
}
|
||||
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: nameCopy,
|
||||
Description: desc,
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
return executeTemplate(ctx, tplCopy, args)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplate runs a template with the given arguments
|
||||
func executeTemplate(ctx ext.Context, tpl PromptTemplate, args string) (string, error) {
|
||||
// Store original model for restoration
|
||||
originalModel := ctx.Model
|
||||
|
||||
// 1. Resolve and switch model if specified
|
||||
if tpl.Model != "" {
|
||||
// Parse model chain (comma-separated)
|
||||
preferences := strings.Split(tpl.Model, ",")
|
||||
for i := range preferences {
|
||||
preferences[i] = strings.TrimSpace(preferences[i])
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(preferences)
|
||||
if result.Error != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Model resolution failed: %s", result.Error))
|
||||
// Continue with current model
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Switching to model: %s", result.Model))
|
||||
if err := ctx.SetModel(result.Model); err != nil {
|
||||
ctx.PrintError(fmt.Sprintf("Failed to switch model: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Inject skill if specified
|
||||
if tpl.Skill != "" {
|
||||
err := ctx.InjectSkillAsContext(tpl.Skill)
|
||||
if err != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Skill injection failed: %s", err))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Injected skill: %s", tpl.Skill))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse and render template
|
||||
parsed := ctx.ParseTemplate(tpl.Name, tpl.Content)
|
||||
|
||||
// Build variable map
|
||||
vars := make(map[string]string)
|
||||
|
||||
// Simple argument parsing: first arg is $1 (input), rest is $@
|
||||
if len(parsed.Variables) > 0 {
|
||||
argsList := ctx.SimpleParseArguments(args, len(parsed.Variables))
|
||||
for i, varName := range parsed.Variables {
|
||||
if i < len(parsed.Variables) && i+1 < len(argsList) {
|
||||
vars[varName] = argsList[i+1]
|
||||
}
|
||||
}
|
||||
// If single variable, use full args
|
||||
if len(parsed.Variables) == 1 && vars[parsed.Variables[0]] == "" {
|
||||
vars[parsed.Variables[0]] = args
|
||||
}
|
||||
}
|
||||
|
||||
// Render with model conditionals
|
||||
content := ctx.RenderWithModelConditionals(tpl.Content)
|
||||
rendered := ctx.RenderTemplate(ext.PromptTemplate{Name: tpl.Name, Content: content, Variables: parsed.Variables}, vars)
|
||||
|
||||
// 4. Send the rendered prompt
|
||||
ctx.SendMessage(rendered)
|
||||
|
||||
// 5. Schedule model restoration after turn completes
|
||||
// We use a goroutine to wait and restore
|
||||
if tpl.Model != "" && originalModel != "" {
|
||||
go func() {
|
||||
// Note: In a real implementation, we'd use OnAgentEnd event
|
||||
// For now, the user can manually switch back
|
||||
ctx.SetStatus("template-mode", fmt.Sprintf("Template: %s (model will restore)", tpl.Name), 20)
|
||||
}()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Executing template: %s", tpl.Name), nil
|
||||
}
|
||||
@@ -130,6 +130,58 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_ConcurrentSubagents verifies no panics when multiple
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 5 subagents concurrently
|
||||
done := make(chan struct{}, 5)
|
||||
for i := range 5 {
|
||||
go func(idx int) {
|
||||
defer func() { done <- struct{}{} }()
|
||||
|
||||
callID := fmt.Sprintf("concurrent-%d", idx)
|
||||
task := fmt.Sprintf("concurrent task %d", idx)
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
})
|
||||
|
||||
// Emit many chunks rapidly
|
||||
for j := range 20 {
|
||||
_, _ = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("agent %d chunk %d", idx, j),
|
||||
})
|
||||
}
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
Response: "done",
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 5 {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Allow any final processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
|
||||
@@ -37,7 +37,7 @@ func Init(api ext.API) {
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
"The LLM can also use the subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
//go:build ignore
|
||||
|
||||
// sudo-handler.go - Extension to handle sudo password prompts securely
|
||||
//
|
||||
// This extension intercepts bash commands containing "sudo" and:
|
||||
// 1. Checks if sudo credentials are already cached (via sudo -n)
|
||||
// 2. If not cached, prompts the user for their password (with masking)
|
||||
// 3. Temporarily sets SUDO_PASSWORD environment variable for execution
|
||||
// 4. The bash tool automatically uses sudo -S -p '' to pipe the password
|
||||
//
|
||||
// Usage: kit -e examples/extensions/sudo-handler.go
|
||||
//
|
||||
// Security notes:
|
||||
// - Password is only stored in memory for the duration of the session
|
||||
// - Password is never logged or displayed
|
||||
// - Each session requires re-authentication (sudo -k is used)
|
||||
// - The SUDO_PASSWORD env var is set only during tool execution
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
// cachedPassword stores the sudo password for the session
|
||||
cachedPassword string
|
||||
// hasCachedPassword tracks if we have a valid cached password
|
||||
hasCachedPassword bool
|
||||
// mu protects cached password access
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Init sets up the sudo handler extension
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName != "bash" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the command from tool input
|
||||
var input struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.Input), &input); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if command contains sudo
|
||||
if !containsSudo(input.Command) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we already have cached credentials
|
||||
mu.RLock()
|
||||
password := cachedPassword
|
||||
hasCached := hasCachedPassword
|
||||
mu.RUnlock()
|
||||
|
||||
if hasCached {
|
||||
// Use cached password
|
||||
os.Setenv("SUDO_PASSWORD", password)
|
||||
return nil
|
||||
}
|
||||
|
||||
// No cached password - prompt user
|
||||
result := ctx.PromptInput(ext.PromptInputConfig{
|
||||
Message: "🔐 Sudo password required for:\n " + truncateCommand(input.Command, 60),
|
||||
Placeholder: "Enter your password",
|
||||
})
|
||||
|
||||
if result.Cancelled {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "Sudo password prompt cancelled by user",
|
||||
}
|
||||
}
|
||||
|
||||
if result.Value == "" {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "No password provided",
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the password for this session
|
||||
mu.Lock()
|
||||
cachedPassword = result.Value
|
||||
hasCachedPassword = true
|
||||
mu.Unlock()
|
||||
|
||||
// Set environment variable for the bash tool to use
|
||||
os.Setenv("SUDO_PASSWORD", result.Value)
|
||||
|
||||
// Show confirmation (without revealing password)
|
||||
ctx.PrintInfo("Sudo password cached for this session")
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Clear cached password when session ends
|
||||
api.OnSessionShutdown(func(event ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
mu.Lock()
|
||||
cachedPassword = ""
|
||||
hasCachedPassword = false
|
||||
mu.Unlock()
|
||||
os.Unsetenv("SUDO_PASSWORD")
|
||||
})
|
||||
}
|
||||
|
||||
// containsSudo checks if the command contains sudo as a command (not in a string)
|
||||
func containsSudo(command string) bool {
|
||||
// Simple check for sudo as a word, not inside quotes or as part of another word
|
||||
lower := strings.ToLower(command)
|
||||
|
||||
// Check for sudo at start or after separators
|
||||
patterns := []string{
|
||||
"sudo ",
|
||||
"sudo\t",
|
||||
";sudo ",
|
||||
"&& sudo ",
|
||||
"|| sudo ",
|
||||
"| sudo ",
|
||||
"$(sudo ",
|
||||
"`sudo ",
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if command starts with sudo
|
||||
if strings.HasPrefix(lower, "sudo ") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// truncateCommand truncates a long command for display
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
@@ -1,97 +1,95 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.1
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.5
|
||||
charm.land/fantasy v0.17.2
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.48.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/term v0.42.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.3 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.3 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.0 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
@@ -105,41 +103,39 @@ require (
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
go.opentelemetry.io/otel v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.42.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/glamour v1.0.0
|
||||
github.com/charmbracelet/x/ansi v0.11.6
|
||||
github.com/charmbracelet/x/ansi v0.11.7
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
)
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.5 h1:TQlLFqxo39AAHSVuOhJ5D3nH7O9Nk8JGinsfWQ4y1U4=
|
||||
charm.land/bubbletea/v2 v2.0.5/go.mod h1:dvbsYZD+MHkdIZl+Z67D212hEvB+GII2tfH8f9SnoDw=
|
||||
charm.land/fantasy v0.17.2 h1:ojTMufMxY/PVH7TzYUxht2SVkvD90iCTJfmPR6c8BR8=
|
||||
charm.land/fantasy v0.17.2/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
@@ -34,42 +34,40 @@ github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
|
||||
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -80,8 +78,6 @@ github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex
|
||||
github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q=
|
||||
github.com/charmbracelet/fang v1.0.0 h1:jESBY40agJOlLYnnv9jE0mLqDGTxEk0hkOnx7YGyRlQ=
|
||||
github.com/charmbracelet/fang v1.0.0/go.mod h1:P5/DNb9DddQ0Z0dbc0P3ol4/ix5Po7Ofr2KMBfAqoCo=
|
||||
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
|
||||
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
|
||||
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
@@ -90,24 +86,26 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 h1:hzWNs3UQRSUTS6YCbLaQnwqKBFXT5Yh1OOw6+26apqg=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502/go.mod h1:mkUCcxn9w9j89JJp3pOza5tmDQZPgIB75UfmQlFYvas=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e h1:O5hZFj55wZQWxMiRtQLa3uLKhZGZGS/j8M3OXinQlrw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4jxsAuEs=
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIRg8gGWwk=
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca h1:62yAoS1Ynbuzwcn1LkNBxi3IMF5p0E0cHCoaLOOmN9w=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 h1:6F/6bu5nBLjodsvaU5xAszTaxtHrDU5UiJarpMPZj48=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 h1:RiZFY92Ug9iz1CenzxSSQla2Z3WflsR7bIuXq40JlpU=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -177,41 +175,40 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.0 h1:i7L3U2yurg+xhokITtJ0k+mjHnXqkoyz8ju5Wb7W8Oc=
|
||||
github.com/kaptinlin/go-i18n v0.4.0/go.mod h1:njA6x0+4MWGcLWT0KLrwekhRPmze1Hnstf2+VJFzwpM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.6 h1:UUMqZGFAk7nOzQsYAxvgygm4wpDp/nwXxA4VP9mCPCs=
|
||||
github.com/kaptinlin/jsonschema v0.7.6/go.mod h1:GGk/oE+F1lWUfYrzKaCf4QWZmMdytt0LL4XdFEFB0LE=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 h1:RBlHVWgZyoxTcUgGWBsl2AcyScq/urqbLZvzgryTmSI=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18/go.mod h1:ntI3154RnqJgr7GaC+vZBnIExl2V3sv9selvRNNEM24=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 h1:a0ufTd5liiUubIGeGxpSTnNS8ZSrN4DV01/wGFmfzMs=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20/go.mod h1:FqdEPfQLkqVBX7OBRMPgYwUPvKYJohFD9Ok1BMzCfIo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.48.0 h1:o+MXuGW/HCeR2ny5LcAcZQn2bo6I2xaZMEHnpRG+dtw=
|
||||
github.com/mark3labs/mcp-go v0.48.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -222,8 +219,6 @@ github.com/muesli/mango-cobra v1.3.0 h1:vQy5GvPg3ndOSpduxutqFoINhWk3vD5K2dXo5E8p
|
||||
github.com/muesli/mango-cobra v1.3.0/go.mod h1:Cj1ZrBu3806Qw7UjxnAUgE+7tllUBj1NCLQDwwGx19E=
|
||||
github.com/muesli/mango-pflag v0.2.0 h1:QViokgKDZQCzKhYe1zH8D+UlPJzBSGoP9yx0hBG0t5k=
|
||||
github.com/muesli/mango-pflag v0.2.0/go.mod h1:X9LT1p/pbGA1wjvEbtwnixujKErkP0jVmrxwrw3fL0Y=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
@@ -236,8 +231,6 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgm
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -279,59 +272,56 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 h1:RmoJA1ujG+/lRGNfUnOMfhCy5EipVMyvUE+KNbPbTlw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+226
-16
@@ -7,8 +7,11 @@ package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
@@ -20,7 +23,6 @@ import (
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
conn *acp.AgentSideConnection
|
||||
@@ -111,13 +113,20 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
)
|
||||
}
|
||||
|
||||
// Extract text from prompt content blocks.
|
||||
promptText := extractPromptText(params.Prompt)
|
||||
if promptText == "" {
|
||||
// Extract text and file attachments from prompt content blocks.
|
||||
promptText, files := extractPromptContent(params.Prompt)
|
||||
if promptText == "" && len(files) == 0 {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams("empty prompt")
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText))
|
||||
// If we have files but no text prompt, add a default prompt
|
||||
// This is required because the underlying LLM library needs a non-empty prompt
|
||||
// when there are no previous messages in the conversation.
|
||||
if promptText == "" && len(files) > 0 {
|
||||
promptText = "Please analyze the attached file."
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
@@ -129,7 +138,13 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
defer unsub()
|
||||
|
||||
// Run the prompt through Kit's full turn lifecycle.
|
||||
_, err := sess.kit.PromptResult(promptCtx, promptText)
|
||||
// Use PromptResultWithFiles when file attachments are present.
|
||||
var err error
|
||||
if len(files) > 0 {
|
||||
_, err = sess.kit.PromptResultWithFiles(promptCtx, promptText, files)
|
||||
} else {
|
||||
_, err = sess.kit.PromptResult(promptCtx, promptText)
|
||||
}
|
||||
if err != nil {
|
||||
if promptCtx.Err() != nil {
|
||||
return acp.PromptResponse{
|
||||
@@ -162,6 +177,24 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// SetSessionModel changes the active model for a session.
|
||||
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
|
||||
modelID := string(params.ModelId)
|
||||
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
|
||||
|
||||
if err := sess.kit.SetModel(ctx, modelID); err != nil {
|
||||
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
|
||||
}
|
||||
|
||||
return acp.SetSessionModelResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event streaming: Kit events → ACP SessionUpdate notifications
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -231,19 +264,196 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractPromptText extracts the concatenated text content from ACP content
|
||||
// blocks. Non-text blocks are ignored for now.
|
||||
func extractPromptText(blocks []acp.ContentBlock) string {
|
||||
var text string
|
||||
for _, block := range blocks {
|
||||
if block.Text != nil {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
// extractPromptContent extracts text and file attachments from ACP content blocks.
|
||||
// It converts supported content blocks (image, audio, resource) to Kit's LLMFilePart.
|
||||
func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart) {
|
||||
var textParts []string
|
||||
var files []kit.LLMFilePart
|
||||
|
||||
log.Debug("acp: extracting content", "blocks", len(blocks))
|
||||
|
||||
for i, block := range blocks {
|
||||
switch {
|
||||
// Text content
|
||||
case block.Text != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "text", "len", len(block.Text.Text))
|
||||
textParts = append(textParts, block.Text.Text)
|
||||
|
||||
// Image data (base64)
|
||||
case block.Image != nil:
|
||||
mimeType := block.Image.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png" // Default fallback
|
||||
}
|
||||
text += block.Text.Text
|
||||
log.Debug("acp: content block", "index", i, "type", "image", "mime", mimeType, "data_len", len(block.Image.Data))
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Image.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "image.png",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode image", "error", err)
|
||||
}
|
||||
|
||||
// Audio data (base64)
|
||||
case block.Audio != nil:
|
||||
mimeType := block.Audio.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "audio/wav" // Default fallback
|
||||
}
|
||||
log.Debug("acp: content block", "index", i, "type", "audio", "mime", mimeType)
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Audio.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "audio.wav",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode audio", "error", err)
|
||||
}
|
||||
|
||||
// Embedded resource (text or binary file content)
|
||||
case block.Resource != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "resource")
|
||||
res := block.Resource.Resource
|
||||
// Text resource - append as text content with file reference
|
||||
if res.TextResourceContents != nil {
|
||||
uri := res.TextResourceContents.Uri
|
||||
content := res.TextResourceContents.Text
|
||||
mimeType := "text/plain"
|
||||
if res.TextResourceContents.MimeType != nil {
|
||||
mimeType = *res.TextResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: text resource", "uri", uri, "mime", mimeType, "len", len(content))
|
||||
// Text files are included as formatted text, NOT as FilePart
|
||||
// FilePart is for binary files (images, audio, PDFs) only
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, content))
|
||||
}
|
||||
// Binary resource (base64 blob) - these become FilePart
|
||||
if res.BlobResourceContents != nil {
|
||||
uri := res.BlobResourceContents.Uri
|
||||
mimeType := "application/octet-stream"
|
||||
if res.BlobResourceContents.MimeType != nil {
|
||||
mimeType = *res.BlobResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: binary resource", "uri", uri, "mime", mimeType, "blob_len", len(res.BlobResourceContents.Blob))
|
||||
if data, err := base64.StdEncoding.DecodeString(res.BlobResourceContents.Blob); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode binary resource", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Resource link (file reference without embedded content)
|
||||
case block.ResourceLink != nil:
|
||||
uri := block.ResourceLink.Uri
|
||||
name := block.ResourceLink.Name
|
||||
log.Debug("acp: content block", "index", i, "type", "resource_link", "uri", uri, "name", name)
|
||||
// For resource links, we'll try to read the file from disk
|
||||
// This requires the file URI to be accessible (file:// scheme)
|
||||
if content, err := readResourceFromURI(uri); err == nil {
|
||||
// Detect if it's a text file or binary file
|
||||
mimeType := "text/plain"
|
||||
if block.ResourceLink.MimeType != nil {
|
||||
mimeType = *block.ResourceLink.MimeType
|
||||
}
|
||||
log.Debug("acp: resource link loaded", "uri", uri, "mime", mimeType, "size", len(content))
|
||||
|
||||
// Only create FilePart for binary files (images, audio, PDFs, etc.)
|
||||
// Text files are included as formatted text in the message
|
||||
if isTextMimeType(mimeType) || looksLikeText(content) {
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, string(content)))
|
||||
} else {
|
||||
// Binary file - create FilePart for models that support it
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: content,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If we can't read it, include as a text reference
|
||||
log.Debug("acp: resource link failed to load", "uri", uri, "error", err)
|
||||
textParts = append(textParts, fmt.Sprintf("[Referenced file: %s]", uri))
|
||||
}
|
||||
|
||||
default:
|
||||
log.Debug("acp: content block", "index", i, "type", "unknown/unhandled")
|
||||
}
|
||||
}
|
||||
return text
|
||||
|
||||
// Debug log the extracted content
|
||||
for i, f := range files {
|
||||
log.Debug("acp: extracted file", "index", i, "filename", f.Filename, "mime", f.MediaType, "size", len(f.Data))
|
||||
}
|
||||
|
||||
return strings.Join(textParts, "\n"), files
|
||||
}
|
||||
|
||||
// isTextMimeType returns true if the MIME type indicates text content.
|
||||
func isTextMimeType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "text/") ||
|
||||
mimeType == "application/json" ||
|
||||
mimeType == "application/xml" ||
|
||||
mimeType == "application/javascript" ||
|
||||
mimeType == "application/typescript" ||
|
||||
mimeType == "application/x-sh" ||
|
||||
mimeType == "application/x-python" ||
|
||||
mimeType == "application/x-yaml" ||
|
||||
mimeType == "application/x-toml"
|
||||
}
|
||||
|
||||
// looksLikeText checks if the content appears to be text (not binary).
|
||||
// It samples the first 512 bytes and checks for null bytes or high
|
||||
// concentration of non-printable characters.
|
||||
func looksLikeText(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check first 512 bytes (or less if file is smaller)
|
||||
sampleSize := min(len(data), 512)
|
||||
sample := data[:sampleSize]
|
||||
|
||||
// Count non-printable characters
|
||||
nonPrintable := 0
|
||||
for _, b := range sample {
|
||||
// Null byte indicates binary
|
||||
if b == 0 {
|
||||
return false
|
||||
}
|
||||
// Count control characters (except common whitespace)
|
||||
if b < 32 && b != '\n' && b != '\r' && b != '\t' {
|
||||
nonPrintable++
|
||||
}
|
||||
}
|
||||
|
||||
// If more than 30% non-printable, consider it binary
|
||||
return float64(nonPrintable)/float64(sampleSize) < 0.3
|
||||
}
|
||||
|
||||
// extractFilenameFromURI extracts a filename from a file URI or path.
|
||||
func extractFilenameFromURI(uri string) string {
|
||||
// Handle file:// URIs
|
||||
uri = strings.TrimPrefix(uri, "file://")
|
||||
// Extract basename
|
||||
if idx := strings.LastIndex(uri, "/"); idx >= 0 {
|
||||
return uri[idx+1:]
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// readResourceFromURI attempts to read file content from a file:// URI.
|
||||
func readResourceFromURI(uri string) ([]byte, error) {
|
||||
if !strings.HasPrefix(uri, "file://") {
|
||||
return nil, fmt.Errorf("unsupported URI scheme: %s", uri)
|
||||
}
|
||||
path := uri[7:] // Remove file:// prefix
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON tool args string into a map for
|
||||
|
||||
@@ -62,8 +62,8 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
if kitInstance.HasExtensions() {
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
@@ -121,31 +121,31 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionFilePath() },
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.GetExtensionOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.SetExtensionOption(name, value) },
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.EmitExtensionCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.GetExtensionToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.SetExtensionActiveTools(names) },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
@@ -173,7 +173,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -188,15 +188,15 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(name)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.ReloadExtensions() },
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
sess := &acpSession{
|
||||
|
||||
+473
-145
@@ -25,19 +25,39 @@ type AgentConfig struct {
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the Fantasy agent. Used by the extensions system
|
||||
// before it is passed to the LLM agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ExtraTools are additional tools to include alongside core and MCP tools.
|
||||
// Used by extensions to register custom tools.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -63,6 +83,10 @@ type ToolCallContentHandler func(content string)
|
||||
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
|
||||
type ReasoningDeltaHandler func(delta string)
|
||||
|
||||
// ReasoningCompleteHandler is a function type for handling reasoning/thinking completion.
|
||||
// Called when the last reasoning token has been processed, before text streaming starts.
|
||||
type ReasoningCompleteHandler func()
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
@@ -70,15 +94,33 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// PasswordPromptHandler is a function type for password prompts.
|
||||
// Used by the bash tool when sudo requires a password. The handler receives
|
||||
// a prompt message and returns the password and whether it was cancelled.
|
||||
// Note: This is an alias for core.PasswordPromptCallback.
|
||||
type PasswordPromptHandler = core.PasswordPromptCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
// by a tool-role message with results, or a final assistant message with text).
|
||||
// This enables incremental session persistence so that progress is saved as
|
||||
// it happens rather than only at the end of the turn.
|
||||
type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// Additional tools from external MCP servers can be loaded alongside core tools.
|
||||
//
|
||||
// When MCP servers are configured, tool loading happens in the background so the
|
||||
// agent (and UI) can start immediately. The first LLM call automatically waits
|
||||
// for MCP tools to finish loading before proceeding.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
fantasyAgent fantasy.Agent
|
||||
@@ -92,6 +134,24 @@ type Agent struct {
|
||||
coreTools []fantasy.AgentTool
|
||||
extraTools []fantasy.AgentTool
|
||||
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool // stored for SetModel rebuild
|
||||
|
||||
// providerOptions and modelConfig are stored for rebuilding the fantasy
|
||||
// agent when MCP tools arrive asynchronously or on SetModel.
|
||||
providerOptions fantasy.ProviderOptions
|
||||
skipMaxOutputTokens bool
|
||||
modelConfig *models.ProviderConfig
|
||||
|
||||
// authHandler and tokenStoreFactory are stored from AgentConfig so that
|
||||
// AddMCPServer() can propagate them when creating a new MCPToolManager
|
||||
// at runtime (i.e. when no MCP servers were configured at init time).
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -100,54 +160,50 @@ type GenerateWithLoopResult struct {
|
||||
FinalResponse *fantasy.Response
|
||||
// ConversationMessages contains all messages in the conversation including tool calls and results
|
||||
ConversationMessages []fantasy.Message
|
||||
// Messages contains the conversation as custom content blocks (crush-style)
|
||||
// Messages contains the conversation as custom content blocks
|
||||
Messages []message.Message
|
||||
// TotalUsage contains aggregate token usage across all steps
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
// PersistedMessageCount is the number of new messages (beyond the original
|
||||
// input) that were already persisted incrementally via OnStepMessages during
|
||||
// generation. The caller should skip these when doing post-generation
|
||||
// persistence to avoid duplicates.
|
||||
PersistedMessageCount int
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
|
||||
// External MCP tools are loaded from the config if any MCP servers are configured.
|
||||
// If MCP servers are configured, their tools are loaded in the background —
|
||||
// the agent returns immediately and is usable with core tools only. The first
|
||||
// LLM call (GenerateWithLoop) automatically waits for MCP tools to finish
|
||||
// loading and rebuilds the agent with the full tool set.
|
||||
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider via fantasy
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Register core tools (direct fantasy implementations, no MCP overhead).
|
||||
// Register core tools (direct AgentTool implementations, no MCP overhead).
|
||||
// Use caller-provided tools if set, otherwise default to all core tools.
|
||||
coreTools := agentConfig.CoreTools
|
||||
if len(coreTools) == 0 {
|
||||
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
|
||||
var coreTools []fantasy.AgentTool
|
||||
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
|
||||
// Explicitly zero tools - chat-only mode
|
||||
coreTools = nil
|
||||
} else if len(agentConfig.CoreTools) > 0 {
|
||||
// Custom tools provided - use them
|
||||
coreTools = agentConfig.CoreTools
|
||||
} else {
|
||||
// Default: load all core tools
|
||||
coreTools = core.AllTools()
|
||||
}
|
||||
|
||||
// Build the combined tool list: core tools + any external MCP tools
|
||||
// Build the initial tool list: core tools + extension tools (no MCP yet).
|
||||
allTools := make([]fantasy.AgentTool, len(coreTools))
|
||||
copy(allTools, coreTools)
|
||||
|
||||
// Load external MCP tools if configured
|
||||
var toolManager *tools.MCPToolManager
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager = tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
// MCP tool loading failures are non-fatal; core tools still work
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
} else {
|
||||
mcpTools := toolManager.GetTools()
|
||||
allTools = append(allTools, mcpTools...)
|
||||
}
|
||||
}
|
||||
|
||||
// Append any extra tools provided by extensions.
|
||||
if len(agentConfig.ExtraTools) > 0 {
|
||||
allTools = append(allTools, agentConfig.ExtraTools...)
|
||||
@@ -158,7 +214,149 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
allTools = agentConfig.ToolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Build fantasy agent options
|
||||
// Build agent options
|
||||
agentOpts := buildAgentOptions(agentConfig, providerResult, allTools)
|
||||
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
}
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
providerOptions: providerResult.ProviderOptions,
|
||||
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
// The mcpReady channel is closed when loading completes (success or failure).
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager := tools.NewMCPToolManager()
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
if agentConfig.TokenStoreFactory != nil {
|
||||
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
|
||||
}
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
// Set per-server loaded callback if provided.
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(a.mcpReady)
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
a.mcpErr = err
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// WaitForMCPTools blocks until background MCP tool loading completes.
|
||||
// Returns nil if no MCP servers are configured or if loading succeeded.
|
||||
// Returns the loading error if all servers failed. Safe to call multiple times.
|
||||
func (a *Agent) WaitForMCPTools() error {
|
||||
if a.mcpReady == nil {
|
||||
return nil
|
||||
}
|
||||
<-a.mcpReady
|
||||
return a.mcpErr
|
||||
}
|
||||
|
||||
// MCPToolsReady returns true if MCP tool loading has completed (or was never
|
||||
// started). This is a non-blocking check useful for UI status display.
|
||||
func (a *Agent) MCPToolsReady() bool {
|
||||
if a.mcpReady == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-a.mcpReady:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMCPTools waits for MCP tools to load and rebuilds the fantasy agent
|
||||
// with the full tool set. Called lazily before the first LLM call.
|
||||
// This is idempotent — subsequent calls after the first rebuild are no-ops.
|
||||
func (a *Agent) ensureMCPTools() {
|
||||
if a.mcpReady == nil {
|
||||
return
|
||||
}
|
||||
<-a.mcpReady
|
||||
|
||||
// If there are MCP tools, rebuild the fantasy agent to include them.
|
||||
if a.toolManager != nil && len(a.toolManager.GetTools()) > 0 {
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// Nil out the channel so future calls are instant no-ops and we
|
||||
// don't rebuild again.
|
||||
a.mcpReady = nil
|
||||
}
|
||||
|
||||
// rebuildFantasyAgent reconstructs the fantasy agent with the current full
|
||||
// tool set (core + MCP + extension tools). Used after MCP tools arrive
|
||||
// asynchronously and by SetModel.
|
||||
func (a *Agent) rebuildFantasyAgent() {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
providerResult := &models.ProviderResult{
|
||||
Model: a.model,
|
||||
ProviderOptions: a.providerOptions,
|
||||
SkipMaxOutputTokens: a.skipMaxOutputTokens,
|
||||
}
|
||||
agentOpts := buildAgentOptions(&AgentConfig{
|
||||
ModelConfig: a.modelConfig,
|
||||
SystemPrompt: a.systemPrompt,
|
||||
MaxSteps: a.maxSteps,
|
||||
}, providerResult, allTools)
|
||||
|
||||
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
|
||||
}
|
||||
|
||||
// buildAgentOptions constructs the fantasy.AgentOption slice from config,
|
||||
// provider result, and the combined tool list. Shared by NewAgent,
|
||||
// rebuildFantasyAgent, and SetModel.
|
||||
func buildAgentOptions(agentConfig *AgentConfig, providerResult *models.ProviderResult, allTools []fantasy.AgentTool) []fantasy.AgentOption {
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
if agentConfig.SystemPrompt != "" {
|
||||
@@ -196,33 +394,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.ModelConfig.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
|
||||
}
|
||||
}
|
||||
|
||||
// Create the fantasy agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
if agentConfig.ModelConfig.FrequencyPenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
|
||||
}
|
||||
if agentConfig.ModelConfig.PresencePenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
|
||||
}
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
toolManager: toolManager,
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
}, nil
|
||||
return agentOpts
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
|
||||
@@ -231,38 +411,54 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
// Fantasy handles the tool call loop internally. We map fantasy's rich callback system
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
onPasswordPrompt PasswordPromptHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Wait for background MCP tool loading to complete and rebuild the
|
||||
// fantasy agent with the full tool set. This is a no-op when no MCP
|
||||
// servers are configured or tools have already been integrated.
|
||||
a.ensureMCPTools()
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
}
|
||||
|
||||
// Fantasy requires the current user input as Prompt, with prior messages as history.
|
||||
// Inject password prompt handler into context for use by bash tool.
|
||||
if onPasswordPrompt != nil {
|
||||
ctx = core.ContextWithPasswordPrompt(ctx, onPasswordPrompt)
|
||||
}
|
||||
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
// Extract the last user message text and files as the prompt, and pass everything
|
||||
// before it as Messages. Files (e.g. clipboard images) are passed via the Files
|
||||
// field so Fantasy includes them in the API request.
|
||||
// field so the agent includes them in the API request.
|
||||
prompt, files, history := splitPromptAndHistory(messages)
|
||||
|
||||
// Track current tool call info for callbacks
|
||||
var currentToolName string
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. Fantasy only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
@@ -270,12 +466,16 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
// on cancellation. Fantasy's Stream() discards accumulated steps
|
||||
// on cancellation. The agent's Stream() discards accumulated steps
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
@@ -292,6 +492,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning/thinking complete callback
|
||||
OnReasoningEnd: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text streaming callback
|
||||
OnTextDelta: func(id, text string) error {
|
||||
if ctx.Err() != nil {
|
||||
@@ -308,7 +519,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolName = tc.ToolName
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
@@ -349,6 +559,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// persisted even if a later step is cancelled.
|
||||
completedStepMessages = append(completedStepMessages, step.Messages...)
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -379,7 +596,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
@@ -396,15 +613,20 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
@@ -422,19 +644,25 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
partialMessages = append(partialMessages, messages...)
|
||||
partialMessages = append(partialMessages, completedStepMessages...)
|
||||
return &GenerateWithLoopResult{
|
||||
ConversationMessages: partialMessages,
|
||||
ConversationMessages: partialMessages,
|
||||
PersistedMessageCount: persistedCount,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fire the response callback for callers that use it (e.g. non-streaming
|
||||
// callers that still want the final response notification).
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// Fire the response callback so callers (e.g. the TUI) can reset
|
||||
// streaming state. This must fire even when the response text is
|
||||
// empty (e.g. reasoning-only responses) so the UI properly resets
|
||||
// the stream component and avoids duplicate content on the next
|
||||
// flush.
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
r := convertAgentResult(result, messages)
|
||||
r.PersistedMessageCount = persistedCount
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Non-streaming path with no callbacks — use the simpler Generate call.
|
||||
@@ -447,18 +675,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For non-streaming, fire the response callback with the final text
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// For non-streaming, fire the response callback so callers can reset
|
||||
// streaming state (see streaming path comment above).
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
_ = currentToolName // satisfy compiler for non-streaming path
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
}
|
||||
|
||||
// splitPromptAndHistory extracts the last user message as the prompt string,
|
||||
// and returns everything before it as conversation history. Fantasy's agent
|
||||
// and returns everything before it as conversation history. The agent's
|
||||
// requires the current turn's input as Prompt (string), with prior messages
|
||||
// passed separately as Messages (history).
|
||||
func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePart, []fantasy.Message) {
|
||||
@@ -501,8 +728,8 @@ func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePa
|
||||
return "", nil, messages
|
||||
}
|
||||
|
||||
// convertAgentResult converts a fantasy AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the legacy fantasy.Message slice and the new custom content blocks.
|
||||
// convertAgentResult converts an AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the message slice and the new custom content blocks.
|
||||
func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.Message) *GenerateWithLoopResult {
|
||||
// Collect all conversation messages: original + all step messages
|
||||
var allFantasyMessages []fantasy.Message
|
||||
@@ -515,7 +742,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
// Convert to custom content blocks
|
||||
var allMessages []message.Message
|
||||
for _, fm := range allFantasyMessages {
|
||||
allMessages = append(allMessages, message.FromFantasyMessage(fm))
|
||||
allMessages = append(allMessages, message.FromLLMMessage(fm))
|
||||
}
|
||||
|
||||
return &GenerateWithLoopResult{
|
||||
@@ -527,7 +754,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
}
|
||||
}
|
||||
|
||||
// extractToolResultText extracts the text and error status from a fantasy ToolResultContent.
|
||||
// extractToolResultText extracts the text and error status from a ToolResultContent.
|
||||
// For core tools, the result is already clean text (no MCP JSON wrapping).
|
||||
// For MCP tools, it unwraps the MCP content structure.
|
||||
func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
@@ -540,7 +767,7 @@ func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
return errResult.Error.Error(), true
|
||||
}
|
||||
|
||||
// Get text directly from the Fantasy result type.
|
||||
// Get text directly from the result type.
|
||||
if textResult, ok := tr.Result.(fantasy.ToolResultOutputContentText); ok {
|
||||
// Try to unwrap MCP JSON structure (for external MCP tools).
|
||||
// Core tools return plain text, so this is a no-op for them.
|
||||
@@ -592,7 +819,7 @@ func (a *Agent) GetTools() []fantasy.AgentTool {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
@@ -618,6 +845,72 @@ func (a *Agent) GetExtensionToolCount() int {
|
||||
return len(a.extraTools)
|
||||
}
|
||||
|
||||
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
|
||||
// tools) and rebuilds the internal agent with the updated tool list. The
|
||||
// model, system prompt, and all other configuration are preserved.
|
||||
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
|
||||
a.extraTools = extraTools
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent. Returns the number of tools loaded.
|
||||
// If the agent has no tool manager (no MCP servers were configured at init),
|
||||
// one is created automatically.
|
||||
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
if a.authHandler != nil {
|
||||
a.toolManager.SetAuthHandler(a.authHandler)
|
||||
}
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
}
|
||||
|
||||
count, err := a.toolManager.AddServer(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but only if it was wired. Ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
|
||||
func (a *Agent) RemoveMCPServer(name string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
err := a.toolManager.RemoveServer(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPToolManager returns the underlying MCP tool manager.
|
||||
// Returns nil if no MCP servers have been configured.
|
||||
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
|
||||
return a.toolManager
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
@@ -631,78 +924,88 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
|
||||
// system prompt, and configuration are preserved. The old provider is closed
|
||||
// if it has a closer. Returns the previous model string for notification.
|
||||
// GetMCPPrompts returns all prompts discovered from connected MCP servers.
|
||||
// Returns nil if no MCP servers are configured or no prompts were found.
|
||||
func (a *Agent) GetMCPPrompts() []tools.MCPPrompt {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetPrompts()
|
||||
}
|
||||
|
||||
// GetMCPPrompt retrieves and expands a specific prompt from an MCP server.
|
||||
// This is a lazy call — the server is contacted each time.
|
||||
func (a *Agent) GetMCPPrompt(ctx context.Context, serverName, promptName string, args map[string]string) (*tools.MCPPromptResult, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.GetPrompt(ctx, serverName, promptName, args)
|
||||
}
|
||||
|
||||
// GetMCPResources returns all resources discovered from connected MCP servers.
|
||||
func (a *Agent) GetMCPResources() []tools.MCPResource {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetResources()
|
||||
}
|
||||
|
||||
// ReadMCPResource reads a specific resource from an MCP server by URI.
|
||||
func (a *Agent) ReadMCPResource(ctx context.Context, serverName, uri string) (*tools.MCPResourceContent, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.ReadResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SubscribeMCPResource subscribes to change notifications for a resource.
|
||||
func (a *Agent) SubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.SubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// UnsubscribeMCPResource cancels change notifications for a resource.
|
||||
func (a *Agent) UnsubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.UnsubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
|
||||
// it has a closer.
|
||||
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
|
||||
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
|
||||
// before the first LLM call).
|
||||
a.ensureMCPTools()
|
||||
|
||||
providerResult, err := models.CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Rebuild tool list (same as NewAgent).
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild fantasy agent options.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
|
||||
}
|
||||
if a.maxSteps > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithStopConditions(
|
||||
fantasy.StepCountIs(a.maxSteps),
|
||||
))
|
||||
}
|
||||
|
||||
// Pass provider-specific options (e.g. OpenAI Responses API reasoning settings).
|
||||
if providerResult.ProviderOptions != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithProviderOptions(providerResult.ProviderOptions))
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTemperature(float64(*config.Temperature)))
|
||||
}
|
||||
if config.TopP != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopP(float64(*config.TopP)))
|
||||
}
|
||||
if config.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
|
||||
}
|
||||
|
||||
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Close old provider.
|
||||
if a.providerCloser != nil {
|
||||
_ = a.providerCloser.Close()
|
||||
}
|
||||
|
||||
// Update model info on MCP tool manager.
|
||||
if a.toolManager != nil {
|
||||
a.toolManager.SetModel(providerResult.Model)
|
||||
}
|
||||
|
||||
// Swap fields.
|
||||
a.fantasyAgent = newFantasyAgent
|
||||
a.model = providerResult.Model
|
||||
a.providerCloser = providerResult.Closer
|
||||
a.providerOptions = providerResult.ProviderOptions
|
||||
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
|
||||
a.modelConfig = config
|
||||
|
||||
// Update system prompt when the config carries one (from per-model
|
||||
// settings or the global config). This allows model-specific system
|
||||
// prompts to take effect on model switch.
|
||||
if config.SystemPrompt != "" {
|
||||
a.systemPrompt = config.SystemPrompt
|
||||
}
|
||||
|
||||
// Update provider type.
|
||||
if config.ModelString != "" {
|
||||
@@ -711,16 +1014,41 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild the fantasy agent with the new model and current tool set.
|
||||
a.rebuildFantasyAgent()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModel returns the underlying fantasy LanguageModel.
|
||||
// GetModel returns the underlying LanguageModel.
|
||||
func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// GetMaxTokens returns the effective max output tokens the agent currently
|
||||
// sends to the LLM provider, after per-model defaults, right-sizing, and any
|
||||
// Anthropic thinking-budget adjustments. Returns 0 when no ModelConfig is
|
||||
// attached (e.g. early init) or when the provider suppresses the parameter
|
||||
// (e.g. Codex OAuth), which allows callers to differentiate "default" from
|
||||
// "explicitly capped".
|
||||
func (a *Agent) GetMaxTokens() int {
|
||||
if a.skipMaxOutputTokens {
|
||||
return 0
|
||||
}
|
||||
if a.modelConfig == nil {
|
||||
return 0
|
||||
}
|
||||
return a.modelConfig.MaxTokens
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources.
|
||||
// If MCP tools are still loading in the background, Close waits for them
|
||||
// to finish before closing connections to avoid resource leaks.
|
||||
func (a *Agent) Close() error {
|
||||
// Wait for background MCP loading to finish before closing connections.
|
||||
if a.mcpReady != nil {
|
||||
<-a.mcpReady
|
||||
}
|
||||
var toolErr error
|
||||
if a.toolManager != nil {
|
||||
toolErr = a.toolManager.Close()
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// mockModel is a minimal LanguageModel that satisfies the interface
|
||||
// without making real API calls. Used to test tool management wiring.
|
||||
type mockModel struct{}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) Provider() string { return "mock" }
|
||||
func (m *mockModel) Model() string { return "mock-model" }
|
||||
|
||||
// testdataDir returns the absolute path to the tools testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// mockAuthHandler is a minimal MCPAuthHandler for testing that auth handler
|
||||
// propagation works without requiring a real OAuth server.
|
||||
type mockAuthHandler struct {
|
||||
redirectURI string
|
||||
}
|
||||
|
||||
func (h *mockAuthHandler) RedirectURI() string { return h.redirectURI }
|
||||
func (h *mockAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// newTestAgent creates a minimal Agent with a mock model and no core tools,
|
||||
// suitable for testing MCP server management without an API key.
|
||||
func newTestAgent() *Agent {
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Initially no MCP tools.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Add a server.
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
// Verify tools are in the agent's tool list.
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
allTools := a.GetTools()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range allTools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo' in agent tools")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet' in agent tools")
|
||||
}
|
||||
|
||||
// Verify loaded server names.
|
||||
names := a.GetLoadedServerNames()
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected 'echo' in loaded server names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add then remove.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools removed.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Verify agent's tool list has no MCP tools.
|
||||
for _, tool := range a.GetTools() {
|
||||
if strings.Contains(tool.Info().Name, "echo__") {
|
||||
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
err := a.RemoveMCPServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no tool manager exists")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no MCP servers loaded") {
|
||||
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist.
|
||||
if a.GetMCPToolManager() == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add → Remove → Add cycle.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First add failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-add failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgent_AddMCPServer_InheritsAuthHandler verifies that AddMCPServer()
|
||||
// propagates the agent's authHandler and tokenStoreFactory to a newly created
|
||||
// MCPToolManager (fix for issue #3).
|
||||
func TestAgent_AddMCPServer_InheritsAuthHandler(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := &mockAuthHandler{redirectURI: "http://localhost:9999/oauth/callback"}
|
||||
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
authHandler: handler,
|
||||
tokenStoreFactory: nil, // nil is fine; we just test authHandler propagation
|
||||
}
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist and have the auth handler set.
|
||||
tm := a.GetMCPToolManager()
|
||||
if tm == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
|
||||
// Verify the auth handler was propagated by checking the field directly.
|
||||
if tm.GetAuthHandler() == nil {
|
||||
t.Fatal("Expected auth handler to be propagated to tool manager")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/anthropic"
|
||||
)
|
||||
|
||||
// cacheControlOptions returns provider options for Anthropic cache control.
|
||||
// This is used at the message level to avoid type conflicts with provider-level options.
|
||||
func cacheControlOptions() fantasy.ProviderOptions {
|
||||
return anthropic.NewProviderCacheControlOptions(&anthropic.ProviderCacheControlOptions{
|
||||
CacheControl: anthropic.CacheControl{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyCacheControlToMessages adds cache control to specific messages.
|
||||
// Anthropic allows max 4 cache blocks per request.
|
||||
// Counts existing cache blocks and only adds new ones up to the limit.
|
||||
func applyCacheControlToMessages(messages []fantasy.Message) []fantasy.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original slice
|
||||
result := make([]fantasy.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
cacheOpts := cacheControlOptions()
|
||||
maxCacheBlocks := 4
|
||||
|
||||
// Helper to check if message already has cache control
|
||||
hasCache := func(msg fantasy.Message) bool {
|
||||
if msg.ProviderOptions == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := msg.ProviderOptions["anthropic"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Count existing cache blocks
|
||||
existingCacheCount := 0
|
||||
for _, msg := range result {
|
||||
if hasCache(msg) {
|
||||
existingCacheCount++
|
||||
}
|
||||
}
|
||||
|
||||
// If we're already at or over the limit, don't add more
|
||||
if existingCacheCount >= maxCacheBlocks {
|
||||
return result
|
||||
}
|
||||
|
||||
// How many new cache blocks can we add?
|
||||
remaining := maxCacheBlocks - existingCacheCount
|
||||
|
||||
// First: find and cache the last system message (most important)
|
||||
lastSystemIdx := -1
|
||||
for i, msg := range result {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if lastSystemIdx >= 0 && remaining > 0 && !hasCache(result[lastSystemIdx]) {
|
||||
result[lastSystemIdx].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
// Second: cache the most recent messages (up to remaining limit)
|
||||
// Work backwards from the end to prioritize recent context
|
||||
for i := len(result) - 1; i >= 0 && remaining > 0; i-- {
|
||||
if hasCache(result[i]) {
|
||||
continue
|
||||
}
|
||||
result[i].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
+27
-10
@@ -36,13 +36,26 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ToolWrapper wraps the combined tool list before agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -50,15 +63,19 @@ type AgentCreationOptions struct {
|
||||
// Returns the created agent or an error if creation fails.
|
||||
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
|
||||
agentConfig := &AgentConfig{
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
manager *tools.MCPToolManager
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: t.tool.Name,
|
||||
Description: t.tool.Description,
|
||||
Parameters: t.tool.Parameters,
|
||||
Required: t.tool.Required,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(result.Content), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(result.Content), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
|
||||
// mcpToolsToAgentTools converts a slice of MCPTool to fantasy.AgentTool
|
||||
// implementations that route execution through the MCPToolManager.
|
||||
func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManager) []fantasy.AgentTool {
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
}
|
||||
+15
-4
@@ -1,6 +1,17 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// SteerMessage carries a steering prompt and optional file attachments
|
||||
// (e.g. clipboard images) through the steer channel.
|
||||
type SteerMessage struct {
|
||||
Text string
|
||||
Files []fantasy.FilePart
|
||||
}
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
@@ -11,7 +22,7 @@ type steerConsumedKey struct{}
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan SteerMessage) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
@@ -23,8 +34,8 @@ func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.C
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
func steerChFromContext(ctx context.Context) <-chan SteerMessage {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan SteerMessage)
|
||||
return ch
|
||||
}
|
||||
|
||||
|
||||
+451
-62
@@ -3,8 +3,11 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -17,7 +20,7 @@ import (
|
||||
// queueItem holds a prompt and optional image attachments for the execution queue.
|
||||
type queueItem struct {
|
||||
Prompt string
|
||||
Files []fantasy.FilePart
|
||||
Files []kit.LLMFilePart
|
||||
}
|
||||
|
||||
// App is the application-layer orchestrator. It owns the agentic loop,
|
||||
@@ -66,11 +69,20 @@ type App struct {
|
||||
// rootCtx/rootCancel are used to signal shutdown to all goroutines.
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
// initialMessages may be nil or empty for a fresh session.
|
||||
func New(opts Options, initialMessages []fantasy.Message) *App {
|
||||
func New(opts Options, initialMessages []kit.LLMMessage) *App {
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
return &App{
|
||||
opts: opts,
|
||||
@@ -114,9 +126,8 @@ func (a *App) Run(prompt string) int {
|
||||
// If the app is idle the prompt executes immediately; otherwise it is queued.
|
||||
// Returns the current queue depth (0 = started immediately, >0 = queued).
|
||||
//
|
||||
// Satisfies ui.AppController (via RunWithImages which converts ImageAttachment
|
||||
// to fantasy.FilePart).
|
||||
func (a *App) RunWithFiles(prompt string, files []fantasy.FilePart) int {
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) RunWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -151,6 +162,24 @@ func (a *App) CancelCurrentStep() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// IsBusy returns true when the agent is currently processing a turn.
|
||||
func (a *App) IsBusy() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.busy
|
||||
}
|
||||
|
||||
// Abort cancels the current agent step (if running) and clears the queue.
|
||||
// Unlike InterruptAndSend, no new message is injected — the agent simply
|
||||
// stops. Safe to call when idle (no-op).
|
||||
func (a *App) Abort() {
|
||||
a.mu.Lock()
|
||||
a.queue = a.queue[:0]
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// QueueLength returns the number of prompts currently waiting in the queue.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
@@ -176,6 +205,15 @@ func (a *App) QueueLength() int {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
return a.SteerWithFiles(prompt, nil)
|
||||
}
|
||||
|
||||
// SteerWithFiles injects a steering message with optional file attachments
|
||||
// (e.g. pasted images) into the currently running agent turn. Behaves like
|
||||
// Steer but includes file parts alongside the text.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -184,8 +222,8 @@ func (a *App) Steer(prompt string) int {
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
// Not busy — start immediately, same as RunWithFiles().
|
||||
item := queueItem{Prompt: prompt, Files: files}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
@@ -200,7 +238,7 @@ func (a *App) Steer(prompt string) int {
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
a.opts.Kit.InjectSteerWithFiles(prompt, files)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -259,6 +297,17 @@ func (a *App) ClearMessages() {
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadMessagesFromTree clears the in-memory message store and reloads it
|
||||
// from the tree session's current branch. Unlike ClearMessages, this does NOT
|
||||
// reset the tree session's leaf pointer. Used after Branch() to sync the
|
||||
// store with the new branch position.
|
||||
func (a *App) ReloadMessagesFromTree() {
|
||||
a.store.Clear()
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
// GetTreeSession returns the tree session manager, or nil if not configured.
|
||||
func (a *App) GetTreeSession() *session.TreeManager {
|
||||
return a.opts.TreeSession
|
||||
@@ -280,7 +329,7 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,12 +340,12 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) AddContextMessage(text string) {
|
||||
msg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(msg)
|
||||
kitMsg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(kitMsg)
|
||||
|
||||
// Persist to tree session if active.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
_, _ = ts.AppendFantasyMessage(msg)
|
||||
_, _ = ts.AppendLLMMessage(fantasy.NewUserMessage(text))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -334,6 +383,15 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
@@ -346,7 +404,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetFantasyMessages())
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
@@ -359,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("cannot compact while the agent is working")
|
||||
}
|
||||
if a.opts.Kit == nil {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("SDK instance not available")
|
||||
}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
if onError != nil {
|
||||
onError(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
|
||||
if onError != nil {
|
||||
onError("nothing to compact")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -367,6 +497,12 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// response text to stdout. No intermediate events are emitted. Blocks until
|
||||
// the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
return a.RunOnceWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithFiles executes a single agent step synchronously with optional
|
||||
// multimodal file attachments. Prints the response to stdout and returns.
|
||||
func (a *App) RunOnceWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -374,7 +510,7 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -389,6 +525,12 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
// full TurnResult without printing anything. This is used by --json mode to
|
||||
// capture structured output for serialization.
|
||||
func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
return a.RunOnceResultWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceResultWithFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments and returns the full TurnResult.
|
||||
func (a *App) RunOnceResultWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -396,7 +538,7 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
return a.executeStep(stepCtx, prompt, nil, nil)
|
||||
return a.executeStep(stepCtx, prompt, nil, files)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplay executes a single agent step synchronously, sending
|
||||
@@ -410,6 +552,12 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
//
|
||||
// Blocks until the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn func(tea.Msg)) error {
|
||||
return a.RunOnceWithDisplayAndFiles(ctx, prompt, eventFn, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplayAndFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments, sending intermediate display events.
|
||||
func (a *App) RunOnceWithDisplayAndFiles(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -417,7 +565,7 @@ func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn fun
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -483,11 +631,10 @@ func (a *App) drainQueue(first queueItem) {
|
||||
a.mu.Lock()
|
||||
items = append(items, a.queue...)
|
||||
a.queue = a.queue[:0] // Clear the queue
|
||||
queueLen := len(a.queue)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Send queue updated event (queue is now empty)
|
||||
a.sendEvent(QueueUpdatedEvent{Length: queueLen})
|
||||
// Notify UI: all queued messages have been consumed into this batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
@@ -500,8 +647,8 @@ func (a *App) drainQueue(first queueItem) {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
@@ -520,6 +667,11 @@ func (a *App) drainQueue(first queueItem) {
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if hasMore {
|
||||
// Notify UI: these newly queued messages have been consumed into the next batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
}
|
||||
|
||||
if !hasMore {
|
||||
// No more items, we're done
|
||||
break
|
||||
@@ -567,7 +719,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// call/result pairs; only the in-progress message or tool
|
||||
// call is discarded. Sync the in-memory store to match.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
a.sendEvent(StepCancelledEvent{})
|
||||
return
|
||||
@@ -586,7 +738,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// executeStep runs a single agentic step by delegating to the SDK's
|
||||
// PromptResult() (or PromptResultWithFiles for multimodal), which handles
|
||||
// session persistence, hooks, extension events, and the generation loop.
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []fantasy.FilePart) (*kit.TurnResult, error) {
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
// Test hook: bypass SDK entirely.
|
||||
if a.opts.PromptFunc != nil {
|
||||
return a.opts.PromptFunc(ctx, prompt)
|
||||
@@ -598,9 +750,10 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -620,8 +773,9 @@ func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.M
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker.
|
||||
a.updateUsageFromTurnResult(result, prompt)
|
||||
// Update usage tracker. If per-step usage was already recorded from
|
||||
// StepUsageEvent callbacks, avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -645,9 +799,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe to SDK events for TUI rendering. The subscription is
|
||||
// temporary — it lives only for the duration of this step.
|
||||
unsub := a.subscribeSDKEvents(sendFn)
|
||||
// Subscribe to SDK events for TUI rendering and per-step usage updates.
|
||||
// The subscription is temporary — it lives only for the duration of this step.
|
||||
var sawStepUsage atomic.Bool
|
||||
unsub := a.subscribeSDKEvents(sendFn, &sawStepUsage)
|
||||
defer unsub()
|
||||
|
||||
// Show spinner while the agent works.
|
||||
@@ -680,8 +835,8 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
messages = append(messages, item.Prompt)
|
||||
}
|
||||
|
||||
// TODO: Handle file attachments in batch mode
|
||||
// For now, files are ignored in batch mode (rare edge case)
|
||||
// File attachments are not supported in batch mode; fall back to
|
||||
// processing only the first item that carries files.
|
||||
if hasFiles {
|
||||
// If files exist, fall back to processing just the first item with files
|
||||
for _, item := range items {
|
||||
@@ -702,8 +857,10 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
// Sync in-memory store with the SDK's authoritative conversation.
|
||||
a.store.Replace(result.Messages)
|
||||
|
||||
// Update usage tracker (using last item's prompt for tracking).
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt)
|
||||
// Update usage tracker (using last item's prompt for fallback estimation).
|
||||
// If per-step usage was already recorded from StepUsageEvent callbacks,
|
||||
// avoid double-counting totals.
|
||||
a.updateUsageFromTurnResult(result, items[len(items)-1].Prompt, sawStepUsage.Load())
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -720,9 +877,10 @@ func (a *App) sendEvent(msg tea.Msg) {
|
||||
}
|
||||
|
||||
// subscribeSDKEvents registers temporary SDK event subscribers that convert
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. Returns an
|
||||
// unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
// SDK events to tea.Msg events and dispatch them via sendFn. When stepUsageSeen
|
||||
// is provided, it is set to true after any non-zero StepUsageEvent is observed.
|
||||
// Returns an unsubscribe function that removes all listeners.
|
||||
func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Bool) func() {
|
||||
k := a.opts.Kit
|
||||
var unsubs []func()
|
||||
|
||||
@@ -747,6 +905,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
sendFn(StreamChunkEvent{Content: ev.Chunk})
|
||||
case kit.ReasoningDeltaEvent:
|
||||
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
|
||||
case kit.ReasoningCompleteEvent:
|
||||
sendFn(ReasoningCompleteEvent{})
|
||||
case kit.ToolOutputEvent:
|
||||
sendFn(ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
@@ -756,6 +916,24 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
})
|
||||
case kit.SteerConsumedEvent:
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
case kit.PasswordPromptEvent:
|
||||
// Convert SDK PasswordPromptEvent to app PasswordPromptEvent
|
||||
// The TUI will handle this and send the response back
|
||||
responseCh := make(chan PasswordPromptResponse, 1)
|
||||
sendFn(PasswordPromptEvent{
|
||||
Prompt: ev.Prompt,
|
||||
ResponseCh: responseCh,
|
||||
})
|
||||
// Wait for TUI response and forward to SDK
|
||||
resp := <-responseCh
|
||||
ev.ResponseCh <- kit.PasswordPromptResponse{
|
||||
Password: resp.Password,
|
||||
Cancelled: resp.Cancelled,
|
||||
}
|
||||
case kit.TurnEndEvent:
|
||||
a.handleTurnEnd(ev, sendFn)
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -766,6 +944,64 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg)) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// handleTurnEnd inspects a turn's final StopReason and surfaces actionable
|
||||
// feedback to the user when the turn ended in a state they can act on.
|
||||
//
|
||||
// Today the only surfaced case is FinishReasonLength — the model hit its
|
||||
// configured max_output_tokens budget and the reply was truncated. Without
|
||||
// this banner the TUI used to swallow the truncation silently, leading to
|
||||
// "ghost" cut-offs with no indication of why.
|
||||
//
|
||||
// Separated from subscribeSDKEvents so tests can exercise it directly via a
|
||||
// stubbed sendFn without standing up a full Kit.
|
||||
func (a *App) handleTurnEnd(ev kit.TurnEndEvent, sendFn func(tea.Msg)) {
|
||||
if sendFn == nil {
|
||||
return
|
||||
}
|
||||
if ev.StopReason != kit.FinishReasonLength {
|
||||
return
|
||||
}
|
||||
sendFn(ExtensionPrintEvent{
|
||||
Level: "info",
|
||||
Text: a.formatMaxTokensTruncatedMessage(),
|
||||
})
|
||||
}
|
||||
|
||||
// formatMaxTokensTruncatedMessage builds the user-facing explanation for a
|
||||
// truncated turn. It reports the active max_output_tokens budget and, when
|
||||
// known, the model's catalog output ceiling so the user can judge how much
|
||||
// headroom is available.
|
||||
func (a *App) formatMaxTokensTruncatedMessage() string {
|
||||
k := a.opts.Kit
|
||||
if k == nil {
|
||||
// Extremely early / test-stub case: still emit a useful generic hint.
|
||||
return "⚠ Response truncated: the model hit the configured max_output_tokens limit. " +
|
||||
"Raise it with --max-tokens N, KIT_MAX_TOKENS=N, or per-model " +
|
||||
"modelSettings[provider/model].maxTokens in config."
|
||||
}
|
||||
current := k.MaxTokens()
|
||||
ceiling := k.MaxOutputLimit()
|
||||
model := k.GetModelString()
|
||||
|
||||
msg := "⚠ Response truncated: "
|
||||
if model != "" {
|
||||
msg += fmt.Sprintf("%s hit the configured max_output_tokens limit", model)
|
||||
} else {
|
||||
msg += "the model hit the configured max_output_tokens limit"
|
||||
}
|
||||
if current > 0 {
|
||||
msg += fmt.Sprintf(" (%d)", current)
|
||||
}
|
||||
msg += "."
|
||||
if ceiling > 0 && current > 0 && ceiling > current {
|
||||
msg += fmt.Sprintf(" This model supports up to %d output tokens.", ceiling)
|
||||
}
|
||||
msg += "\n\nRaise it with --max-tokens N, KIT_MAX_TOKENS=N, " +
|
||||
"or per-model modelSettings[provider/model].maxTokens in your config. " +
|
||||
"Re-run the last prompt after raising it to get the full response."
|
||||
return msg
|
||||
}
|
||||
|
||||
// QuitFromExtension triggers a graceful shutdown. In interactive mode it
|
||||
// sends a tea.QuitMsg to the program so the TUI exits cleanly. In
|
||||
// non-interactive mode it cancels the root context, stopping any in-flight
|
||||
@@ -786,7 +1022,8 @@ func (a *App) QuitFromExtension() {
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
// appropriate renderer. In non-interactive mode it falls back to stderr with
|
||||
// a level prefix so errors are distinguishable from plain output.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
@@ -795,8 +1032,16 @@ func (a *App) PrintFromExtension(level, text string) {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
// Non-interactive fallback: write to stderr with a level prefix so that
|
||||
// errors and info messages are distinguishable from plain output.
|
||||
switch level {
|
||||
case "error":
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
|
||||
default:
|
||||
fmt.Println(text)
|
||||
}
|
||||
}
|
||||
|
||||
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
|
||||
@@ -824,12 +1069,73 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// NotifyWidgetUpdate sends a WidgetUpdateEvent to the TUI so it re-renders
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
// prompt templates and skills from their provider callbacks. Called by file
|
||||
// watchers when .md/.txt files change in prompt or skill directories.
|
||||
// In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyContentReload() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(ContentReloadEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
|
||||
// tool names and MCP tool count from provider callbacks. Called when background
|
||||
// MCP tool loading completes. In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyMCPToolsReady() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPToolsReadyEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
|
||||
// display a system message when a single MCP server finishes loading. Called
|
||||
// per server as background MCP tool loading progresses.
|
||||
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPServerLoadedEvent{
|
||||
ServerName: serverName,
|
||||
ToolCount: toolCount,
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -917,48 +1223,131 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
// Non-interactive fallback: render a simple framed block to stderr so
|
||||
// it is visually distinct from plain stdout output.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// recordStepUsage applies token/cost usage reported for a completed step.
|
||||
// Step usage events arrive even when a turn is later cancelled, so this keeps
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
hasUsage, ev.InputTokens, ev.OutputTokens, ev.CacheReadTokens, ev.CacheWriteTokens)
|
||||
}
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
if stepUsageSeen != nil {
|
||||
stepUsageSeen.Store(true)
|
||||
}
|
||||
if a.opts.UsageTracker == nil {
|
||||
return
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(ev.InputTokens),
|
||||
int(ev.OutputTokens),
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
// configured UsageTracker. Called once per turn after the turn completes.
|
||||
//
|
||||
// Cost/token accumulation uses TotalUsage (sum across all tool-calling steps in
|
||||
// the turn). Context-window fill uses FinalUsage.InputTokens only — that is the
|
||||
// number of tokens sent to the model on the last API call, which equals the
|
||||
// actual context window occupation (all accumulated messages + tool results).
|
||||
// OutputTokens are not added here because they are the response length, not
|
||||
// context fill.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string) {
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage from the API response.
|
||||
//
|
||||
// NOTE: We only use ACTUAL token counts from API responses for cost tracking.
|
||||
// Estimation is never used for costs - only API-reported tokens are accurate.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Debug logging for token tracking
|
||||
if a.opts.Debug {
|
||||
if result.TotalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult TotalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: TotalUsage=nil")
|
||||
}
|
||||
if result.FinalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult FinalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.FinalUsage.InputTokens, result.FinalUsage.OutputTokens,
|
||||
result.FinalUsage.CacheReadTokens, result.FinalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: FinalUsage=nil")
|
||||
}
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: sawStepUsage=%v", sawStepUsage)
|
||||
}
|
||||
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// If sawStepUsage is true, totals were already updated via StepUsageEvent.
|
||||
// Check any token field > 0 (not just InputTokens) because cached prompts
|
||||
// can result in InputTokens=0 while OutputTokens>0 (OpenAI-compatible behavior).
|
||||
hasTotalUsage := result.TotalUsage != nil &&
|
||||
(result.TotalUsage.InputTokens > 0 ||
|
||||
result.TotalUsage.OutputTokens > 0 ||
|
||||
result.TotalUsage.CacheReadTokens > 0 ||
|
||||
result.TotalUsage.CacheCreationTokens > 0)
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: hasTotalUsage=%v", hasTotalUsage)
|
||||
}
|
||||
if !sawStepUsage && hasTotalUsage {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling UpdateUsage input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens: the input token count of the last API call
|
||||
// equals the number of tokens currently occupying the context window.
|
||||
// Adding OutputTokens would overstate fill since the response is not part
|
||||
// of the context that was *sent* to the model.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
// Calculate context fill from the LAST API call's usage. The context
|
||||
// window is filled by everything sent to and received from the model:
|
||||
//
|
||||
// InputTokens — non-cached input (may be small with prompt caching)
|
||||
// CacheReadTokens — input tokens served from cache
|
||||
// CacheCreationTokens — input tokens written to cache this call
|
||||
// OutputTokens — assistant output (becomes input next turn)
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can drop to near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. We must sum all four to
|
||||
// get the true context window utilization.
|
||||
//
|
||||
// We use FinalUsage (last step only), NOT TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps — and each step re-sends the full
|
||||
// conversation, so TotalUsage massively overstates the actual window fill.
|
||||
if result.FinalUsage != nil {
|
||||
u := result.FinalUsage
|
||||
contextFill := int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheCreate=%d + Output=%d)",
|
||||
contextFill, u.InputTokens, u.CacheReadTokens, u.CacheCreationTokens, u.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -14,6 +16,47 @@ import (
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type usageUpdaterStub struct {
|
||||
mu sync.Mutex
|
||||
|
||||
updateCalls int
|
||||
estimateCalls int
|
||||
contextCalls int
|
||||
|
||||
lastUpdateInput int
|
||||
lastUpdateOutput int
|
||||
lastUpdateCacheRead int
|
||||
lastUpdateCacheWrite int
|
||||
lastContextTokens int
|
||||
lastEstimateInput string
|
||||
lastEstimateOutput string
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) UpdateUsage(inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.updateCalls++
|
||||
s.lastUpdateInput = inputTokens
|
||||
s.lastUpdateOutput = outputTokens
|
||||
s.lastUpdateCacheRead = cacheReadTokens
|
||||
s.lastUpdateCacheWrite = cacheWriteTokens
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.estimateCalls++
|
||||
s.lastEstimateInput = inputText
|
||||
s.lastEstimateOutput = outputText
|
||||
}
|
||||
|
||||
func (s *usageUpdaterStub) SetContextTokens(tokens int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.contextCalls++
|
||||
s.lastContextTokens = tokens
|
||||
}
|
||||
|
||||
// turnResult builds a minimal TurnResult with response text t.
|
||||
func turnResult(t string) *kit.TurnResult {
|
||||
return &kit.TurnResult{Response: t}
|
||||
@@ -489,3 +532,230 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
t.Fatalf("expected 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.recordStepUsage(kit.StepUsageEvent{
|
||||
InputTokens: 120,
|
||||
OutputTokens: 45,
|
||||
CacheReadTokens: 5,
|
||||
CacheWriteTokens: 2,
|
||||
}, nil)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 120 || usage.lastUpdateOutput != 45 || usage.lastUpdateCacheRead != 5 || usage.lastUpdateCacheWrite != 2 {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen ensures we avoid
|
||||
// double-counting totals once StepUsageEvent-based updates were already applied.
|
||||
func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 0 {
|
||||
t.Fatalf("expected no total usage update when sawStepUsage=true, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
// Context tokens should be InputTokens only (456)
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456 (InputTokens only), got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero verifies that usage
|
||||
// is recorded when InputTokens=0 but OutputTokens>0 (OpenAI-compatible cache behavior).
|
||||
func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call when InputTokens=0 but OutputTokens>0, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 0 || usage.lastUpdateOutput != 150 {
|
||||
t.Fatalf("expected input=0 output=150, got input=%d output=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput)
|
||||
}
|
||||
if usage.lastUpdateCacheRead != 500 {
|
||||
t.Fatalf("expected cache_read=500, got %d", usage.lastUpdateCacheRead)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories verifies that
|
||||
// context window fill uses all token categories from the final API call:
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3,
|
||||
OutputTokens: 5,
|
||||
CacheReadTokens: 0,
|
||||
CacheCreationTokens: 4317,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3, // Non-cached input (small with caching)
|
||||
OutputTokens: 5, // Assistant output
|
||||
CacheReadTokens: 0, // No cache reads on first call
|
||||
CacheCreationTokens: 4317, // System prompt + tools written to cache
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be Input + CacheRead + CacheCreate + Output = 4325
|
||||
expected := 3 + 0 + 4317 + 5
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != expected {
|
||||
t.Fatalf("expected context tokens=%d (all categories), got calls=%d tokens=%d",
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_LengthEmitsWarning verifies that when the SDK reports a
|
||||
// FinishReasonLength (max_output_tokens hit), the app surfaces a user-visible
|
||||
// ExtensionPrintEvent with Level="info" so the TUI can render a banner
|
||||
// instead of silently showing a truncated reply.
|
||||
func TestHandleTurnEnd_LengthEmitsWarning(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
var mu sync.Mutex
|
||||
var received []tea.Msg
|
||||
sendFn := func(m tea.Msg) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
received = append(received, m)
|
||||
}
|
||||
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, sendFn)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(received) != 1 {
|
||||
t.Fatalf("expected 1 event on length stop, got %d", len(received))
|
||||
}
|
||||
ev, ok := received[0].(ExtensionPrintEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected ExtensionPrintEvent, got %T", received[0])
|
||||
}
|
||||
if ev.Level != "info" {
|
||||
t.Errorf("expected Level=info, got %q", ev.Level)
|
||||
}
|
||||
if ev.Text == "" {
|
||||
t.Error("expected non-empty warning text")
|
||||
}
|
||||
if !strings.Contains(ev.Text, "max_output_tokens") {
|
||||
t.Errorf("warning text should mention max_output_tokens, got: %s", ev.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_NonLengthIgnored verifies that ordinary stop reasons
|
||||
// (stop, tool-calls, error, unknown, "") do not produce a warning banner.
|
||||
func TestHandleTurnEnd_NonLengthIgnored(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
reasons := []string{
|
||||
kit.FinishReasonStop,
|
||||
kit.FinishReasonToolCalls,
|
||||
kit.FinishReasonError,
|
||||
kit.FinishReasonContentFilter,
|
||||
kit.FinishReasonOther,
|
||||
kit.FinishReasonUnknown,
|
||||
"",
|
||||
}
|
||||
for _, r := range reasons {
|
||||
var called bool
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: r}, func(m tea.Msg) {
|
||||
called = true
|
||||
})
|
||||
if called {
|
||||
t.Errorf("stop reason %q unexpectedly emitted a warning", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleTurnEnd_NilSendFn guards against panics when no TUI listener is
|
||||
// attached (e.g. early init or headless teardown).
|
||||
func TestHandleTurnEnd_NilSendFn(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Should not panic with a nil sendFn.
|
||||
app.handleTurnEnd(kit.TurnEndEvent{StopReason: kit.FinishReasonLength}, nil)
|
||||
}
|
||||
|
||||
// TestFormatMaxTokensTruncatedMessage_NoKit verifies the fallback message
|
||||
// when Options.Kit is nil (test/stub path).
|
||||
func TestFormatMaxTokensTruncatedMessage_NoKit(t *testing.T) {
|
||||
app := New(Options{}, nil)
|
||||
defer app.Close()
|
||||
|
||||
msg := app.formatMaxTokensTruncatedMessage()
|
||||
if msg == "" {
|
||||
t.Fatal("expected non-empty fallback message")
|
||||
}
|
||||
for _, needle := range []string{"max_output_tokens", "--max-tokens", "KIT_MAX_TOKENS", "modelSettings"} {
|
||||
if !strings.Contains(msg, needle) {
|
||||
t.Errorf("fallback message missing %q:\n%s", needle, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+45
-3
@@ -1,6 +1,6 @@
|
||||
package app
|
||||
|
||||
import "charm.land/fantasy"
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
// StreamChunkEvent is sent by the app layer when a streaming text delta arrives
|
||||
// from the LLM. Each chunk contains an incremental portion of the response.
|
||||
@@ -16,6 +16,11 @@ type ReasoningChunkEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
// ReasoningCompleteEvent is sent when reasoning/thinking is finished, after
|
||||
// the last reasoning token has been processed. The TUI uses this to freeze
|
||||
// the reasoning duration counter.
|
||||
type ReasoningCompleteEvent struct{}
|
||||
|
||||
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
|
||||
// It carries the tool name and its arguments for display purposes.
|
||||
type ToolCallStartedEvent struct {
|
||||
@@ -74,6 +79,24 @@ type ToolCallContentEvent struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
// PasswordPromptEvent is sent when a sudo command needs a password.
|
||||
// The TUI should display a password prompt overlay and send the result back.
|
||||
type PasswordPromptEvent struct {
|
||||
// Prompt is the message to display to the user.
|
||||
Prompt string
|
||||
// ResponseCh receives the password from the TUI.
|
||||
// The TUI must send exactly one value.
|
||||
ResponseCh chan<- PasswordPromptResponse
|
||||
}
|
||||
|
||||
// PasswordPromptResponse carries the user's password input.
|
||||
type PasswordPromptResponse struct {
|
||||
// Password is the entered password.
|
||||
Password string
|
||||
// Cancelled is true if the user cancelled the prompt.
|
||||
Cancelled bool
|
||||
}
|
||||
|
||||
// ResponseCompleteEvent is sent when the LLM produces a final (non-streaming) response.
|
||||
// In streaming mode, this may be empty if all content was delivered via StreamChunkEvents.
|
||||
type ResponseCompleteEvent struct {
|
||||
@@ -118,8 +141,8 @@ type SpinnerEvent struct {
|
||||
// MessageCreatedEvent is sent when a new message is added to the message store.
|
||||
// This allows the TUI to stay in sync with the conversation history.
|
||||
type MessageCreatedEvent struct {
|
||||
// Message is the fantasy message that was added to the store.
|
||||
Message fantasy.Message
|
||||
// Message is the message that was added to the store.
|
||||
Message kit.LLMMessage
|
||||
}
|
||||
|
||||
// CompactCompleteEvent is sent when a /compact operation finishes successfully.
|
||||
@@ -162,6 +185,25 @@ type ModelChangedEvent struct {
|
||||
// from its WidgetProvider on the next render cycle.
|
||||
type WidgetUpdateEvent struct{}
|
||||
|
||||
// ContentReloadEvent is sent when prompt templates or skills are reloaded
|
||||
// from disk (e.g. by a file watcher detecting changes). The TUI refreshes
|
||||
// its autocomplete entries and internal state from the provider callbacks.
|
||||
type ContentReloadEvent struct{}
|
||||
|
||||
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
|
||||
// The TUI refreshes its tool names and MCP tool count from provider callbacks
|
||||
// so that /tools and the startup info bar reflect the loaded MCP tools.
|
||||
type MCPToolsReadyEvent struct{}
|
||||
|
||||
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
|
||||
// (successfully or with error). The TUI displays a system message so users
|
||||
// see real-time progress as each server initializes.
|
||||
type MCPServerLoadedEvent struct {
|
||||
ServerName string
|
||||
ToolCount int
|
||||
Error error // nil on success
|
||||
}
|
||||
|
||||
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
|
||||
// pre-fill the input editor with text. The TUI handles this by setting the
|
||||
// textarea content and moving the cursor to the end.
|
||||
|
||||
@@ -3,14 +3,14 @@ package app
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// MessageStore is a thread-safe in-memory store for the conversation history.
|
||||
// On-disk persistence is handled by the TreeManager at the app/SDK layer.
|
||||
type MessageStore struct {
|
||||
mu sync.RWMutex
|
||||
messages []fantasy.Message
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
@@ -20,14 +20,14 @@ func NewMessageStore() *MessageStore {
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []fantasy.Message) *MessageStore {
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
return &MessageStore{messages: cp}
|
||||
}
|
||||
|
||||
// Add appends a single message to the store.
|
||||
func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
func (s *MessageStore) Add(msg kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, msg)
|
||||
@@ -36,22 +36,22 @@ func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
// Replace replaces the entire message history with the given slice. This is
|
||||
// used after an agent step returns the full updated conversation (including
|
||||
// tool calls and results).
|
||||
func (s *MessageStore) Replace(msgs []fantasy.Message) {
|
||||
func (s *MessageStore) Replace(msgs []kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
s.messages = cp
|
||||
}
|
||||
|
||||
// GetAll returns a snapshot copy of the current message slice.
|
||||
// The returned slice is safe to modify without affecting the store.
|
||||
func (s *MessageStore) GetAll() []fantasy.Message {
|
||||
func (s *MessageStore) GetAll() []kit.LLMMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(s.messages))
|
||||
cp := make([]kit.LLMMessage, len(s.messages))
|
||||
copy(cp, s.messages)
|
||||
return cp
|
||||
}
|
||||
|
||||
@@ -3,17 +3,27 @@ package app
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// makeTextMsg builds a minimal fantasy.Message with a single TextPart.
|
||||
func makeTextMsg(role, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRole(role),
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
|
||||
// makeTextMsg builds a minimal kit.LLMMessage with the given role and text.
|
||||
func makeTextMsg(role, text string) kit.LLMMessage {
|
||||
return kit.LLMMessage{
|
||||
Role: kit.LLMMessageRole(role),
|
||||
Content: []kit.LLMMessagePart{kit.LLMTextPart{Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
// textOf extracts the plain text from an LLMMessage for assertions.
|
||||
func textOf(msg kit.LLMMessage) string {
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(kit.LLMTextPart); ok {
|
||||
return tp.Text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// NewMessageStore / NewMessageStoreWithMessages
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -29,7 +39,7 @@ func TestNewMessageStore_empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
msgs := []fantasy.Message{
|
||||
msgs := []kit.LLMMessage{
|
||||
makeTextMsg("user", "hello"),
|
||||
makeTextMsg("assistant", "hi"),
|
||||
}
|
||||
@@ -42,7 +52,7 @@ func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
// NewMessageStoreWithMessages must deep-copy the slice so that external
|
||||
// modifications don't affect the store.
|
||||
func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
msgs := []fantasy.Message{makeTextMsg("user", "hello")}
|
||||
msgs := []kit.LLMMessage{makeTextMsg("user", "hello")}
|
||||
s := NewMessageStoreWithMessages(msgs)
|
||||
|
||||
// Mutate the source slice.
|
||||
@@ -52,9 +62,8 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
tp, ok := got[0].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", tp.Text)
|
||||
if textOf(got[0]) != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,9 +89,8 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
}
|
||||
got := s.GetAll()
|
||||
for i, expected := range texts {
|
||||
tp, ok := got[i].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, tp.Text)
|
||||
if textOf(got[i]) != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, textOf(got[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,7 +103,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []fantasy.Message{
|
||||
replacement := []kit.LLMMessage{
|
||||
makeTextMsg("user", "new1"),
|
||||
makeTextMsg("assistant", "new2"),
|
||||
}
|
||||
@@ -105,25 +113,22 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
t.Fatalf("expected 2 messages after replace, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp0, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
tp1, _ := got[1].Content[0].(fantasy.TextPart)
|
||||
if tp0.Text != "new1" || tp1.Text != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", tp0.Text, tp1.Text)
|
||||
if textOf(got[0]) != "new1" || textOf(got[1]) != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", textOf(got[0]), textOf(got[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
replacement := []fantasy.Message{makeTextMsg("user", "original")}
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
replacement[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", tp.Text)
|
||||
if textOf(got[0]) != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,9 +145,8 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
got[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
internal := s.GetAll()
|
||||
tp, _ := internal[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", tp.Text)
|
||||
if textOf(internal[0]) != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", textOf(internal[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +183,8 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
t.Fatalf("expected 1 message after Clear+Add, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", tp.Text)
|
||||
if textOf(got[0]) != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ type UsageUpdater interface {
|
||||
// the provider does not return exact counts.
|
||||
EstimateAndUpdateUsage(inputText, outputText string)
|
||||
// SetContextTokens records the approximate current context window fill
|
||||
// level. This should be the final API call's input+output tokens (from
|
||||
// FinalResponse.Usage), NOT the aggregate TotalUsage.
|
||||
// level. This should be the sum of ALL token categories from the last
|
||||
// API call: InputTokens + CacheReadTokens + CacheCreationTokens +
|
||||
// OutputTokens. With Anthropic prompt caching, InputTokens can be
|
||||
// near-zero while CacheReadTokens holds the bulk of the context.
|
||||
SetContextTokens(tokens int)
|
||||
}
|
||||
|
||||
@@ -67,10 +69,6 @@ type Options struct {
|
||||
// Debug enables verbose debug logging.
|
||||
Debug bool
|
||||
|
||||
// CompactMode selects the compact renderer instead of the block renderer for
|
||||
// message formatting.
|
||||
CompactMode bool
|
||||
|
||||
// UsageTracker is an optional callback for recording token usage after each
|
||||
// agent step. When non-nil, the app layer calls UpdateUsage (or
|
||||
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
|
||||
|
||||
@@ -43,13 +43,30 @@ type OpenAICredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
func oauthTokenExpired(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= expiresAt
|
||||
}
|
||||
|
||||
// oauthTokenNeedsRefresh reports whether an OAuth token will expire within the
|
||||
// next 5 minutes, allowing proactive refresh before it becomes invalid.
|
||||
// Returns false for API key credentials or when no expiry is set.
|
||||
func oauthTokenNeedsRefresh(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (expiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -57,19 +74,13 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -77,10 +88,7 @@ func (c *OpenAICredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
|
||||
@@ -428,6 +428,10 @@ type PreviousCompaction struct {
|
||||
ModifiedFiles []string
|
||||
}
|
||||
|
||||
// StreamCallback is called for each chunk of text during streaming compaction.
|
||||
// Return a non-nil error to cancel the stream.
|
||||
type StreamCallback func(delta string) error
|
||||
|
||||
// Compact summarises older messages using the LLM, returning the compaction
|
||||
// result and a new message slice (summary message + preserved recent
|
||||
// messages).
|
||||
@@ -442,6 +446,8 @@ type PreviousCompaction struct {
|
||||
//
|
||||
// prev carries file tracking from a previous compaction for cumulative
|
||||
// tracking. Pass nil if there is no prior compaction.
|
||||
// onChunk is an optional callback for streaming summary text. Pass nil for
|
||||
// non-streaming compaction.
|
||||
func Compact(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -449,6 +455,7 @@ func Compact(
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
prev *PreviousCompaction,
|
||||
onChunk StreamCallback,
|
||||
) (*CompactionResult, []fantasy.Message, error) {
|
||||
opts.defaults()
|
||||
|
||||
@@ -487,9 +494,9 @@ func Compact(
|
||||
var err error
|
||||
|
||||
if IsSplitTurn(messages, cutPoint) {
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions)
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions, onChunk)
|
||||
} else {
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions)
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions, onChunk)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -527,15 +534,17 @@ func Compact(
|
||||
}
|
||||
|
||||
// compactNormal generates a summary for a clean turn-boundary cut.
|
||||
// If onChunk is provided, text deltas are streamed to it.
|
||||
func compactNormal(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
oldMessages []fantasy.Message,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
conversationText := serializeMessages(oldMessages)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions, onChunk)
|
||||
}
|
||||
|
||||
// compactSplitTurn handles the case where the cut point lands mid-turn.
|
||||
@@ -546,6 +555,7 @@ func compactNormal(
|
||||
//
|
||||
// The merged result preserves context from both the older history and the
|
||||
// beginning of the current long turn.
|
||||
// If onChunk is provided, both summaries and the separator are streamed.
|
||||
func compactSplitTurn(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -554,6 +564,7 @@ func compactSplitTurn(
|
||||
cutPoint int,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
// Find where the split turn starts.
|
||||
turnStart := findTurnStart(allMessages, cutPoint)
|
||||
@@ -573,12 +584,19 @@ func compactSplitTurn(
|
||||
// Generate history summary if there are complete turns before the split.
|
||||
if len(historyMessages) >= 2 {
|
||||
historySummary, err = generateSummary(ctx, model,
|
||||
serializeMessages(historyMessages), opts, "")
|
||||
serializeMessages(historyMessages), opts, "", onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn history summary failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the separator between history and turn prefix summaries.
|
||||
if onChunk != nil && historySummary != "" {
|
||||
if err := onChunk("\n\n---\n\n## Current Turn (in progress)\n\n"); err != nil {
|
||||
return "", fmt.Errorf("streaming separator failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate turn prefix summary.
|
||||
turnPrefixText := serializeMessages(turnPrefixMessages)
|
||||
turnPrefixPrompt := "The messages above are the BEGINNING of a long turn that was split. " +
|
||||
@@ -588,16 +606,10 @@ func compactSplitTurn(
|
||||
turnPrefixPrompt += "\n\nAdditional instructions: " + customInstructions
|
||||
}
|
||||
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: turnPrefixText + "\n\n" + turnPrefixPrompt,
|
||||
})
|
||||
turnPrefixSummary, err := generateSummary(ctx, model, turnPrefixText, opts, turnPrefixPrompt, onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn prefix summary failed: %w", err)
|
||||
}
|
||||
turnPrefixSummary := result.Response.Content.Text()
|
||||
|
||||
// Merge the two summaries.
|
||||
if historySummary != "" && turnPrefixSummary != "" {
|
||||
@@ -610,12 +622,14 @@ func compactSplitTurn(
|
||||
}
|
||||
|
||||
// generateSummary calls the LLM to produce a structured summary.
|
||||
// If onChunk is provided, the summary is streamed using Agent.Stream().
|
||||
func generateSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
conversationText string,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
userPrompt := opts.SummaryPrompt
|
||||
if userPrompt == "" {
|
||||
@@ -628,8 +642,31 @@ func generateSummary(
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
|
||||
prompt := conversationText + "\n\n" + userPrompt
|
||||
|
||||
// Use streaming if onChunk is provided.
|
||||
if onChunk != nil {
|
||||
var fullText strings.Builder
|
||||
_, err := summaryAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
OnTextDelta: func(_, delta string) error {
|
||||
if delta != "" {
|
||||
fullText.WriteString(delta)
|
||||
return onChunk(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation (streaming) failed: %w", err)
|
||||
}
|
||||
return fullText.String(), nil
|
||||
}
|
||||
|
||||
// Non-streaming path.
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: conversationText + "\n\n" + userPrompt,
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation failed: %w", err)
|
||||
|
||||
@@ -243,7 +243,7 @@ func TestCompact_TooFewMessages(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleUser, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -262,7 +262,7 @@ func TestCompact_WithinBudget(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleAssistant, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
+115
-17
@@ -22,6 +22,20 @@ type MCPServerConfig struct {
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
|
||||
// OAuth configuration for remote servers that don't support dynamic
|
||||
// client registration (e.g. GitHub). When OAuthClientID is set, it is
|
||||
// passed directly to the transport's OAuthConfig instead of relying on
|
||||
// dynamic registration.
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
|
||||
// InProcessServer holds a live *server.MCPServer for in-process transport.
|
||||
// When set (and Type is "inprocess"), the connection pool creates an
|
||||
// in-process client instead of spawning a subprocess or making HTTP calls.
|
||||
// This field is never serialized — it is only used programmatically via the SDK.
|
||||
InProcessServer any `json:"-" yaml:"-"`
|
||||
|
||||
// Legacy fields for backward compatibility
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
@@ -35,13 +49,16 @@ type MCPServerConfig struct {
|
||||
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
// First try to unmarshal as the new format
|
||||
type newFormat struct {
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -66,6 +83,9 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = newConfig.Headers
|
||||
s.AllowedTools = newConfig.AllowedTools
|
||||
s.ExcludedTools = newConfig.ExcludedTools
|
||||
s.OAuthClientID = newConfig.OAuthClientID
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,11 +177,28 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParams defines generation parameter defaults that can be attached
|
||||
// to individual models. These act as model-level defaults — CLI flags and
|
||||
// global config values take precedence when explicitly set.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
@@ -169,6 +206,11 @@ type CustomModelConfig struct {
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
|
||||
// Generation parameter defaults for this model.
|
||||
// These are applied when the user hasn't explicitly set the corresponding
|
||||
// CLI flag or global config value.
|
||||
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
@@ -191,18 +233,19 @@ type Config struct {
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
MaxSteps int `json:"max-steps,omitempty" yaml:"max-steps,omitempty"`
|
||||
Debug bool `json:"debug,omitempty" yaml:"debug,omitempty"`
|
||||
Compact bool `json:"compact,omitempty" yaml:"compact,omitempty"`
|
||||
SystemPrompt string `json:"system-prompt,omitempty" yaml:"system-prompt,omitempty"`
|
||||
ProviderAPIKey string `json:"provider-api-key,omitempty" yaml:"provider-api-key,omitempty"`
|
||||
ProviderURL string `json:"provider-url,omitempty" yaml:"provider-url,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
|
||||
// Thinking / extended reasoning
|
||||
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
|
||||
@@ -216,6 +259,12 @@ type Config struct {
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
|
||||
// Per-model generation parameter overrides. Keys are "provider/model" strings
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
|
||||
// settings act as model-level defaults — CLI flags and global config values
|
||||
// take precedence when explicitly set.
|
||||
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
@@ -234,11 +283,18 @@ func (s *MCPServerConfig) GetTransportType() string {
|
||||
return "stdio"
|
||||
case "remote":
|
||||
return "streamable"
|
||||
case "inprocess":
|
||||
return "inprocess"
|
||||
default:
|
||||
return s.Type
|
||||
}
|
||||
}
|
||||
|
||||
// Programmatic in-process server detection.
|
||||
if s.InProcessServer != nil {
|
||||
return "inprocess"
|
||||
}
|
||||
|
||||
// Backward compatibility: infer transport type
|
||||
if len(s.Command) > 0 {
|
||||
return "stdio"
|
||||
@@ -269,8 +325,12 @@ func (c *Config) Validate() error {
|
||||
if serverConfig.URL == "" {
|
||||
return fmt.Errorf("server %s: url is required for %s transport", serverName, transport)
|
||||
}
|
||||
case "inprocess":
|
||||
if serverConfig.InProcessServer == nil {
|
||||
return fmt.Errorf("server %s: InProcessServer is required for inprocess transport", serverName)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable", serverName, transport)
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable, inprocess", serverName, transport)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -364,16 +424,55 @@ mcpServers:
|
||||
# debug: false # Enable debug logging
|
||||
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
|
||||
|
||||
# Model generation parameters (all optional)
|
||||
# Model generation parameters (all optional, apply globally to all models)
|
||||
# max-tokens: 4096 # Maximum tokens in response
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
# top-k: 40 # Top K sampling
|
||||
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# Per-model generation parameter overrides (apply to specific models)
|
||||
# These act as model-level defaults — CLI flags and global settings above take precedence.
|
||||
# Keys are "provider/model" strings matching the model you use.
|
||||
# modelSettings:
|
||||
# anthropic/claude-sonnet-4-5-20250929:
|
||||
# temperature: 0.3
|
||||
# maxTokens: 8192
|
||||
# openai/gpt-4o:
|
||||
# temperature: 0.7
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# frequencyPenalty: 0.1
|
||||
# presencePenalty: 0.1
|
||||
# anthropic/claude-opus-4-6:
|
||||
# thinkingLevel: "high"
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
# Custom model definitions (under custom/ provider)
|
||||
# customModels:
|
||||
# my-local-llama:
|
||||
# name: "Local Llama 3"
|
||||
# baseUrl: "http://localhost:8080/v1"
|
||||
# family: "llama"
|
||||
# temperature: true
|
||||
# cost:
|
||||
# input: 0.0
|
||||
# output: 0.0
|
||||
# limit:
|
||||
# context: 131072
|
||||
# output: 8192
|
||||
# params: # Generation parameter defaults for this model
|
||||
# temperature: 0.8
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# systemPrompt: "You are a helpful local assistant."
|
||||
`
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
@@ -403,10 +502,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filepath.Join(home, absPath[2:])
|
||||
absPath = filepath.Join(home, absPath[2:])
|
||||
}
|
||||
if !filepath.IsAbs(absPath) {
|
||||
// base := GetConfigPath()
|
||||
base := configPath
|
||||
if base == "" {
|
||||
fmt.Fprintf(os.Stderr, "unable to build relative path to config.")
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMCPServerConfig_NewFormat(t *testing.T) {
|
||||
@@ -542,3 +544,86 @@ func TestEnsureConfigExistsWhenFileExists(t *testing.T) {
|
||||
t.Error("Existing config file was modified when it shouldn't have been")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_JSON(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://api.githubcopilot.com/mcp/",
|
||||
"oauthClientId": "Ov23liXXXXXXXXXXXXXX",
|
||||
"oauthClientSecret": "secret123",
|
||||
"oauthScopes": ["read:user", "repo"]
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.URL != "https://api.githubcopilot.com/mcp/" {
|
||||
t.Errorf("Expected URL, got %q", cfg.URL)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "secret123" {
|
||||
t.Errorf("Expected OAuthClientSecret 'secret123', got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_YAML(t *testing.T) {
|
||||
yamlData := `
|
||||
type: remote
|
||||
url: https://api.githubcopilot.com/mcp/
|
||||
oauthClientId: "Ov23liXXXXXXXXXXXXXX"
|
||||
oauthScopes:
|
||||
- read:user
|
||||
- repo
|
||||
`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal YAML: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
// Verify that omitting OAuth fields still works (backward compat).
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://example.com/mcp"
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.OAuthClientID != "" {
|
||||
t.Errorf("Expected empty OAuthClientID, got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "" {
|
||||
t.Errorf("Expected empty OAuthClientSecret, got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 0 {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
+181
-24
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,10 +19,18 @@ import (
|
||||
// It receives tool call ID, tool name, output chunk, and whether it's stderr.
|
||||
type ToolOutputCallback func(toolCallID, toolName, chunk string, isStderr bool)
|
||||
|
||||
// PasswordPromptCallback is the signature for password prompts.
|
||||
// It receives a prompt message and returns the password and whether it was cancelled.
|
||||
type PasswordPromptCallback func(prompt string) (password string, cancelled bool)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions.
|
||||
type contextKey string
|
||||
|
||||
const toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
const (
|
||||
toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
sudoPasswordKey contextKey = "sudoPassword"
|
||||
passwordPromptKey contextKey = "passwordPrompt"
|
||||
)
|
||||
|
||||
// ContextWithToolOutputCallback returns a new context with the tool output callback set.
|
||||
func ContextWithToolOutputCallback(ctx context.Context, callback ToolOutputCallback) context.Context {
|
||||
@@ -36,23 +45,39 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithPasswordPrompt returns a new context with the password prompt callback set.
|
||||
// This allows the TUI to show a modal password prompt when sudo needs a password.
|
||||
func ContextWithPasswordPrompt(ctx context.Context, callback PasswordPromptCallback) context.Context {
|
||||
return context.WithValue(ctx, passwordPromptKey, callback)
|
||||
}
|
||||
|
||||
// passwordPromptFromContext retrieves the password prompt callback from context.
|
||||
func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
if cb, ok := ctx.Value(passwordPromptKey).(PasswordPromptCallback); ok {
|
||||
return cb
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
return pw
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
var bannedCommands = []string{
|
||||
"alias ", "bg ", "bind ", "builtin ",
|
||||
"caller ", "command ", "compgen ",
|
||||
"complete ", "compopt ", "coproc ",
|
||||
"dirs ", "disown ", "enable ",
|
||||
"fc ", "fg ", "hash ", "help ",
|
||||
"history ", "jobs ", "kill ",
|
||||
"logout ", "mapfile ", "popd ",
|
||||
"pushd ", "readonly ", "select ",
|
||||
"set ", "shopt ", "source ",
|
||||
"suspend ", "times ", "trap ",
|
||||
"type ", "typeset ", "ulimit ",
|
||||
"umask ", "unalias ", "wait ",
|
||||
}
|
||||
// bannedCmdRe matches bash builtin commands that are not allowed for security reasons.
|
||||
var bannedCmdRe = regexp.MustCompile(`^(alias|bg|bind|builtin|caller|command|compgen|complete|compopt|coproc|dirs|disown|enable|fc|fg|hash|help|history|jobs|kill|logout|mapfile|popd|pushd|readonly|select|set|shopt|source|suspend|times|trap|type|typeset|ulimit|umask|unalias|wait)\s`)
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
@@ -84,6 +109,66 @@ func NewBashTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// sudoCommandRe matches sudo commands that need to be rewritten for -S mode.
|
||||
// It matches "sudo" as a word boundary, optionally preceded by environment variables.
|
||||
var sudoCommandRe = regexp.MustCompile(`(?i)(^|[&|;|]|\|\||&&)\s*(\w+=\S+\s+)?\bsudo\b`)
|
||||
|
||||
// truncateCommand truncates a long command for display.
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// rewriteSudoForStdin rewrites sudo commands to use -S -p ” for stdin password input.
|
||||
// It transforms: sudo cmd → sudo -S -p ” cmd
|
||||
func rewriteSudoForStdin(command string) string {
|
||||
// Find all matches and their positions
|
||||
matches := sudoCommandRe.FindAllStringIndex(command, -1)
|
||||
if matches == nil {
|
||||
return command
|
||||
}
|
||||
|
||||
// Build result from end to start to preserve indices
|
||||
result := command
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
match := matches[i]
|
||||
start, end := match[0], match[1]
|
||||
matchedText := result[start:end]
|
||||
|
||||
// Extract just the "sudo" part (after any prefix)
|
||||
sudoIdx := strings.Index(strings.ToLower(matchedText), "sudo")
|
||||
if sudoIdx == -1 {
|
||||
continue
|
||||
}
|
||||
prefix := matchedText[:sudoIdx]
|
||||
sudoPart := matchedText[sudoIdx:]
|
||||
|
||||
// Check if the text immediately after "sudo" in the result contains -S
|
||||
afterSudo := result[end:]
|
||||
if strings.HasPrefix(strings.TrimLeft(afterSudo, " \t"), "-S") {
|
||||
// Already has -S flag, skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert -S -p '' after "sudo"
|
||||
newSudo := strings.Replace(sudoPart, "sudo", "sudo -S -p ''", 1)
|
||||
result = result[:start] + prefix + newSudo + result[end:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
|
||||
// This is stored in tool response metadata to signal the TUI to prompt for password.
|
||||
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
|
||||
|
||||
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
|
||||
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
|
||||
return resp.Metadata == SudoPasswordRequiredMetadata
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
@@ -94,10 +179,8 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// Check for banned commands
|
||||
for _, banned := range bannedCommands {
|
||||
if strings.HasPrefix(args.Command, banned) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
if bannedCmdRe.MatchString(args.Command) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
|
||||
// Determine timeout
|
||||
@@ -110,7 +193,47 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", args.Command)
|
||||
// Check for sudo password in context or environment
|
||||
sudoPassword := sudoPasswordFromContext(ctx)
|
||||
if sudoPassword == "" {
|
||||
sudoPassword = os.Getenv("SUDO_PASSWORD")
|
||||
}
|
||||
command := args.Command
|
||||
|
||||
// If command contains sudo and we don't have a password, check if sudo needs one
|
||||
if sudoPassword == "" && sudoCommandRe.MatchString(command) {
|
||||
// Check if sudo credentials are cached using sudo -n (non-interactive)
|
||||
testCmd := exec.CommandContext(cmdCtx, "sudo", "-n", "true")
|
||||
testCmd.Dir = workDir
|
||||
if err := testCmd.Run(); err != nil {
|
||||
// Sudo needs a password - try to prompt via callback
|
||||
if promptCallback := passwordPromptFromContext(ctx); promptCallback != nil {
|
||||
pw, cancelled := promptCallback("Sudo password required for: " + truncateCommand(args.Command, 60))
|
||||
if cancelled {
|
||||
return fantasy.NewTextErrorResponse("sudo password prompt cancelled"), nil
|
||||
}
|
||||
if pw == "" {
|
||||
return fantasy.NewTextErrorResponse("no sudo password provided"), nil
|
||||
}
|
||||
sudoPassword = pw
|
||||
command = rewriteSudoForStdin(command)
|
||||
} else {
|
||||
// No callback available - return error with helpful message
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"This command requires sudo access. " +
|
||||
"Please run 'sudo -v' in your terminal first to cache credentials, " +
|
||||
"or set the SUDO_PASSWORD environment variable."), nil
|
||||
}
|
||||
}
|
||||
// Credentials are cached or password was provided, proceed
|
||||
}
|
||||
|
||||
// If we have a sudo password, rewrite the command to use sudo -S
|
||||
if sudoPassword != "" && sudoCommandRe.MatchString(command) {
|
||||
command = rewriteSudoForStdin(command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", command)
|
||||
if workDir != "" {
|
||||
cmd.Dir = workDir
|
||||
}
|
||||
@@ -128,18 +251,18 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
if outputCallback != nil {
|
||||
// Streaming mode: use pipes to capture output as it arrives
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback)
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback, sudoPassword)
|
||||
}
|
||||
|
||||
// Non-streaming mode: collect all output at once (original behavior)
|
||||
return executeBashBuffered(cmdCtx, call, cmd)
|
||||
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd) (fantasy.ToolResponse, error) {
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -149,10 +272,27 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe and write the password
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
@@ -194,7 +334,7 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
}
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback) (fantasy.ToolResponse, error) {
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -204,11 +344,28 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
@@ -127,3 +127,72 @@ func TestBash_EmptyCommand(t *testing.T) {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteSudoForStdin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple sudo",
|
||||
input: "sudo apt update",
|
||||
expected: "sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo with env var",
|
||||
input: "DEBIAN_FRONTEND=noninteractive sudo apt update",
|
||||
expected: "DEBIAN_FRONTEND=noninteractive sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo in pipeline",
|
||||
input: "echo test | sudo tee /etc/test.conf",
|
||||
expected: "echo test | sudo -S -p '' tee /etc/test.conf",
|
||||
},
|
||||
{
|
||||
name: "sudo after &&",
|
||||
input: "apt update && sudo apt upgrade",
|
||||
expected: "apt update && sudo -S -p '' apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "already has -S flag",
|
||||
input: "sudo -S apt update",
|
||||
expected: "sudo -S apt update",
|
||||
},
|
||||
{
|
||||
name: "no sudo",
|
||||
input: "apt update && apt upgrade",
|
||||
expected: "apt update && apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "sudo in string (should not match)",
|
||||
input: "echo 'use sudo carefully'",
|
||||
expected: "echo 'use sudo carefully'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := rewriteSudoForStdin(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("rewriteSudoForStdin(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
}
|
||||
|
||||
// Test without password
|
||||
ctx = context.Background()
|
||||
pw = sudoPasswordFromContext(ctx)
|
||||
if pw != "" {
|
||||
t.Errorf("expected empty password, got %q", pw)
|
||||
}
|
||||
}
|
||||
|
||||
+234
-44
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -13,19 +14,45 @@ import (
|
||||
udiff "github.com/aymanbagabas/go-udiff"
|
||||
)
|
||||
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
// Edit represents a single replacement in a multi-edit operation.
|
||||
type Edit struct {
|
||||
OldText string `json:"old_text"`
|
||||
NewText string `json:"new_text"`
|
||||
}
|
||||
|
||||
// editArgs holds the arguments for the edit tool.
|
||||
// Supports both single-edit mode (old_text/new_text) and multi-edit mode (edits array).
|
||||
type editArgs struct {
|
||||
Path string `json:"path"`
|
||||
OldText string `json:"old_text"` // Single-edit mode
|
||||
NewText string `json:"new_text"` // Single-edit mode
|
||||
Edits []Edit `json:"edits"` // Multi-edit mode
|
||||
}
|
||||
|
||||
// replacement represents a normalized edit ready for processing.
|
||||
type replacement struct {
|
||||
oldText string // normalized old text for matching
|
||||
newText string // normalized new text
|
||||
originalOld string // original old text for metadata
|
||||
originalNew string // original new text for metadata
|
||||
index int // index in the original edits array (for error messages)
|
||||
}
|
||||
|
||||
// matchedReplacement represents a replacement with its match location.
|
||||
type matchedReplacement struct {
|
||||
replacement
|
||||
start int // start index in normalized content
|
||||
end int // end index in normalized content
|
||||
usedFuzzyMatch bool // true if fuzzy matching was used
|
||||
}
|
||||
|
||||
// NewEditTool creates the edit core tool.
|
||||
func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
cfg := ApplyOptions(opts)
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "edit",
|
||||
Description: "Edit a file by replacing exact text. The old_text must match exactly (including whitespace). Use this for precise, surgical edits. Fails if old_text is not found or matches multiple locations.",
|
||||
Description: "Edit a file by replacing exact text. Supports single edit via old_text/new_text, or multiple edits via the edits array. All edits in the array are matched against the original file content (non-incremental) and must be non-overlapping.",
|
||||
Parameters: map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
@@ -33,14 +60,32 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
},
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace (must match exactly)",
|
||||
"description": "Exact text to find and replace (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text to replace the old text with",
|
||||
"description": "New text to replace the old text with (single-edit mode). Must not be used with 'edits' array.",
|
||||
},
|
||||
"edits": map[string]any{
|
||||
"type": "array",
|
||||
"description": "Array of edits for multi-region replacement. Each edit must have unique, non-overlapping old_text. All matches are against the original file content.",
|
||||
"items": map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"old_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Exact text to find and replace for this edit",
|
||||
},
|
||||
"new_text": map[string]any{
|
||||
"type": "string",
|
||||
"description": "New text for this edit",
|
||||
},
|
||||
},
|
||||
"required": []string{"old_text", "new_text"},
|
||||
},
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "old_text", "new_text"},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
handler: func(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
return executeEdit(ctx, call, cfg.WorkDir)
|
||||
@@ -51,7 +96,7 @@ func NewEditTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args editArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
return fantasy.NewTextErrorResponse("path, old_text, and new_text parameters are required"), nil
|
||||
return fantasy.NewTextErrorResponse("failed to parse arguments: " + err.Error()), nil
|
||||
}
|
||||
if args.Path == "" {
|
||||
return fantasy.NewTextErrorResponse("path parameter is required"), nil
|
||||
@@ -69,56 +114,201 @@ func executeEdit(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
content := string(contentBytes)
|
||||
|
||||
// Normalize line endings for matching
|
||||
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
normalizedOld := strings.ReplaceAll(args.OldText, "\r\n", "\n")
|
||||
|
||||
// Try exact match first
|
||||
count := strings.Count(normalized, normalizedOld)
|
||||
|
||||
// If no exact match, try fuzzy matching
|
||||
if count == 0 {
|
||||
if idx, matchLen := fuzzyMatch(normalized, normalizedOld); idx >= 0 {
|
||||
// Apply fuzzy match — the matched text is the original content slice
|
||||
matchedText := normalized[idx : idx+matchLen]
|
||||
newContent := normalized[:idx] + args.NewText + normalized[idx+matchLen:]
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, matchedText, args.NewText)), nil
|
||||
}
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("old_text not found in %s", args.Path)), nil
|
||||
// Normalize and validate input
|
||||
replacements, err := normalizeEditInput(args)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("found %d matches for old_text in %s. Provide more context to identify the correct match.", count, args.Path)), nil
|
||||
// Apply all edits
|
||||
newContent, applied, err := applyEdits(content, replacements)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(err.Error()), nil
|
||||
}
|
||||
|
||||
// Apply the edit
|
||||
newContent := strings.Replace(normalized, normalizedOld, args.NewText, 1)
|
||||
|
||||
// Write the file
|
||||
if err := os.WriteFile(absPath, []byte(newContent), 0644); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to write file: %v", err)), nil
|
||||
}
|
||||
|
||||
diff := generateDiff(absPath, normalized, newContent)
|
||||
resp := fantasy.NewTextResponse(fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff))
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, normalizedOld, args.NewText)), nil
|
||||
// Generate diff
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
diff := generateDiff(absPath, normalizedContent, newContent)
|
||||
|
||||
// Build response with fuzzy match indication
|
||||
fuzzyCount := 0
|
||||
for _, m := range applied {
|
||||
if m.usedFuzzyMatch {
|
||||
fuzzyCount++
|
||||
}
|
||||
}
|
||||
|
||||
var msg string
|
||||
if len(applied) == 1 {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied edit (fuzzy match) to %s\n%s", args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied edit to %s\n%s", args.Path, diff)
|
||||
}
|
||||
} else {
|
||||
if fuzzyCount > 0 {
|
||||
msg = fmt.Sprintf("Applied %d edits (%d fuzzy) to %s\n%s", len(applied), fuzzyCount, args.Path, diff)
|
||||
} else {
|
||||
msg = fmt.Sprintf("Applied %d edits to %s\n%s", len(applied), args.Path, diff)
|
||||
}
|
||||
}
|
||||
|
||||
resp := fantasy.NewTextResponse(msg)
|
||||
return fantasy.WithResponseMetadata(resp, editDiffMeta(absPath, applied)), nil
|
||||
}
|
||||
|
||||
// normalizeEditInput validates and normalizes the edit input.
|
||||
// Returns error if both single-edit and multi-edit modes are used.
|
||||
func normalizeEditInput(args editArgs) ([]replacement, error) {
|
||||
singleMode := args.OldText != "" || args.NewText != ""
|
||||
multiMode := len(args.Edits) > 0
|
||||
|
||||
if singleMode && multiMode {
|
||||
return nil, fmt.Errorf("cannot use old_text/new_text together with edits array")
|
||||
}
|
||||
|
||||
if !singleMode && !multiMode {
|
||||
return nil, fmt.Errorf("must provide either old_text/new_text or edits array")
|
||||
}
|
||||
|
||||
if singleMode {
|
||||
if args.OldText == "" {
|
||||
return nil, fmt.Errorf("old_text is required when using single-edit mode")
|
||||
}
|
||||
if args.NewText == "" {
|
||||
return nil, fmt.Errorf("new_text is required when using single-edit mode")
|
||||
}
|
||||
return []replacement{{
|
||||
oldText: strings.ReplaceAll(args.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(args.NewText, "\r\n", "\n"),
|
||||
originalOld: args.OldText,
|
||||
originalNew: args.NewText,
|
||||
index: 0,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
// Multi-edit mode
|
||||
var reps []replacement
|
||||
for i, edit := range args.Edits {
|
||||
if edit.OldText == "" {
|
||||
return nil, fmt.Errorf("edits[%d].old_text is required", i)
|
||||
}
|
||||
reps = append(reps, replacement{
|
||||
oldText: strings.ReplaceAll(edit.OldText, "\r\n", "\n"),
|
||||
newText: strings.ReplaceAll(edit.NewText, "\r\n", "\n"),
|
||||
originalOld: edit.OldText,
|
||||
originalNew: edit.NewText,
|
||||
index: i,
|
||||
})
|
||||
}
|
||||
return reps, nil
|
||||
}
|
||||
|
||||
// applyEdits applies multiple replacements to the content.
|
||||
// All matches are against the original content (non-incremental).
|
||||
// Returns the new content, the applied matches, and any error.
|
||||
func applyEdits(content string, edits []replacement) (string, []matchedReplacement, error) {
|
||||
normalizedContent := strings.ReplaceAll(content, "\r\n", "\n")
|
||||
|
||||
// Find all matches
|
||||
var matched []matchedReplacement
|
||||
for _, edit := range edits {
|
||||
m, err := findMatch(normalizedContent, edit)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
matched = append(matched, *m)
|
||||
}
|
||||
|
||||
// Sort by position
|
||||
sort.Slice(matched, func(i, j int) bool {
|
||||
return matched[i].start < matched[j].start
|
||||
})
|
||||
|
||||
// Check for overlaps
|
||||
for i := 1; i < len(matched); i++ {
|
||||
if matched[i-1].end > matched[i].start {
|
||||
return "", nil, fmt.Errorf("edits[%d] and edits[%d] overlap; merge them into a single edit",
|
||||
matched[i-1].index, matched[i].index)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply edits in reverse order (end to start) to maintain stable offsets
|
||||
result := normalizedContent
|
||||
for i := len(matched) - 1; i >= 0; i-- {
|
||||
m := matched[i]
|
||||
result = result[:m.start] + m.newText + result[m.end:]
|
||||
}
|
||||
|
||||
return result, matched, nil
|
||||
}
|
||||
|
||||
// findMatch finds a unique match for the edit in the content.
|
||||
// Returns error if not found or ambiguous.
|
||||
func findMatch(content string, edit replacement) (*matchedReplacement, error) {
|
||||
// Try exact match first
|
||||
count := strings.Count(content, edit.oldText)
|
||||
|
||||
if count == 0 {
|
||||
// Try fuzzy match
|
||||
idx, matchLen := fuzzyMatch(content, edit.oldText)
|
||||
if idx < 0 {
|
||||
return nil, fmt.Errorf("edits[%d]: could not find old_text in file. The text must match exactly (including whitespace)", edit.index)
|
||||
}
|
||||
// Use the matched text from content for the replacement
|
||||
matchedText := content[idx : idx+matchLen]
|
||||
return &matchedReplacement{
|
||||
replacement: replacement{
|
||||
oldText: matchedText,
|
||||
newText: edit.newText,
|
||||
originalOld: edit.originalOld,
|
||||
originalNew: edit.originalNew,
|
||||
index: edit.index,
|
||||
},
|
||||
start: idx,
|
||||
end: idx + matchLen,
|
||||
usedFuzzyMatch: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
return nil, fmt.Errorf("found %d matches for edits[%d].old_text; each old_text must be unique, provide more context to identify the correct match", count, edit.index)
|
||||
}
|
||||
|
||||
// Single exact match
|
||||
idx := strings.Index(content, edit.oldText)
|
||||
return &matchedReplacement{
|
||||
replacement: edit,
|
||||
start: idx,
|
||||
end: idx + len(edit.oldText),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// editDiffMeta builds the structured metadata attached to edit tool responses.
|
||||
func editDiffMeta(path, oldText, newText string) map[string]any {
|
||||
func editDiffMeta(path string, applied []matchedReplacement) map[string]any {
|
||||
var diffBlocks []map[string]any
|
||||
totalAdditions, totalDeletions := 0, 0
|
||||
|
||||
for _, m := range applied {
|
||||
diffBlocks = append(diffBlocks, map[string]any{
|
||||
"old_text": m.originalOld,
|
||||
"new_text": m.originalNew,
|
||||
})
|
||||
totalAdditions += strings.Count(m.originalNew, "\n") + 1
|
||||
totalDeletions += strings.Count(m.originalOld, "\n") + 1
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"file_diffs": []map[string]any{{
|
||||
"path": path,
|
||||
"additions": strings.Count(newText, "\n") + 1,
|
||||
"deletions": strings.Count(oldText, "\n") + 1,
|
||||
"diff_blocks": []map[string]any{{
|
||||
"old_text": oldText,
|
||||
"new_text": newText,
|
||||
}},
|
||||
"path": path,
|
||||
"additions": totalAdditions,
|
||||
"deletions": totalDeletions,
|
||||
"diff_blocks": diffBlocks,
|
||||
}},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -715,3 +715,315 @@ func TestExecuteEdit_MetadataContainsFileDiffs(t *testing.T) {
|
||||
t.Fatal("file_diffs should be a non-empty array")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multi-edit tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Basic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "multi.txt")
|
||||
writeFileOrFail(t, path, "line1\nline2\nline3\nline4\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "line1", NewText: "LINE1"},
|
||||
{OldText: "line3", NewText: "LINE3"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "LINE1") {
|
||||
t.Error("first edit not applied: missing LINE1")
|
||||
}
|
||||
if !strings.Contains(gotStr, "LINE3") {
|
||||
t.Error("second edit not applied: missing LINE3")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line2") {
|
||||
t.Error("line2 was modified but should be untouched")
|
||||
}
|
||||
if !strings.Contains(gotStr, "line4") {
|
||||
t.Error("line4 was modified but should be untouched")
|
||||
}
|
||||
|
||||
// Check response mentions multiple edits
|
||||
if !strings.Contains(resp.Content, "2 edits") {
|
||||
t.Errorf("response should mention '2 edits', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NonIncrementalMatching(t *testing.T) {
|
||||
// All edits are matched against the original content, not incrementally
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "noninc.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
want := "AAA\nBBB\nccc\n"
|
||||
if gotStr != want {
|
||||
t.Errorf("got %q, want %q", gotStr, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_OverlapDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "overlap.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "hello world", NewText: "GOODBYE"}, // Overlaps with first edit
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for overlapping edits")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "overlap") {
|
||||
t.Errorf("expected 'overlap' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_DuplicateDetection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "dup.txt")
|
||||
writeFileOrFail(t, path, "hello\nworld\nhello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "hello", NewText: "HELLO"},
|
||||
{OldText: "world", NewText: "WORLD"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for ambiguous old_text (duplicate matches)")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "unique") {
|
||||
t.Errorf("expected 'unique' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello\nworld\nhello\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_NotFound(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "notfound.txt")
|
||||
writeFileOrFail(t, path, "hello world\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "nonexistent", NewText: "REPLACEMENT"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for not found")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "edits[0]") {
|
||||
t.Errorf("expected 'edits[0]' in error, got: %s", resp.Content)
|
||||
}
|
||||
|
||||
// File should be untouched
|
||||
got, _ := os.ReadFile(path)
|
||||
if string(got) != "hello world\n" {
|
||||
t.Error("file was modified despite error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_EmptyArray(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "empty.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error for empty edits array")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_MixedWithSingleMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
writeFileOrFail(t, path, "hello\n")
|
||||
|
||||
input, _ := json.Marshal(map[string]any{
|
||||
"path": path,
|
||||
"old_text": "hello",
|
||||
"new_text": "HELLO",
|
||||
"edits": []Edit{
|
||||
{OldText: "hello", NewText: "HI"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if !resp.IsError {
|
||||
t.Error("expected error when mixing single and multi-edit modes")
|
||||
}
|
||||
if !strings.Contains(resp.Content, "cannot use") {
|
||||
t.Errorf("expected 'cannot use' in error, got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_FuzzyMatch(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "fuzzy_multi.txt")
|
||||
// File has trailing whitespace
|
||||
original := "func foo() { \n\treturn 1 \n}\nfunc bar() { \n\treturn 2 \n}\n"
|
||||
writeFileOrFail(t, path, original)
|
||||
|
||||
// Search without trailing whitespace (common LLM behavior)
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "func foo() {\n\treturn 1\n}", NewText: "func foo() {\n\treturn 10\n}"},
|
||||
{OldText: "func bar() {\n\treturn 2\n}", NewText: "func bar() {\n\treturn 20\n}"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("executeEdit error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
got, _ := os.ReadFile(path)
|
||||
gotStr := string(got)
|
||||
|
||||
if !strings.Contains(gotStr, "return 10") {
|
||||
t.Error("first edit not applied")
|
||||
}
|
||||
if !strings.Contains(gotStr, "return 20") {
|
||||
t.Error("second edit not applied")
|
||||
}
|
||||
|
||||
// Response should mention fuzzy match
|
||||
if !strings.Contains(resp.Content, "fuzzy") {
|
||||
t.Errorf("response should mention 'fuzzy', got: %s", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteEdit_MultiEdit_Metadata(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "meta_multi.txt")
|
||||
writeFileOrFail(t, path, "aaa\nbbb\nccc\n")
|
||||
|
||||
input, _ := json.Marshal(editArgs{
|
||||
Path: path,
|
||||
Edits: []Edit{
|
||||
{OldText: "aaa", NewText: "AAA"},
|
||||
{OldText: "bbb", NewText: "BBB"},
|
||||
},
|
||||
})
|
||||
|
||||
resp, err := executeEdit(t.Context(), fantasy.ToolCall{Input: string(input)}, dir)
|
||||
if err != nil {
|
||||
t.Fatalf("error: %v", err)
|
||||
}
|
||||
if resp.IsError {
|
||||
t.Fatalf("tool returned error: %s", resp.Content)
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
if err := json.Unmarshal([]byte(resp.Metadata), &meta); err != nil {
|
||||
t.Fatalf("metadata is not valid JSON: %v", err)
|
||||
}
|
||||
|
||||
diffs, ok := meta["file_diffs"].([]any)
|
||||
if !ok || len(diffs) == 0 {
|
||||
t.Fatal("metadata missing file_diffs")
|
||||
}
|
||||
|
||||
firstDiff, ok := diffs[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("first diff is not an object")
|
||||
}
|
||||
|
||||
// Check that diff_blocks contains both edits
|
||||
diffBlocks, ok := firstDiff["diff_blocks"].([]any)
|
||||
if !ok || len(diffBlocks) != 2 {
|
||||
t.Fatalf("expected 2 diff_blocks, got %d", len(diffBlocks))
|
||||
}
|
||||
|
||||
// Verify each block has old_text and new_text
|
||||
for i, block := range diffBlocks {
|
||||
b, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("diff_block[%d] is not an object", i)
|
||||
}
|
||||
if _, ok := b["old_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing old_text", i)
|
||||
}
|
||||
if _, ok := b["new_text"]; !ok {
|
||||
t.Fatalf("diff_block[%d] missing new_text", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+1
-20
@@ -67,7 +67,7 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return readDirectory(absPath)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("'%s' is a directory, not a file. Use the ls tool to list directory contents.", args.Path)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
@@ -116,25 +116,6 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
func readDirectory(absPath string) (fantasy.ToolResponse, error) {
|
||||
entries, err := os.ReadDir(absPath)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to read directory: %v", err)), nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() {
|
||||
name += "/"
|
||||
}
|
||||
result.WriteString(name + "\n")
|
||||
}
|
||||
|
||||
tr := truncateHead(result.String(), 500, defaultMaxBytes)
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
// resolvePathWithWorkDir resolves a path to an absolute path relative to the
|
||||
// given workDir. If workDir is empty, os.Getwd() is used.
|
||||
func resolvePathWithWorkDir(path, workDir string) (string, error) {
|
||||
|
||||
+31
-38
@@ -28,14 +28,14 @@ type SubagentSpawnResult struct {
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
// The toolCallID parameter is the LLM-assigned ID of the spawn_subagent
|
||||
// The toolCallID parameter is the LLM-assigned ID of the subagent
|
||||
// tool call, enabling the parent to correlate subagent events.
|
||||
type SubagentSpawnFunc func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
// subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
@@ -59,11 +59,11 @@ type subagentArgs struct {
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
// NewSubagentTool creates the subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Name: "subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
@@ -130,13 +130,22 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
), fmt.Errorf("no subagent spawner in context")
|
||||
}
|
||||
|
||||
// Detach from the parent's deadline so the subagent gets its own
|
||||
// independent timeout (applied downstream in Kit.Subagent). The parent
|
||||
// context may carry a tight deadline from the LLM generation loop or
|
||||
// other tool timeouts that would prematurely kill the subagent.
|
||||
// We preserve context values (spawner, etc.) and propagate parent
|
||||
// cancellation (e.g. user hits Ctrl-C) without inheriting the deadline.
|
||||
spawnCtx := detachedWithCancel(ctx)
|
||||
// Build a clean context for the subagent that inherits values (e.g. the
|
||||
// spawner callback) but is completely detached from the parent's
|
||||
// deadline AND cancellation. The subagent gets its own independent
|
||||
// timeout (applied downstream in Kit.Subagent).
|
||||
//
|
||||
// Why full detachment instead of propagating parent cancellation?
|
||||
// The parent context may already be done (deadline exceeded or
|
||||
// cancelled) by the time this tool handler executes — for example when
|
||||
// the generation loop context carries a deadline, when the user
|
||||
// double-ESC cancels mid-turn, or when parallel tool execution
|
||||
// encounters a race between stream completion and tool dispatch. Using
|
||||
// context.WithoutCancel (Go 1.21+) ensures the subagent always starts
|
||||
// cleanly with a fresh timeout, following the pattern used by crush for
|
||||
// shutdown-resilient child work. The subagent's own timeout
|
||||
// (defaultSubagentTimeout / user-specified) provides the safety net.
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: ctx})
|
||||
|
||||
// Spawn in-process subagent.
|
||||
result, err := spawner(spawnCtx, call.ID, args.Task, args.Model, args.SystemPrompt, timeout)
|
||||
@@ -173,37 +182,21 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context detachment
|
||||
// Context helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// detachedContext wraps a parent context, preserving its values but removing
|
||||
// its deadline and cancellation. This allows the subagent to have its own
|
||||
// independent timeout while still accessing context-stored values (e.g. the
|
||||
// subagent spawner function).
|
||||
type detachedContext struct {
|
||||
// valuesContext preserves a parent context's values (e.g. the subagent
|
||||
// spawner callback) while stripping its deadline and cancellation. Combined
|
||||
// with context.WithoutCancel() this gives the subagent a completely clean
|
||||
// context that only inherits value-based dependencies.
|
||||
type valuesContext struct {
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (d detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (d detachedContext) Done() <-chan struct{} { return nil }
|
||||
func (d detachedContext) Err() error { return nil }
|
||||
func (d detachedContext) Value(key any) any { return d.parent.Value(key) }
|
||||
|
||||
// detachedWithCancel creates a new context that inherits values from the
|
||||
// parent but has no deadline. Cancellation of the parent is propagated: when
|
||||
// the parent is cancelled the returned context is also cancelled, but the
|
||||
// parent's deadline does not apply to the child.
|
||||
func detachedWithCancel(parent context.Context) context.Context {
|
||||
child, cancel := context.WithCancel(detachedContext{parent: parent})
|
||||
go func() {
|
||||
select {
|
||||
case <-parent.Done():
|
||||
cancel()
|
||||
case <-child.Done():
|
||||
}
|
||||
}()
|
||||
return child
|
||||
}
|
||||
func (v valuesContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (v valuesContext) Done() <-chan struct{} { return nil }
|
||||
func (v valuesContext) Err() error { return nil }
|
||||
func (v valuesContext) Value(key any) any { return v.parent.Value(key) }
|
||||
|
||||
// truncateResponse limits the response length to avoid overwhelming context windows.
|
||||
func truncateResponse(s string, maxLen int) string {
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValuesContext_StripsDeadlineAndCancellation(t *testing.T) {
|
||||
// Parent with a tight deadline.
|
||||
parent, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(5 * time.Millisecond) // Let deadline expire.
|
||||
|
||||
if parent.Err() == nil {
|
||||
t.Fatal("expected parent to be expired")
|
||||
}
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
if _, ok := vc.Deadline(); ok {
|
||||
t.Error("valuesContext should report no deadline")
|
||||
}
|
||||
if vc.Done() != nil {
|
||||
t.Error("valuesContext.Done() should return nil")
|
||||
}
|
||||
if vc.Err() != nil {
|
||||
t.Errorf("valuesContext.Err() should be nil, got %v", vc.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesContext_PreservesValues(t *testing.T) {
|
||||
type testKey struct{}
|
||||
parent := context.WithValue(context.Background(), testKey{}, "hello")
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
got, ok := vc.Value(testKey{}).(string)
|
||||
if !ok || got != "hello" {
|
||||
t.Errorf("expected value 'hello', got %q (ok=%v)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesCancelledParent(t *testing.T) {
|
||||
// Simulate the exact scenario from the bug: the parent generation
|
||||
// context is already cancelled when the subagent tool handler runs.
|
||||
parent, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancelled before detach.
|
||||
|
||||
// This is what executeSubagent now does:
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
// The spawn context must be alive.
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive, got err: %v", spawnCtx.Err())
|
||||
}
|
||||
|
||||
// Adding a timeout should produce a working context.
|
||||
tCtx, tCancel := context.WithTimeout(spawnCtx, 5*time.Second)
|
||||
defer tCancel()
|
||||
|
||||
if tCtx.Err() != nil {
|
||||
t.Fatalf("timeout context should be alive, got err: %v", tCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesDeadlineExceededParent(t *testing.T) {
|
||||
// Simulate: parent had a deadline that already expired.
|
||||
parent, pCancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer pCancel()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if parent.Err() != context.DeadlineExceeded {
|
||||
t.Fatalf("expected parent deadline exceeded, got: %v", parent.Err())
|
||||
}
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive after deadline-exceeded parent, got: %v", spawnCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_PreservesSpawnerValue(t *testing.T) {
|
||||
// Verify the subagent spawner callback survives context detachment.
|
||||
called := false
|
||||
spawner := SubagentSpawnFunc(func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error) {
|
||||
called = true
|
||||
return &SubagentSpawnResult{Response: "ok"}, nil
|
||||
})
|
||||
|
||||
parent := WithSubagentSpawner(context.Background(), spawner)
|
||||
// Cancel the parent.
|
||||
parentCtx, cancel := context.WithCancel(parent)
|
||||
cancel()
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parentCtx})
|
||||
|
||||
// Should be able to retrieve the spawner from the detached context.
|
||||
recovered := getSubagentSpawner(spawnCtx)
|
||||
if recovered == nil {
|
||||
t.Fatal("spawner should be recoverable from detached context")
|
||||
}
|
||||
|
||||
result, err := recovered(spawnCtx, "tc1", "test task", "", "", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("spawner call failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("spawner was not called")
|
||||
}
|
||||
if result.Response != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", result.Response)
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// SubagentTools returns all core tools except subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
|
||||
+344
-6
@@ -77,6 +77,64 @@ type Context struct {
|
||||
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
|
||||
CancelAndSend func(string)
|
||||
|
||||
// Abort cancels the current agent turn (if running) and clears the
|
||||
// message queue. Unlike CancelAndSend, no new message is injected —
|
||||
// the agent simply stops. Safe to call when idle (no-op).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.Abort() // stop whatever the agent is doing
|
||||
Abort func()
|
||||
|
||||
// IsIdle returns true when the agent is not processing a turn.
|
||||
// Extensions can use this to decide whether to dispatch immediately
|
||||
// or queue work for later.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if ctx.IsIdle() {
|
||||
// ctx.SendMessage("start new task")
|
||||
// }
|
||||
IsIdle func() bool
|
||||
|
||||
// Compact triggers context compaction, summarising older messages to
|
||||
// free context window space. Returns an error if compaction cannot
|
||||
// start (e.g. agent is busy or app is closed). The actual compaction
|
||||
// runs asynchronously; use OnComplete/OnError callbacks in
|
||||
// CompactConfig to observe the result.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.Compact(ext.CompactConfig{
|
||||
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
|
||||
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
|
||||
// })
|
||||
Compact func(CompactConfig) error
|
||||
|
||||
// SendMultimodalMessage injects a message with file attachments (images,
|
||||
// documents) into the conversation and triggers a new agent turn. Files
|
||||
// are described by FilePart structs containing the raw bytes, filename,
|
||||
// and MIME type. If the agent is busy the message is queued.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := os.ReadFile("photo.jpg")
|
||||
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
|
||||
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
|
||||
// })
|
||||
SendMultimodalMessage func(text string, files []FilePart)
|
||||
|
||||
// GetSessionUsage returns aggregated token usage and cost statistics
|
||||
// for the current session. This includes total input/output tokens,
|
||||
// cache read/write tokens, total cost, and request count.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// usage := ctx.GetSessionUsage()
|
||||
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
|
||||
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
|
||||
GetSessionUsage func() SessionUsage
|
||||
|
||||
// SetWidget places or updates a persistent widget in the TUI. Widgets
|
||||
// remain visible across agent turns until explicitly removed. The
|
||||
// widget is identified by WidgetConfig.ID; calling SetWidget with the
|
||||
@@ -572,6 +630,102 @@ type Context struct {
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found.
|
||||
GetTreeNode func(entryID string) *TreeNode
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf.
|
||||
// Each node contains full metadata (unlike GetMessages which flattens).
|
||||
GetCurrentBranch func() []TreeNode
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
GetChildren func(entryID string) []string
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Equivalent to SDK's Branch() but for extensions.
|
||||
NavigateTo func(entryID string) TreeNavigationResult
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string (empty if success).
|
||||
SummarizeBranch func(fromID, toID string) string
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// This is the "fresh context" primitive for context window management.
|
||||
CollapseBranch func(fromID, toID, summary string) TreeNavigationResult
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// LoadSkill loads a single skill file from path.
|
||||
// Parses YAML frontmatter, returns skill with content ready for injection.
|
||||
LoadSkill func(path string) (*Skill, string)
|
||||
|
||||
// LoadSkillsFromDir discovers and loads all skills from a directory.
|
||||
LoadSkillsFromDir func(dir string) SkillLoadResult
|
||||
|
||||
// DiscoverSkills finds skills in standard locations.
|
||||
// Checks ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
DiscoverSkills func() SkillLoadResult
|
||||
|
||||
// InjectSkillAsContext sends a skill's content as a system message.
|
||||
// Looks up skill by name from discovered skills.
|
||||
InjectSkillAsContext func(skillName string) string
|
||||
|
||||
// InjectRawSkillAsContext loads and immediately injects a skill file.
|
||||
InjectRawSkillAsContext func(path string) string
|
||||
|
||||
// GetAvailableSkills returns all currently loaded/discovered skills.
|
||||
GetAvailableSkills func() []Skill
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
ParseTemplate func(name, content string) PromptTemplate
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
RenderTemplate func(tpl PromptTemplate, vars map[string]string) string
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
ParseArguments func(input string, pattern ArgumentPattern) ParseResult
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
SimpleParseArguments func(input string, count int) []string
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
EvaluateModelConditional func(condition string) bool
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
RenderWithModelConditionals func(content string) string
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
ResolveModelChain func(preferences []string) ModelResolutionResult
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, uses current model.
|
||||
GetModelCapabilities func(model string) (ModelCapabilities, string)
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid.
|
||||
CheckModelAvailable func(model string) bool
|
||||
|
||||
// GetCurrentProvider returns just the provider part of current model.
|
||||
GetCurrentProvider func() string
|
||||
|
||||
// GetCurrentModelID returns just the model ID part of current model.
|
||||
GetCurrentModelID func() string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -598,6 +752,148 @@ type SessionMessage struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree navigation types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TreeNode represents a node in the session tree for navigation.
|
||||
// Extensions use this to traverse conversation history and implement
|
||||
// features like "fresh context" loops and branch summarization.
|
||||
type TreeNode struct {
|
||||
// ID is the unique entry identifier.
|
||||
ID string
|
||||
// ParentID links this entry to its parent (empty if root).
|
||||
ParentID string
|
||||
// Type is the entry type: "message", "branch_summary", "model_change", "extension_data", "tool_execution".
|
||||
Type string
|
||||
// Role is the message role for message entries: "user", "assistant", "system", "tool".
|
||||
Role string
|
||||
// Content is the text content or summary.
|
||||
Content string
|
||||
// Model is the model that generated this (for assistant messages).
|
||||
Model string
|
||||
// Provider is the provider used.
|
||||
Provider string
|
||||
// Timestamp is the RFC3339-formatted creation time.
|
||||
Timestamp string
|
||||
// Children is the list of child entry IDs for tree traversal.
|
||||
Children []string
|
||||
}
|
||||
|
||||
// TreeNavigationResult reports success or failure of tree operations.
|
||||
type TreeNavigationResult struct {
|
||||
// Success is true if the operation completed.
|
||||
Success bool
|
||||
// Error describes what went wrong (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Skill represents a loaded skill file with parsed YAML frontmatter.
|
||||
type Skill struct {
|
||||
// Name is the human-readable identifier.
|
||||
Name string
|
||||
// Description summarizes what this skill provides.
|
||||
Description string
|
||||
// Content is the markdown body (frontmatter stripped).
|
||||
Content string
|
||||
// Path is the absolute filesystem path.
|
||||
Path string
|
||||
// Tags are optional labels for categorization.
|
||||
Tags []string
|
||||
// When controls automatic inclusion: "always", "on-demand", or file-glob.
|
||||
When string
|
||||
}
|
||||
|
||||
// SkillLoadResult reports skills loaded from a directory.
|
||||
type SkillLoadResult struct {
|
||||
// Skills is the list of loaded skills.
|
||||
Skills []Skill
|
||||
// Error describes loading failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template parsing types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// PromptTemplate represents a parsed template with variable placeholders.
|
||||
type PromptTemplate struct {
|
||||
// Name is the template identifier.
|
||||
Name string
|
||||
// Content is the original template content.
|
||||
Content string
|
||||
// Variables are the extracted {{variable}} names.
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// ArgumentPattern defines how to parse command arguments.
|
||||
type ArgumentPattern struct {
|
||||
// Positional names for $1, $2, etc.
|
||||
Positional []string
|
||||
// Rest is the variable name for $@ (all remaining).
|
||||
Rest string
|
||||
// Flags maps flag names to variable names (e.g., "--loop" -> "loop").
|
||||
Flags map[string]string
|
||||
}
|
||||
|
||||
// ParseResult reports argument parsing outcome.
|
||||
type ParseResult struct {
|
||||
// Vars maps variable names to values for positional args.
|
||||
Vars map[string]string
|
||||
// Flags maps flag names to values.
|
||||
Flags map[string]string
|
||||
// Rest is remaining unparsed text.
|
||||
Rest string
|
||||
// Error describes parsing failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ModelConditional represents an <if-model> block for evaluation.
|
||||
type ModelConditional struct {
|
||||
// Condition is the model pattern (e.g., "claude-*", "anthropic/*").
|
||||
Condition string
|
||||
// Content is rendered if condition matches.
|
||||
Content string
|
||||
// Else is rendered if condition doesn't match.
|
||||
Else string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model resolution types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ModelCapabilities describes what a model supports.
|
||||
type ModelCapabilities struct {
|
||||
// Provider is the provider ID (e.g., "anthropic").
|
||||
Provider string
|
||||
// ModelID is the model identifier (e.g., "claude-sonnet-4-20250929").
|
||||
ModelID string
|
||||
// ContextLimit is the maximum context window in tokens.
|
||||
ContextLimit int
|
||||
// OutputLimit is the maximum output tokens.
|
||||
OutputLimit int
|
||||
// Reasoning indicates if the model supports reasoning/thinking.
|
||||
Reasoning bool
|
||||
// Streaming indicates if the model supports streaming.
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// ModelResolutionResult reports model chain resolution outcome.
|
||||
type ModelResolutionResult struct {
|
||||
// Model is the selected model in "provider/model" format.
|
||||
Model string
|
||||
// Capabilities describes the selected model.
|
||||
Capabilities ModelCapabilities
|
||||
// Attempted lists models tried before success.
|
||||
Attempted []string
|
||||
// Error describes resolution failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExtensionEntry represents persisted extension data stored in the session.
|
||||
// Extensions use AppendEntry to save custom state and GetEntries to retrieve
|
||||
// it on session resume.
|
||||
@@ -622,7 +918,7 @@ type ExtensionEntry struct {
|
||||
type ContextMessage struct {
|
||||
// Index is the position of this message in the original context array
|
||||
// (0-based). When returning messages from a ContextPrepareResult,
|
||||
// messages with Index >= 0 reuse the original fantasy.Message at that
|
||||
// messages with Index >= 0 reuse the original LLM message at that
|
||||
// position (preserving tool calls, reasoning, and other complex parts).
|
||||
// Set Index to -1 for newly injected messages (created from Role + Content).
|
||||
Index int
|
||||
@@ -699,6 +995,48 @@ type StatusBarEntry struct {
|
||||
Priority int
|
||||
}
|
||||
|
||||
// CompactConfig configures a programmatic context compaction request.
|
||||
type CompactConfig struct {
|
||||
// CustomInstructions is optional text appended to the summary prompt
|
||||
// (e.g. "Focus on the API design decisions"). Empty uses the default.
|
||||
CustomInstructions string
|
||||
// OnComplete is called when compaction finishes successfully.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnComplete func()
|
||||
// OnError is called when compaction fails. The argument is the error message.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnError func(errMsg string)
|
||||
}
|
||||
|
||||
// FilePart describes a file attachment for multimodal messages. Extensions
|
||||
// use this with SendMultimodalMessage to attach images or documents.
|
||||
type FilePart struct {
|
||||
// Filename is the name of the file (e.g. "photo.jpg").
|
||||
Filename string
|
||||
// Data is the raw file content.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// SessionUsage contains aggregated token usage and cost statistics for
|
||||
// the current session. Extensions use this with GetSessionUsage() to
|
||||
// report usage information.
|
||||
type SessionUsage struct {
|
||||
// TotalInputTokens is the sum of input tokens across all requests.
|
||||
TotalInputTokens int
|
||||
// TotalOutputTokens is the sum of output tokens across all requests.
|
||||
TotalOutputTokens int
|
||||
// TotalCacheReadTokens is the sum of cache read tokens.
|
||||
TotalCacheReadTokens int
|
||||
// TotalCacheWriteTokens is the sum of cache write tokens.
|
||||
TotalCacheWriteTokens int
|
||||
// TotalCost is the total cost in USD across all requests.
|
||||
TotalCost float64
|
||||
// RequestCount is the number of LLM requests made in this session.
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
@@ -784,7 +1122,7 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a spawn_subagent tool
|
||||
// OnSubagentStart registers a handler that fires when a subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
@@ -799,7 +1137,7 @@ func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a spawn_subagent call
|
||||
// OnSubagentEnd registers a handler that fires when a subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
@@ -1808,9 +2146,9 @@ func (BeforeCompactResult) isResult() {}
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStartEvent fires when a subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the spawn_subagent tool call.
|
||||
// ToolCallID is the LLM-assigned ID of the subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
@@ -1850,7 +2188,7 @@ type SubagentChunkEvent struct {
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a spawn_subagent tool call completes.
|
||||
// SubagentEndEvent fires when a subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
|
||||
@@ -72,7 +72,7 @@ const (
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStart fires when a subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
@@ -80,7 +80,7 @@ const (
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a spawn_subagent tool call completes (success
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
return i.install(source, scope, nil)
|
||||
}
|
||||
|
||||
// install is the internal implementation that supports optional include paths.
|
||||
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Include: includePaths,
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
// Update manifest timestamp, preserving existing fields like Include
|
||||
existing, _ := i.loadManifest(scope)
|
||||
var include []string
|
||||
var installed time.Time
|
||||
if existing != nil {
|
||||
for _, p := range existing.Packages {
|
||||
if p.Host+"/"+p.Path == source.Identity() {
|
||||
include = p.Include
|
||||
installed = p.Installed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if installed.IsZero() {
|
||||
installed = time.Now()
|
||||
}
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Installed: installed,
|
||||
Updated: time.Now(),
|
||||
Include: include,
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return i.install(source, scope, includePaths)
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
|
||||
@@ -34,59 +34,64 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Warn("skipping extension", "path", p, "err", err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Debug("loaded extension", "path", p,
|
||||
"handlers", countHandlers(ext),
|
||||
"tools", len(ext.Tools),
|
||||
"commands", len(ext.Commands),
|
||||
"tool_renderers", len(ext.ToolRenderers))
|
||||
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// pathSet is a thread-safe helper for deduplicating and ordering file paths.
|
||||
type pathSet struct {
|
||||
m map[string]bool
|
||||
list []string
|
||||
}
|
||||
|
||||
func newPathSet() *pathSet {
|
||||
return &pathSet{m: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (ps *pathSet) add(p string) bool {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ps.m[abs] {
|
||||
return false
|
||||
}
|
||||
ps.m[abs] = true
|
||||
ps.list = append(ps.list, abs)
|
||||
return true
|
||||
}
|
||||
|
||||
// discoverExtensionPaths returns deduplicated paths to extension files in
|
||||
// load-order (global first, then project-local, then explicit).
|
||||
func discoverExtensionPaths(extraPaths []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var paths []string
|
||||
|
||||
add := func(p string) {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
seen[abs] = true
|
||||
paths = append(paths, abs)
|
||||
}
|
||||
ps := newPathSet()
|
||||
|
||||
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
||||
globalDir := globalExtensionsDir()
|
||||
for _, p := range findExtensionsInDir(globalDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Global installed git packages: $XDG_DATA_HOME/kit/git/
|
||||
globalGitDir := globalGitInstallRoot()
|
||||
for _, p := range findExtensionsInGitPackages(globalGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local installed git packages: .kit/git/
|
||||
projectGitDir := filepath.Join(".kit", "git")
|
||||
for _, p := range findExtensionsInGitPackages(projectGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
@@ -97,14 +102,14 @@ func discoverExtensionPaths(extraPaths []string) []string {
|
||||
}
|
||||
if info.IsDir() {
|
||||
for _, found := range findExtensionsInDir(p) {
|
||||
add(found)
|
||||
ps.add(found)
|
||||
}
|
||||
} else if strings.HasSuffix(p, ".go") {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
return ps.list
|
||||
}
|
||||
|
||||
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
||||
@@ -123,7 +128,7 @@ func findExtensionsInDir(dir string) []string {
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
@@ -180,9 +185,13 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
@@ -192,13 +201,6 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -217,7 +219,7 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+362
-13
@@ -1,21 +1,93 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// reentrantMu — a per-extension mutex that allows the same goroutine to
|
||||
// re-enter (e.g. handler → ctx.EmitCustomEvent → handler in same extension).
|
||||
// Different goroutines are serialized, preventing concurrent state mutation.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type reentrantMu struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
owner int64 // goroutine ID that holds the lock, or 0
|
||||
depth int // re-entrancy depth
|
||||
}
|
||||
|
||||
// initReentrantMu initializes the reentrant mutex in-place. Must be called
|
||||
// after the struct is at its final memory location (not before copying).
|
||||
func (r *reentrantMu) init() {
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
}
|
||||
|
||||
// lock acquires the mutex. If the calling goroutine already holds it, the
|
||||
// call succeeds immediately (re-entrant). Every call to lock must be paired
|
||||
// with a call to unlock.
|
||||
func (r *reentrantMu) lock() {
|
||||
gid := goroutineID()
|
||||
r.mu.Lock()
|
||||
if r.owner == gid {
|
||||
// Re-entrant: same goroutine already holds the lock.
|
||||
r.depth++
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Wait for the current owner to release.
|
||||
for r.owner != 0 {
|
||||
r.cond.Wait() // releases mu, blocks, re-acquires mu on wake
|
||||
}
|
||||
r.owner = gid
|
||||
r.depth = 1
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// unlock releases the mutex (or decrements re-entrancy depth).
|
||||
func (r *reentrantMu) unlock() {
|
||||
r.mu.Lock()
|
||||
r.depth--
|
||||
if r.depth == 0 {
|
||||
r.owner = 0
|
||||
r.cond.Signal()
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// goroutineID extracts the current goroutine's ID from runtime.Stack output.
|
||||
// This is a well-known technique used by Go testing infrastructure.
|
||||
func goroutineID() int64 {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
// Stack output starts with "goroutine NNN ["
|
||||
s := buf[:n]
|
||||
s = s[len("goroutine "):]
|
||||
s = s[:bytes.IndexByte(s, ' ')]
|
||||
id, _ := strconv.ParseInt(string(s), 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
// Runner manages loaded extensions and dispatches events to their handlers
|
||||
// sequentially. Handlers execute in extension
|
||||
// load order; for cancellable events the first blocking result wins.
|
||||
//
|
||||
// Each extension has a dedicated reentrant mutex so that handlers for the
|
||||
// same extension are serialized (preventing data races on shared package-level
|
||||
// state), while handlers for different extensions may execute concurrently.
|
||||
type Runner struct {
|
||||
extensions []LoadedExtension
|
||||
extMu []reentrantMu // per-extension reentrant mutex, indexed by extension position
|
||||
ctx Context
|
||||
widgets map[string]WidgetConfig // keyed by widget ID
|
||||
statusEntries map[string]StatusBarEntry // keyed by status key
|
||||
@@ -52,15 +124,284 @@ type LoadedExtension struct {
|
||||
|
||||
// NewRunner creates a Runner from a set of loaded extensions.
|
||||
func NewRunner(exts []LoadedExtension) *Runner {
|
||||
return &Runner{extensions: exts}
|
||||
mus := make([]reentrantMu, len(exts))
|
||||
for i := range mus {
|
||||
mus[i].init()
|
||||
}
|
||||
return &Runner{extensions: exts, extMu: mus}
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
// passed to every handler invocation. Thread-safe.
|
||||
// passed to every handler invocation. Nil function fields are replaced with
|
||||
// safe no-ops so extension handlers never panic on a missing callback.
|
||||
// Thread-safe.
|
||||
func (r *Runner) SetContext(ctx Context) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ctx = ctx
|
||||
r.ctx = normalizeContext(ctx)
|
||||
}
|
||||
|
||||
// normalizeContext replaces nil function fields in ctx with no-op stubs so
|
||||
// that extension handlers can call any ctx method without a nil-function panic.
|
||||
func normalizeContext(ctx Context) Context {
|
||||
if ctx.Print == nil {
|
||||
ctx.Print = func(string) {}
|
||||
}
|
||||
if ctx.PrintInfo == nil {
|
||||
ctx.PrintInfo = func(string) {}
|
||||
}
|
||||
if ctx.PrintError == nil {
|
||||
ctx.PrintError = func(string) {}
|
||||
}
|
||||
if ctx.PrintBlock == nil {
|
||||
ctx.PrintBlock = func(PrintBlockOpts) {}
|
||||
}
|
||||
if ctx.SendMessage == nil {
|
||||
ctx.SendMessage = func(string) {}
|
||||
}
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.Abort == nil {
|
||||
ctx.Abort = func() {}
|
||||
}
|
||||
if ctx.IsIdle == nil {
|
||||
ctx.IsIdle = func() bool { return true }
|
||||
}
|
||||
if ctx.Compact == nil {
|
||||
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
|
||||
}
|
||||
if ctx.SendMultimodalMessage == nil {
|
||||
ctx.SendMultimodalMessage = func(string, []FilePart) {}
|
||||
}
|
||||
if ctx.GetSessionUsage == nil {
|
||||
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
if ctx.RemoveWidget == nil {
|
||||
ctx.RemoveWidget = func(string) {}
|
||||
}
|
||||
if ctx.SetHeader == nil {
|
||||
ctx.SetHeader = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveHeader == nil {
|
||||
ctx.RemoveHeader = func() {}
|
||||
}
|
||||
if ctx.SetFooter == nil {
|
||||
ctx.SetFooter = func(HeaderFooterConfig) {}
|
||||
}
|
||||
if ctx.RemoveFooter == nil {
|
||||
ctx.RemoveFooter = func() {}
|
||||
}
|
||||
if ctx.PromptSelect == nil {
|
||||
ctx.PromptSelect = func(PromptSelectConfig) PromptSelectResult {
|
||||
return PromptSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptConfirm == nil {
|
||||
ctx.PromptConfirm = func(PromptConfirmConfig) PromptConfirmResult {
|
||||
return PromptConfirmResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptInput == nil {
|
||||
ctx.PromptInput = func(PromptInputConfig) PromptInputResult {
|
||||
return PromptInputResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.PromptMultiSelect == nil {
|
||||
ctx.PromptMultiSelect = func(PromptMultiSelectConfig) PromptMultiSelectResult {
|
||||
return PromptMultiSelectResult{Cancelled: true}
|
||||
}
|
||||
}
|
||||
if ctx.ShowOverlay == nil {
|
||||
ctx.ShowOverlay = func(OverlayConfig) OverlayResult {
|
||||
return OverlayResult{Cancelled: true, Index: -1}
|
||||
}
|
||||
}
|
||||
if ctx.SetEditor == nil {
|
||||
ctx.SetEditor = func(EditorConfig) {}
|
||||
}
|
||||
if ctx.ResetEditor == nil {
|
||||
ctx.ResetEditor = func() {}
|
||||
}
|
||||
if ctx.SetEditorText == nil {
|
||||
ctx.SetEditorText = func(string) {}
|
||||
}
|
||||
if ctx.SetUIVisibility == nil {
|
||||
ctx.SetUIVisibility = func(UIVisibility) {}
|
||||
}
|
||||
if ctx.SetStatus == nil {
|
||||
ctx.SetStatus = func(string, string, int) {}
|
||||
}
|
||||
if ctx.RemoveStatus == nil {
|
||||
ctx.RemoveStatus = func(string) {}
|
||||
}
|
||||
if ctx.GetContextStats == nil {
|
||||
ctx.GetContextStats = func() ContextStats { return ContextStats{} }
|
||||
}
|
||||
if ctx.GetMessages == nil {
|
||||
ctx.GetMessages = func() []SessionMessage { return nil }
|
||||
}
|
||||
if ctx.GetSessionPath == nil {
|
||||
ctx.GetSessionPath = func() string { return "" }
|
||||
}
|
||||
if ctx.AppendEntry == nil {
|
||||
ctx.AppendEntry = func(string, string) (string, error) { return "", nil }
|
||||
}
|
||||
if ctx.GetEntries == nil {
|
||||
ctx.GetEntries = func(string) []ExtensionEntry { return nil }
|
||||
}
|
||||
if ctx.GetOption == nil {
|
||||
ctx.GetOption = func(string) string { return "" }
|
||||
}
|
||||
if ctx.SetOption == nil {
|
||||
ctx.SetOption = func(string, string) {}
|
||||
}
|
||||
if ctx.SetModel == nil {
|
||||
ctx.SetModel = func(string) error { return nil }
|
||||
}
|
||||
if ctx.GetAvailableModels == nil {
|
||||
ctx.GetAvailableModels = func() []ModelInfoEntry { return nil }
|
||||
}
|
||||
if ctx.EmitCustomEvent == nil {
|
||||
ctx.EmitCustomEvent = func(string, string) {}
|
||||
}
|
||||
if ctx.GetAllTools == nil {
|
||||
ctx.GetAllTools = func() []ToolInfo { return nil }
|
||||
}
|
||||
if ctx.SetActiveTools == nil {
|
||||
ctx.SetActiveTools = func([]string) {}
|
||||
}
|
||||
if ctx.Exit == nil {
|
||||
ctx.Exit = func() {}
|
||||
}
|
||||
if ctx.Complete == nil {
|
||||
ctx.Complete = func(CompleteRequest) (CompleteResponse, error) {
|
||||
return CompleteResponse{}, nil
|
||||
}
|
||||
}
|
||||
if ctx.SuspendTUI == nil {
|
||||
ctx.SuspendTUI = func(callback func()) error { callback(); return nil }
|
||||
}
|
||||
if ctx.RenderMessage == nil {
|
||||
ctx.RenderMessage = func(string, string) {}
|
||||
}
|
||||
if ctx.RegisterTheme == nil {
|
||||
ctx.RegisterTheme = func(string, ThemeColorConfig) {}
|
||||
}
|
||||
if ctx.SetTheme == nil {
|
||||
ctx.SetTheme = func(string) error { return nil }
|
||||
}
|
||||
if ctx.ListThemes == nil {
|
||||
ctx.ListThemes = func() []string { return nil }
|
||||
}
|
||||
if ctx.ReloadExtensions == nil {
|
||||
ctx.ReloadExtensions = func() error { return nil }
|
||||
}
|
||||
if ctx.SpawnSubagent == nil {
|
||||
ctx.SpawnSubagent = func(SubagentConfig) (*SubagentHandle, *SubagentResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.GetTreeNode == nil {
|
||||
ctx.GetTreeNode = func(string) *TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetCurrentBranch == nil {
|
||||
ctx.GetCurrentBranch = func() []TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetChildren == nil {
|
||||
ctx.GetChildren = func(string) []string { return nil }
|
||||
}
|
||||
if ctx.NavigateTo == nil {
|
||||
ctx.NavigateTo = func(string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.SummarizeBranch == nil {
|
||||
ctx.SummarizeBranch = func(string, string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if ctx.CollapseBranch == nil {
|
||||
ctx.CollapseBranch = func(string, string, string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.LoadSkill == nil {
|
||||
ctx.LoadSkill = func(string) (*Skill, string) { return nil, "" }
|
||||
}
|
||||
if ctx.LoadSkillsFromDir == nil {
|
||||
ctx.LoadSkillsFromDir = func(string) SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.DiscoverSkills == nil {
|
||||
ctx.DiscoverSkills = func() SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.InjectSkillAsContext == nil {
|
||||
ctx.InjectSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.InjectRawSkillAsContext == nil {
|
||||
ctx.InjectRawSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.GetAvailableSkills == nil {
|
||||
ctx.GetAvailableSkills = func() []Skill { return nil }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ParseTemplate == nil {
|
||||
ctx.ParseTemplate = func(string, string) PromptTemplate { return PromptTemplate{} }
|
||||
}
|
||||
if ctx.RenderTemplate == nil {
|
||||
ctx.RenderTemplate = func(PromptTemplate, map[string]string) string { return "" }
|
||||
}
|
||||
if ctx.ParseArguments == nil {
|
||||
ctx.ParseArguments = func(string, ArgumentPattern) ParseResult { return ParseResult{} }
|
||||
}
|
||||
if ctx.SimpleParseArguments == nil {
|
||||
ctx.SimpleParseArguments = func(string, int) []string { return nil }
|
||||
}
|
||||
if ctx.EvaluateModelConditional == nil {
|
||||
ctx.EvaluateModelConditional = func(string) bool { return false }
|
||||
}
|
||||
if ctx.RenderWithModelConditionals == nil {
|
||||
ctx.RenderWithModelConditionals = func(string) string { return "" }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ResolveModelChain == nil {
|
||||
ctx.ResolveModelChain = func([]string) ModelResolutionResult {
|
||||
return ModelResolutionResult{Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.GetModelCapabilities == nil {
|
||||
ctx.GetModelCapabilities = func(string) (ModelCapabilities, string) {
|
||||
return ModelCapabilities{}, "not implemented"
|
||||
}
|
||||
}
|
||||
if ctx.CheckModelAvailable == nil {
|
||||
ctx.CheckModelAvailable = func(string) bool { return false }
|
||||
}
|
||||
if ctx.GetCurrentProvider == nil {
|
||||
ctx.GetCurrentProvider = func() string { return "" }
|
||||
}
|
||||
if ctx.GetCurrentModelID == nil {
|
||||
ctx.GetCurrentModelID = func() string { return "" }
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// GetContext returns a snapshot of the current runtime context. Thread-safe.
|
||||
@@ -102,13 +443,15 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
for i := range r.extensions {
|
||||
ext := &r.extensions[i]
|
||||
handlers := ext.Handlers[event.Type()]
|
||||
if len(handlers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r.extMu[i].lock()
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
log.Warn("extension handler error",
|
||||
"path", ext.Path,
|
||||
"event", event.Type(),
|
||||
"err", err)
|
||||
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
@@ -117,6 +460,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
|
||||
// Check for blocking/short-circuit results.
|
||||
if isBlocking(result) {
|
||||
r.extMu[i].unlock()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -124,6 +468,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
// the caller is responsible for applying the modifications.
|
||||
accumulated = result
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
return accumulated, nil
|
||||
}
|
||||
@@ -442,9 +787,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
safeInvoke := func(h func(string)) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Warn("custom event handler panicked",
|
||||
"event", name,
|
||||
"err", fmt.Sprintf("%v", rec))
|
||||
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
|
||||
}
|
||||
}()
|
||||
h(data)
|
||||
@@ -452,11 +795,17 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
|
||||
// Extension-registered handlers first (in load order).
|
||||
for i := range r.extensions {
|
||||
for _, h := range r.extensions[i].CustomEventHandlers[name] {
|
||||
extHandlers := r.extensions[i].CustomEventHandlers[name]
|
||||
if len(extHandlers) == 0 {
|
||||
continue
|
||||
}
|
||||
r.extMu[i].lock()
|
||||
for _, h := range extHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
// Then dynamic subscriptions.
|
||||
// Then dynamic subscriptions (not extension-scoped, no per-ext lock).
|
||||
for _, h := range dynamicHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -571,3 +572,142 @@ func TestRunner_ContextPrintNilSafe(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitSameExtension(t *testing.T) {
|
||||
// Verify that concurrent Emit calls for the same extension are serialized
|
||||
// and don't cause data races on shared handler state.
|
||||
var counter int
|
||||
ext := makeHandlerExt("shared-state.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// Read-modify-write: racy without serialization.
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
SubagentChunk: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
_, _ = r.Emit(SubagentChunkEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if counter != goroutines*iterations*2 {
|
||||
t.Errorf("expected counter=%d, got %d (race detected)", goroutines*iterations*2, counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitDifferentExtensions(t *testing.T) {
|
||||
// Two extensions with independent state should not block each other
|
||||
// and should both run correctly under concurrent Emit calls.
|
||||
var counter1, counter2 int
|
||||
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter1
|
||||
counter1 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter2
|
||||
counter2 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext1, ext2)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
expected := goroutines * iterations
|
||||
if counter1 != expected {
|
||||
t.Errorf("ext1 counter: expected %d, got %d", expected, counter1)
|
||||
}
|
||||
if counter2 != expected {
|
||||
t.Errorf("ext2 counter: expected %d, got %d", expected, counter2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ReentrantEmitCustomEvent(t *testing.T) {
|
||||
// Verify that a handler can call EmitCustomEvent (which dispatches to
|
||||
// the same extension's custom event handlers) without deadlocking.
|
||||
var order []string
|
||||
ext := LoadedExtension{
|
||||
Path: "reentrant.go",
|
||||
Handlers: map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
// This triggers EmitCustomEvent for the same extension
|
||||
// via a direct runner call (simulating ctx.EmitCustomEvent).
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
CustomEventHandlers: map[string][]func(string){
|
||||
"test-event": {
|
||||
func(data string) {
|
||||
order = append(order, "custom:"+data)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := makeRunner(ext)
|
||||
|
||||
// Wire up the handler to call EmitCustomEvent re-entrantly.
|
||||
ext.Handlers[SessionStart] = []HandlerFunc{
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
r.EmitCustomEvent("test-event", "hello")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
r.extensions[0] = ext
|
||||
// Rebuild mutexes after modifying extensions slice.
|
||||
r.extMu = make([]reentrantMu, len(r.extensions))
|
||||
for i := range r.extMu {
|
||||
r.extMu[i].init()
|
||||
}
|
||||
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 2 || order[0] != "session_start" || order[1] != "custom:hello" {
|
||||
t.Errorf("expected [session_start, custom:hello], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +173,10 @@ type subagentJSONOutput struct {
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter uint64
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := atomic.AddUint64(&subagentCounter, 1)
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
|
||||
// Session types
|
||||
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
|
||||
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
|
||||
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
|
||||
|
||||
// Option types
|
||||
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
|
||||
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
|
||||
// LLM completion types
|
||||
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
|
||||
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
|
||||
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
|
||||
"FilePart": reflect.ValueOf((*FilePart)(nil)),
|
||||
|
||||
// Status bar types
|
||||
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
|
||||
@@ -128,6 +131,24 @@ func Symbols() interp.Exports {
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Tree navigation types
|
||||
"TreeNode": reflect.ValueOf((*TreeNode)(nil)),
|
||||
"TreeNavigationResult": reflect.ValueOf((*TreeNavigationResult)(nil)),
|
||||
|
||||
// Skill types
|
||||
"Skill": reflect.ValueOf((*Skill)(nil)),
|
||||
"SkillLoadResult": reflect.ValueOf((*SkillLoadResult)(nil)),
|
||||
|
||||
// Template parsing types
|
||||
"PromptTemplate": reflect.ValueOf((*PromptTemplate)(nil)),
|
||||
"ArgumentPattern": reflect.ValueOf((*ArgumentPattern)(nil)),
|
||||
"ParseResult": reflect.ValueOf((*ParseResult)(nil)),
|
||||
"ModelConditional": reflect.ValueOf((*ModelConditional)(nil)),
|
||||
|
||||
// Model resolution types
|
||||
"ModelCapabilities": reflect.ValueOf((*ModelCapabilities)(nil)),
|
||||
"ModelResolutionResult": reflect.ValueOf((*ModelResolutionResult)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
// This includes the global extensions directory and the project-local
|
||||
// .kit/extensions/ directory (if they exist). Explicit -e paths that
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatcher_ReloadsOnGoFileChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write an initial extension file.
|
||||
extFile := filepath.Join(dir, "test.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Modify the file.
|
||||
time.Sleep(50 * time.Millisecond) // let watcher settle
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for debounce (300ms) + margin.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_IgnoresNonGoFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Write a non-.go file.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
txtFile := filepath.Join(dir, "notes.txt")
|
||||
if err := os.WriteFile(txtFile, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait past the debounce window.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 0 {
|
||||
t.Errorf("expected 0 reloads for .txt file, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_Debounces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
extFile := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Rapid-fire writes (simulating editor save: write temp, rename, etc.).
|
||||
for range 5 {
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for debounce to fire.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 debounced reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_Deduplicates(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dirs := WatchedDirs([]string{dir, dir})
|
||||
|
||||
count := 0
|
||||
for _, d := range dirs {
|
||||
abs, _ := filepath.Abs(dir)
|
||||
if d == abs {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected directory to appear once, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_FileParent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(file, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dirs := WatchedDirs([]string{file})
|
||||
|
||||
abs, _ := filepath.Abs(dir)
|
||||
found := false
|
||||
for _, d := range dirs {
|
||||
if d == abs {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected parent dir %s in watched dirs %v", abs, dirs)
|
||||
}
|
||||
}
|
||||
@@ -28,11 +28,11 @@ func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantas
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
|
||||
// into fantasy.AgentTool implementations so the LLM can invoke them.
|
||||
// ExtensionToolsAsLLMTools converts ToolDef values registered by extensions
|
||||
// into LLM agent tool implementations so the LLM can invoke them.
|
||||
// The runner is optional; if provided, ToolContext.OnProgress routes
|
||||
// progress messages through the runner's Print function.
|
||||
func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
tools := make([]fantasy.AgentTool, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
tools = append(tools, &extensionTool{def: def, runner: runner})
|
||||
@@ -42,14 +42,14 @@ func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
@@ -154,7 +154,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
|
||||
// extensionTool — wraps a ToolDef into an LLM agent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type extensionTool struct {
|
||||
@@ -182,7 +182,7 @@ func (t *extensionTool) Info() fantasy.ToolInfo {
|
||||
info.Parameters = props
|
||||
} else {
|
||||
// Schema doesn't have "properties" — use as-is (may be
|
||||
// a flat property map already matching fantasy's format).
|
||||
// a flat property map already matching the expected format).
|
||||
info.Parameters = schema
|
||||
}
|
||||
// Extract required fields if present.
|
||||
|
||||
@@ -192,7 +192,7 @@ func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
func TestExtensionToolsAsLLMTools(t *testing.T) {
|
||||
defs := []ToolDef{
|
||||
{
|
||||
Name: "greet",
|
||||
@@ -202,7 +202,7 @@ func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func TestExtensionTool_Error(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
@@ -259,7 +259,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
}
|
||||
|
||||
// Without runner, OnProgress is a no-op.
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -285,7 +285,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools2 := ExtensionToolsAsFantasy(defs2, runner)
|
||||
tools2 := ExtensionToolsAsLLMTools(defs2, runner)
|
||||
_, err = tools2[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -306,7 +306,7 @@ func TestExtensionTool_ExecuteWithContextPriority(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -330,7 +330,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
_, _ = tools[0].Run(ctx, fantasy.ToolCall{Input: ""})
|
||||
if !sawCancelled {
|
||||
t.Error("expected IsCancelled=true for cancelled context")
|
||||
@@ -339,7 +339,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
|
||||
func TestExtensionTool_ProviderOptions(t *testing.T) {
|
||||
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
|
||||
// Initially nil.
|
||||
opts := tools[0].ProviderOptions()
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
// Package fences provides utilities for detecting markdown code regions
|
||||
// (fenced code blocks and inline code spans) and applying transformations
|
||||
// only to text outside those regions.
|
||||
//
|
||||
// This prevents special tokens like $1, $@, or @file from being interpreted
|
||||
// when they appear inside ``` fences, ~~~ fences, or `inline` code spans.
|
||||
package fences
|
||||
|
||||
import "strings"
|
||||
|
||||
// Ranges returns byte ranges [start, end) of fenced code blocks in content.
|
||||
// Recognises both backtick (```) and tilde (~~~) fences, with optional
|
||||
// leading indentation (up to 3 spaces) and optional info strings.
|
||||
// An unclosed fence extends to the end of content.
|
||||
func Ranges(content string) [][2]int {
|
||||
var result [][2]int
|
||||
var inFence bool
|
||||
var fenceChar byte
|
||||
var fenceCount int
|
||||
var fenceStart int
|
||||
|
||||
pos := 0
|
||||
for pos < len(content) {
|
||||
// Find the end of the current line.
|
||||
lineEnd := strings.IndexByte(content[pos:], '\n')
|
||||
var line string
|
||||
var nextPos int
|
||||
if lineEnd < 0 {
|
||||
line = content[pos:]
|
||||
nextPos = len(content)
|
||||
} else {
|
||||
line = content[pos : pos+lineEnd]
|
||||
nextPos = pos + lineEnd + 1
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(line, " ")
|
||||
indent := len(line) - len(trimmed)
|
||||
|
||||
if !inFence {
|
||||
if indent <= 3 {
|
||||
if ch, n := parseFenceOpen(trimmed); n > 0 {
|
||||
inFence = true
|
||||
fenceChar = ch
|
||||
fenceCount = n
|
||||
fenceStart = pos
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if indent <= 3 && isFenceClose(trimmed, fenceChar, fenceCount) {
|
||||
result = append(result, [2]int{fenceStart, nextPos})
|
||||
inFence = false
|
||||
}
|
||||
}
|
||||
|
||||
pos = nextPos
|
||||
}
|
||||
|
||||
// Unclosed fence extends to end of content.
|
||||
if inFence {
|
||||
result = append(result, [2]int{fenceStart, len(content)})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReplaceOutside applies fn to each text segment that is outside fenced code
|
||||
// blocks and inline code spans, leaving code content unchanged. This is the
|
||||
// primary entry point for callers that need to do regex replacement only on
|
||||
// non-code text.
|
||||
func ReplaceOutside(content string, fn func(string) string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return replaceOutsideInline(content, fn)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
// Within non-fenced segments, also skip inline code spans.
|
||||
b.WriteString(replaceOutsideInline(content[pos:r[0]], fn))
|
||||
}
|
||||
// Preserve fenced content verbatim.
|
||||
b.WriteString(content[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(content) {
|
||||
b.WriteString(replaceOutsideInline(content[pos:], fn))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// StripCode returns content with fenced code blocks and inline code spans
|
||||
// removed. Useful for detection/matching where only non-code text matters.
|
||||
func StripCode(content string) string {
|
||||
// First strip fenced blocks.
|
||||
stripped := StripFenced(content)
|
||||
// Then strip inline code spans from what remains.
|
||||
return stripInlineCode(stripped)
|
||||
}
|
||||
|
||||
// StripFenced returns content with fenced code block regions removed.
|
||||
// Useful for detection/matching where only non-fenced text matters.
|
||||
// NOTE: this does NOT strip inline code spans; use StripCode for both.
|
||||
func StripFenced(content string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(content[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(content[pos:])
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// parseFenceOpen checks whether trimmed (leading spaces already removed)
|
||||
// starts a fenced code block. Returns the fence character and count, or
|
||||
// (0, 0) if it is not a fence opener.
|
||||
func parseFenceOpen(trimmed string) (byte, int) {
|
||||
if len(trimmed) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
ch := trimmed[0]
|
||||
if ch != '`' && ch != '~' {
|
||||
return 0, 0
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == ch {
|
||||
count++
|
||||
}
|
||||
if count < 3 {
|
||||
return 0, 0
|
||||
}
|
||||
// Per CommonMark: backtick fences cannot have backticks in the info string.
|
||||
if ch == '`' && strings.ContainsRune(trimmed[count:], '`') {
|
||||
return 0, 0
|
||||
}
|
||||
return ch, count
|
||||
}
|
||||
|
||||
// isFenceClose checks whether trimmed is a closing fence matching fenceChar
|
||||
// with at least minCount characters. A closing fence line contains only the
|
||||
// fence characters and optional trailing spaces.
|
||||
func isFenceClose(trimmed string, fenceChar byte, minCount int) bool {
|
||||
if len(trimmed) == 0 || trimmed[0] != fenceChar {
|
||||
return false
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == fenceChar {
|
||||
count++
|
||||
}
|
||||
if count < minCount {
|
||||
return false
|
||||
}
|
||||
// Closing fence must contain only fence chars (and optional trailing spaces).
|
||||
return strings.TrimRight(trimmed[count:], " ") == ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Inline code span handling
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// inlineCodeRanges returns byte ranges [start, end) of inline code spans
|
||||
// in segment. Per CommonMark, a code span opens with N backticks and closes
|
||||
// with exactly N backticks.
|
||||
func inlineCodeRanges(s string) [][2]int {
|
||||
var result [][2]int
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
if s[i] != '`' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// Count opening backticks.
|
||||
start := i
|
||||
n := 0
|
||||
for i < len(s) && s[i] == '`' {
|
||||
n++
|
||||
i++
|
||||
}
|
||||
// Scan for a closing run of exactly n backticks.
|
||||
for j := i; j < len(s); {
|
||||
if s[j] != '`' {
|
||||
j++
|
||||
continue
|
||||
}
|
||||
m := 0
|
||||
for j < len(s) && s[j] == '`' {
|
||||
m++
|
||||
j++
|
||||
}
|
||||
if m == n {
|
||||
result = append(result, [2]int{start, j})
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no closing run was found, i is already past the opening
|
||||
// backticks so the outer loop advances naturally.
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// replaceOutsideInline applies fn only to text outside inline code spans.
|
||||
func replaceOutsideInline(segment string, fn func(string) string) string {
|
||||
ranges := inlineCodeRanges(segment)
|
||||
if len(ranges) == 0 {
|
||||
return fn(segment)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(segment))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
b.WriteString(fn(segment[pos:r[0]]))
|
||||
}
|
||||
b.WriteString(segment[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(segment) {
|
||||
b.WriteString(fn(segment[pos:]))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// stripInlineCode removes inline code spans from s.
|
||||
func stripInlineCode(s string) string {
|
||||
ranges := inlineCodeRanges(s)
|
||||
if len(ranges) == 0 {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(s[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(s[pos:])
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package fences
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want [][2]int
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world\nno code here",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single backtick fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "single tilde fence",
|
||||
content: "before\n~~~\ncode\n~~~\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "fence with info string",
|
||||
content: "before\n```go\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: [][2]int{{2, 12}, {14, 24}},
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\ncode\nmore code",
|
||||
want: [][2]int{{7, 25}},
|
||||
},
|
||||
{
|
||||
name: "longer closing fence",
|
||||
content: "before\n```\ncode\n`````\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "shorter closing fence ignored",
|
||||
content: "before\n`````\ncode\n```\nmore\n`````\nafter",
|
||||
want: [][2]int{{7, 33}},
|
||||
},
|
||||
{
|
||||
name: "indented fence up to 3 spaces",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: [][2]int{{7, 26}},
|
||||
},
|
||||
{
|
||||
name: "4 space indent is not a fence",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "backtick in info string rejects open",
|
||||
// The ```foo`bar line is not a valid opener (backtick in info).
|
||||
// The standalone ``` becomes an opener with no close.
|
||||
content: "before\n```foo`bar\ncode\n```\nafter",
|
||||
want: [][2]int{{23, 32}},
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "fence only",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "fence at end without trailing newline",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "tilde fence does not close with backticks",
|
||||
content: "~~~\ncode\n```\nmore\n~~~\nafter",
|
||||
want: [][2]int{{0, 22}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Ranges(tt.content)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("Ranges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("Ranges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world",
|
||||
want: "HELLO WORLD",
|
||||
},
|
||||
{
|
||||
name: "text around fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: "BEFORE\n```\ncode\n```\nAFTER",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "aaa\n```\nxxx\n```\nbbb\n~~~\nyyy\n~~~\nccc",
|
||||
want: "AAA\n```\nxxx\n```\nBBB\n~~~\nyyy\n~~~\nCCC",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence preserves code",
|
||||
content: "before\n```\ncode",
|
||||
want: "BEFORE\n```\ncode",
|
||||
},
|
||||
{
|
||||
name: "only fenced content",
|
||||
content: "```\ncode\n```",
|
||||
want: "```\ncode\n```",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripFenced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips fenced code",
|
||||
content: "before $1\n```\n$2 inside\n```\nafter $3",
|
||||
want: "before $1\nafter $3",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: "a\nb\nc",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\n$1 inside",
|
||||
want: "before\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripFenced(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripFenced() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineCodeRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want [][2]int
|
||||
}{
|
||||
{"no backticks", "hello world", nil},
|
||||
{"single backtick span", "use `$1` here", [][2]int{{4, 8}}},
|
||||
{"double backtick span", "use ``$1`` here", [][2]int{{4, 10}}},
|
||||
{"multiple spans", "`$1` and `$2`", [][2]int{{0, 4}, {9, 13}}},
|
||||
{"unmatched backtick", "use `$1 here", nil},
|
||||
{"mismatched backtick counts", "use ``$1` here", nil},
|
||||
{"empty inline content", "use `` `` here", [][2]int{{4, 9}}},
|
||||
{"backticks inside double", "use ``foo`bar`` here", [][2]int{{4, 15}}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := inlineCodeRanges(tt.s)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("inlineCodeRanges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("inlineCodeRanges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside_InlineCode(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inline code preserved",
|
||||
content: "use `code` here",
|
||||
want: "USE `code` HERE",
|
||||
},
|
||||
{
|
||||
name: "double backtick inline code",
|
||||
content: "use ``co`de`` here",
|
||||
want: "USE ``co`de`` HERE",
|
||||
},
|
||||
{
|
||||
name: "mixed fenced and inline",
|
||||
content: "before `x` mid\n```\nfenced\n```\nafter `y` end",
|
||||
want: "BEFORE `x` MID\n```\nfenced\n```\nAFTER `y` END",
|
||||
},
|
||||
{
|
||||
name: "only inline code",
|
||||
content: "`code`",
|
||||
want: "`code`",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no code",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips inline code",
|
||||
content: "use `$1` and `$2` for positional args",
|
||||
want: "use and for positional args",
|
||||
},
|
||||
{
|
||||
name: "strips fenced and inline",
|
||||
content: "before `$1`\n```\n$2 inside\n```\nafter",
|
||||
want: "before \nafter",
|
||||
},
|
||||
{
|
||||
name: "real world prompt template",
|
||||
content: "Use $@ for all args.\n`$1`, `$2` for positional.\n```bash\necho $1\n```\n",
|
||||
want: "Use $@ for all args.\n, for positional.\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripCode(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripCode() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+104
-24
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ExtraTools are additional tools added alongside core, MCP, and extension
|
||||
// tools. They do not replace the defaults — they extend them.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
@@ -40,6 +44,34 @@ type AgentSetupOptions struct {
|
||||
// wrapping. Used by the SDK hook system. Both wrappers compose:
|
||||
// extension wrapper runs first (inner), then this wrapper (outer).
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
Debug bool
|
||||
// NoExtensions skips extension loading. When false, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
NoExtensions bool
|
||||
// MaxSteps overrides the agent step limit. 0 means use viper value.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
|
||||
// is also set.
|
||||
StreamingEnabled bool
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -54,15 +86,17 @@ type AgentSetupResult struct {
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
@@ -72,9 +106,6 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
@@ -82,21 +113,66 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
}
|
||||
|
||||
// SetupAgent creates an agent from the current viper state + the provided
|
||||
// options. It wraps BuildProviderConfig and agent.CreateAgent.
|
||||
func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, error) {
|
||||
modelConfig, systemPrompt, err := BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
var bufferedLogger *tools.BufferedDebugLogger
|
||||
if viper.GetBool("debug") {
|
||||
if debugEnabled {
|
||||
if opts.UseBufferedLogger {
|
||||
bufferedLogger = tools.NewBufferedDebugLogger(true)
|
||||
debugLogger = bufferedLogger
|
||||
@@ -108,7 +184,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
// Load extensions unless --no-extensions is set.
|
||||
var extRunner *extensions.Runner
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !viper.GetBool("no-extensions") {
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
if extErr != nil {
|
||||
@@ -137,18 +213,22 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
}
|
||||
|
||||
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streamingEnabled,
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
@@ -187,7 +267,7 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
}
|
||||
|
||||
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools(), runner)
|
||||
extTools := extensions.ExtensionToolsAsLLMTools(runner.RegisteredTools(), runner)
|
||||
|
||||
return runner, extensionCreationOpts{
|
||||
toolWrapper: wrapper,
|
||||
|
||||
+20
-10
@@ -4,12 +4,18 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
@@ -115,9 +121,9 @@ const (
|
||||
)
|
||||
|
||||
// Message is a single conversation message containing a heterogeneous slice
|
||||
// of ContentPart blocks. This design (borrowed from crush) enables a single
|
||||
// assistant message to carry text, reasoning, and multiple tool calls as
|
||||
// discrete, typed blocks rather than flattening everything into strings.
|
||||
// of ContentPart blocks. This design enables a single assistant message to
|
||||
// carry text, reasoning, and multiple tool calls as discrete, typed blocks
|
||||
// rather than flattening everything into strings.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role MessageRole `json:"role"`
|
||||
@@ -312,13 +318,13 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// --- Fantasy bridge ---
|
||||
// --- LLM bridge ---
|
||||
|
||||
// ToFantasyMessages converts a Message to one or more fantasy.Message values.
|
||||
// An assistant message with tool calls produces a single fantasy message with
|
||||
// ToLLMMessages converts a Message to one or more LLM message values.
|
||||
// An assistant message with tool calls produces a single message with
|
||||
// mixed TextPart and ToolCallPart content. Tool-role messages produce
|
||||
// ToolResultPart entries.
|
||||
func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
func (m *Message) ToLLMMessages() []fantasy.Message {
|
||||
switch m.Role {
|
||||
case RoleAssistant:
|
||||
var parts []fantasy.MessagePart
|
||||
@@ -416,9 +422,9 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
|
||||
// FromFantasyMessage converts a fantasy.Message into our Message type,
|
||||
// FromLLMMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
func FromLLMMessage(msg fantasy.Message) Message {
|
||||
m := Message{
|
||||
Role: MessageRole(msg.Role),
|
||||
Parts: make([]ContentPart, 0),
|
||||
@@ -430,7 +436,11 @@ func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
switch p := part.(type) {
|
||||
case fantasy.TextPart:
|
||||
if p.Text != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: p.Text})
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(p.Text, "")
|
||||
if cleanedText != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: cleanedText})
|
||||
}
|
||||
}
|
||||
case fantasy.ToolCallPart:
|
||||
m.Parts = append(m.Parts, ToolCall{
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/openai"
|
||||
)
|
||||
|
||||
// buildCacheProviderOptions returns caching options for supported models.
|
||||
// Caching is enabled by default for all supported models to reduce costs.
|
||||
// Set KIT_DISABLE_CACHE=1 or ProviderConfig.DisableCaching=true to opt out.
|
||||
func buildCacheProviderOptions(modelInfo *ModelInfo, config *ProviderConfig) fantasy.ProviderOptions {
|
||||
// Check explicit opt-out via config
|
||||
if config.DisableCaching {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check global opt-out via environment
|
||||
if os.Getenv("KIT_DISABLE_CACHE") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if model supports caching
|
||||
if modelInfo == nil || !modelInfo.SupportsCaching() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch modelInfo.CacheType() {
|
||||
case "anthropic-ephemeral":
|
||||
// Provider-level Anthropic caching disabled - use message-level caching instead.
|
||||
return nil
|
||||
case "openai-prompt-cache":
|
||||
return buildOpenAICacheOptions(config, modelInfo.ID)
|
||||
case "google-cached-content":
|
||||
// Google caching not yet implemented.
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildOpenAICacheOptions enables prompt caching for OpenAI models.
|
||||
// Uses a deterministic cache key based on system prompt and model ID.
|
||||
func buildOpenAICacheOptions(config *ProviderConfig, modelID string) fantasy.ProviderOptions {
|
||||
cacheKey := generateCacheKey(config.SystemPrompt, modelID)
|
||||
|
||||
return fantasy.ProviderOptions{
|
||||
openai.Name: &openai.ProviderOptions{
|
||||
PromptCacheKey: &cacheKey,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateCacheKey creates a deterministic cache key from system prompt and model.
|
||||
// This ensures the same system prompt + model combination gets cache hits.
|
||||
func generateCacheKey(systemPrompt, modelID string) string {
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = "default"
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
h.Write([]byte(systemPrompt))
|
||||
h.Write([]byte(modelID))
|
||||
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected bool
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", true},
|
||||
{"Claude 4 model", "claude-4-opus", true},
|
||||
{"GPT model", "gpt-4", true},
|
||||
{"GPT-5 model", "gpt-5", true},
|
||||
{"O1 model", "o1", true},
|
||||
{"O3 model", "o3", true},
|
||||
{"O4 model", "o4-mini", true},
|
||||
{"Codex model", "codex", true},
|
||||
{"Gemini model", "gemini-2.5-pro", true},
|
||||
{"Gemini 1.5 model", "gemini-1.5-flash", true},
|
||||
{"Llama model", "llama-3", false},
|
||||
{"Unknown model", "unknown", false},
|
||||
{"Empty family", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.SupportsCaching(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.SupportsCaching() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInfo_CacheType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected string
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", "anthropic-ephemeral"},
|
||||
{"GPT model", "gpt-4", "openai-prompt-cache"},
|
||||
{"O1 model", "o1", "openai-prompt-cache"},
|
||||
{"Gemini model", "gemini-2.5-pro", "google-cached-content"},
|
||||
{"Unknown model", "llama-3", ""},
|
||||
{"Empty family", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.CacheType(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.CacheType() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCacheKey(t *testing.T) {
|
||||
key1 := generateCacheKey("system prompt", "model-id")
|
||||
key2 := generateCacheKey("system prompt", "model-id")
|
||||
if key1 != key2 {
|
||||
t.Errorf("generateCacheKey should be deterministic: got %q and %q", key1, key2)
|
||||
}
|
||||
|
||||
key3 := generateCacheKey("different prompt", "model-id")
|
||||
if key1 == key3 {
|
||||
t.Errorf("generateCacheKey should produce different keys for different inputs")
|
||||
}
|
||||
|
||||
key4 := generateCacheKey("", "model-id")
|
||||
key5 := generateCacheKey("default", "model-id")
|
||||
if key4 != key5 {
|
||||
t.Errorf("generateCacheKey should treat empty prompt as 'default'")
|
||||
}
|
||||
|
||||
if len(key1) < 4 || key1[:4] != "kit-" {
|
||||
t.Errorf("generateCacheKey should produce keys with 'kit-' prefix, got %q", key1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Disabled(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: true}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when DisableCaching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_EnvironmentVariable(t *testing.T) {
|
||||
_ = os.Setenv("KIT_DISABLE_CACHE", "1")
|
||||
defer func() { _ = os.Unsetenv("KIT_DISABLE_CACHE") }()
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when KIT_DISABLE_CACHE is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_UnsupportedModel(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "llama-3", ID: "llama-3-70b"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil for unsupported model families")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_NilModelInfo(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
|
||||
if opts := buildCacheProviderOptions(nil, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when modelInfo is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Anthropic(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
// Provider-level Anthropic caching is disabled; message-level caching is used instead
|
||||
if opts != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_OpenAI(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test system prompt",
|
||||
}
|
||||
modelInfo := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
if opts == nil {
|
||||
t.Fatalf("buildCacheProviderOptions should return options for OpenAI models")
|
||||
}
|
||||
|
||||
if _, ok := opts["openai"]; !ok {
|
||||
t.Errorf("buildCacheProviderOptions should include 'openai' key for GPT models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
// Anthropic uses message-level caching; provider-level returns nil
|
||||
config1 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
modelInfo1 := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
opts1 := buildCacheProviderOptions(modelInfo1, config1)
|
||||
if opts1 != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
|
||||
// OpenAI provider-level caching works with thinking enabled
|
||||
config2 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingMedium,
|
||||
}
|
||||
modelInfo2 := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
opts2 := buildCacheProviderOptions(modelInfo2, config2)
|
||||
if opts2 == nil {
|
||||
t.Errorf("OpenAI caching should work with thinking enabled")
|
||||
}
|
||||
|
||||
// OpenAI caching also works with thinking disabled
|
||||
config3 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
opts3 := buildCacheProviderOptions(modelInfo2, config3)
|
||||
if opts3 == nil {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
+236
-9
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -31,12 +33,14 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
info := ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
Cost: Cost{
|
||||
Input: cfg.Cost.Input,
|
||||
Output: cfg.Cost.Output,
|
||||
@@ -46,19 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert custom model generation params if any are set.
|
||||
if p := convertGenerationParams(cfg.Params); p != nil {
|
||||
info.Params = p
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]*GenerationParams, len(settings))
|
||||
for modelKey, cfg := range settings {
|
||||
if p := convertGenerationParams(cfg); p != nil {
|
||||
result[modelKey] = p
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
|
||||
// Returns nil if no parameters are set.
|
||||
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
|
||||
p := &GenerationParams{}
|
||||
any := false
|
||||
|
||||
if cfg.MaxTokens != nil {
|
||||
p.MaxTokens = cfg.MaxTokens
|
||||
any = true
|
||||
}
|
||||
if cfg.Temperature != nil {
|
||||
p.Temperature = cfg.Temperature
|
||||
any = true
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
p.TopP = cfg.TopP
|
||||
any = true
|
||||
}
|
||||
if cfg.TopK != nil {
|
||||
p.TopK = cfg.TopK
|
||||
any = true
|
||||
}
|
||||
if cfg.FrequencyPenalty != nil {
|
||||
p.FrequencyPenalty = cfg.FrequencyPenalty
|
||||
any = true
|
||||
}
|
||||
if cfg.PresencePenalty != nil {
|
||||
p.PresencePenalty = cfg.PresencePenalty
|
||||
any = true
|
||||
}
|
||||
if len(cfg.StopSequences) > 0 {
|
||||
p.StopSequences = cfg.StopSequences
|
||||
any = true
|
||||
}
|
||||
if cfg.ThinkingLevel != "" {
|
||||
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
|
||||
any = true
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
p.SystemPrompt = cfg.SystemPrompt
|
||||
any = true
|
||||
}
|
||||
|
||||
if !any {
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ApplyModelSettings merges per-model generation parameter defaults from the
|
||||
// registry into a ProviderConfig. Model-level params are only applied for
|
||||
// fields where the user has not explicitly set a value (i.e., the
|
||||
// corresponding viper key is not set via CLI flag or global config).
|
||||
//
|
||||
// The lookup order is:
|
||||
// 1. modelSettings["provider/model"] from config (highest model-level priority)
|
||||
// 2. ModelInfo.Params from custom model definitions
|
||||
//
|
||||
// Both are overridden by explicit CLI flags / global config values.
|
||||
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ModelInfo.Params (from custom model definitions).
|
||||
if params == nil && modelInfo != nil && modelInfo.Params != nil {
|
||||
params = modelInfo.Params
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply each parameter only when the user hasn't explicitly set it.
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
// Resolve file paths: if the value points to an existing file, read it.
|
||||
// We check config.SystemPrompt == "" rather than isExplicitlySet because
|
||||
// viper.BindPFlag causes IsSet to return true even for unset flags.
|
||||
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSystemPromptValue resolves a system prompt value that may be either
|
||||
// inline text or a file path. If the value is a path to an existing file,
|
||||
// its contents are read and returned. Otherwise the string is returned as-is.
|
||||
// This mirrors config.LoadSystemPrompt but lives in the models package to
|
||||
// avoid circular dependencies.
|
||||
func LoadSystemPromptValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
if info, err := os.Stat(input); err == nil && !info.IsDir() {
|
||||
content, err := os.ReadFile(input)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
|
||||
return input
|
||||
}
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
// as a default). For generation params, the global config keys use
|
||||
// hyphenated names (e.g. "max-tokens", "top-p").
|
||||
//
|
||||
// Since viper merges all sources, IsSet returns true even for config
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
// These are stored on ModelInfo and applied during provider creation.
|
||||
// Nil pointer fields mean "no model-level default" — the global config
|
||||
// or CLI flag value (if any) will be used instead.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
ThinkingLevel ThinkingLevel
|
||||
SystemPrompt string // Per-model system prompt (inline text or file path)
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
|
||||
// parameter defaults. Used in both customModels[].params and modelSettings[].
|
||||
type GenerationParamsConfig struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestConvertGenerationParams(t *testing.T) {
|
||||
t.Run("empty config returns nil", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("temperature only", func(t *testing.T) {
|
||||
temp := float32(0.7)
|
||||
cfg := GenerationParamsConfig{Temperature: &temp}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP != nil {
|
||||
t.Errorf("expected nil TopP, got %v", p.TopP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all params set", func(t *testing.T) {
|
||||
maxTokens := 8192
|
||||
temp := float32(0.5)
|
||||
topP := float32(0.9)
|
||||
topK := int32(50)
|
||||
freqPenalty := float32(0.1)
|
||||
presPenalty := float32(0.2)
|
||||
cfg := GenerationParamsConfig{
|
||||
MaxTokens: &maxTokens,
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &freqPenalty,
|
||||
PresencePenalty: &presPenalty,
|
||||
StopSequences: []string{"STOP"},
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
|
||||
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP == nil || *p.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", p.TopP)
|
||||
}
|
||||
if p.TopK == nil || *p.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", p.TopK)
|
||||
}
|
||||
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
|
||||
}
|
||||
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
|
||||
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
|
||||
}
|
||||
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level parsing", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingMedium {
|
||||
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
t.Run("system prompt only", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.SystemPrompt != "You are helpful." {
|
||||
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithParams(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
topP := float32(0.95)
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
Temperature: true,
|
||||
Params: GenerationParamsConfig{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
},
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params == nil {
|
||||
t.Fatal("expected non-nil Params")
|
||||
}
|
||||
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
|
||||
}
|
||||
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params != nil {
|
||||
t.Errorf("expected nil Params, got %+v", info.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyModelSettings(t *testing.T) {
|
||||
// Save and restore viper state.
|
||||
originalViper := viper.AllSettings()
|
||||
defer func() {
|
||||
viper.Reset()
|
||||
for k, v := range originalViper {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("applies model params when not explicitly set", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
maxTokens := 4096
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
MaxTokens: &maxTokens,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature == nil || *config.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", config.TopK)
|
||||
}
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit viper values take precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.Set("temperature", 0.3)
|
||||
|
||||
temp := float32(0.8)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
explicitTemp := float32(0.3)
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
Temperature: &explicitTemp,
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Temperature should NOT be overridden because it's explicitly set in viper
|
||||
if config.Temperature == nil || *config.Temperature != 0.3 {
|
||||
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model info is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
ApplyModelSettings(config, nil)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model info without params is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{ID: "test-model"}
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Set up modelSettings in viper (simulating config file)
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"temperature": 0.5,
|
||||
"topK": 30,
|
||||
},
|
||||
})
|
||||
|
||||
// ModelInfo has different params
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// modelSettings should win over ModelInfo.Params
|
||||
if config.Temperature == nil || *config.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 30 {
|
||||
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop sequences applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
ThinkingLevel: ThinkingHigh,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "You are a coding assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "You are a coding assistant." {
|
||||
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "Model-specific prompt",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
SystemPrompt: "Global prompt",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Global system prompt should NOT be overridden because config
|
||||
// already has a non-empty SystemPrompt.
|
||||
if config.SystemPrompt != "Global prompt" {
|
||||
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt from file path", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Create a temp file with a system prompt
|
||||
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "Prompt from file" {
|
||||
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"systemPrompt": "From modelSettings",
|
||||
},
|
||||
})
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "From custom model",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "From modelSettings" {
|
||||
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -48,10 +48,10 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToFantasyProvider maps npm package names from models.dev to fantasy
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToFantasyProvider = map[string]string{
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
|
||||
+161
-53
@@ -22,9 +22,9 @@ import (
|
||||
"charm.land/fantasy/providers/openaicompat"
|
||||
"charm.land/fantasy/providers/openrouter"
|
||||
"charm.land/fantasy/providers/vercel"
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/ui/progress"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -142,19 +142,28 @@ func ParseThinkingLevel(s string) ThinkingLevel {
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers.
|
||||
type ProviderConfig struct {
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
|
||||
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
|
||||
// for long operations like Ollama model pulls. The returned io.ReadCloser
|
||||
// must be closed when done. When nil, the raw reader is consumed directly
|
||||
// with no progress UI.
|
||||
ProgressReaderFunc func(io.Reader) io.ReadCloser
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -237,30 +246,69 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Apply per-model generation parameter defaults. Model-level params are
|
||||
// only applied for fields where the user hasn't explicitly set a value
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Auto-raise MaxTokens toward the model's known output ceiling when the
|
||||
// user hasn't explicitly set --max-tokens and no per-model override
|
||||
// applied. Runs after ApplyModelSettings so explicit modelSettings win.
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
return createOpenAIProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
return createGoogleProvider(ctx, config, modelName)
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
return createOllamaProvider(ctx, config, modelName)
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
return createAzureProvider(ctx, config, modelName)
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
return createVertexAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
return createOpenRouterProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
return createBedrockProvider(ctx, config, modelName)
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
return createVercelProvider(ctx, config, modelName)
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
case "custom":
|
||||
return createCustomProvider(ctx, config, modelName)
|
||||
result, createErr = createCustomProvider(ctx, config, modelName)
|
||||
default:
|
||||
return autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
result, createErr = autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
}
|
||||
|
||||
if createErr != nil {
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
// AUTOMATICALLY ENABLE CACHING for supported models (unless disabled).
|
||||
// This works for BOTH native and auto-routed providers by detecting
|
||||
// the model family from the model metadata.
|
||||
if cacheOpts := buildCacheProviderOptions(modelInfo, config); cacheOpts != nil {
|
||||
if result.ProviderOptions == nil {
|
||||
result.ProviderOptions = cacheOpts
|
||||
} else {
|
||||
// Merge cache options with existing provider options.
|
||||
// Only add cache options for providers that don't already have
|
||||
// options set, to avoid type conflicts (e.g., Anthropic has
|
||||
// different types for regular options vs cache control options).
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
@@ -280,14 +328,14 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
llmProvider = "openaicompat"
|
||||
}
|
||||
|
||||
switch fantasyProvider {
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
@@ -301,7 +349,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -446,6 +494,37 @@ func validateModelConfig(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
// defaultRightSizeCap bounds auto-raised MaxTokens so that we don't silently
|
||||
// allocate enormous output budgets for models with very high ceilings (e.g.
|
||||
// Devstral at 262144, Mistral at 128000). Users who genuinely want more can
|
||||
// pass --max-tokens explicitly or set modelSettings[...].maxTokens in config.
|
||||
const defaultRightSizeCap = 32768
|
||||
|
||||
// rightSizeMaxTokens raises config.MaxTokens toward the model's known output
|
||||
// ceiling when:
|
||||
// - the user has not explicitly set --max-tokens (or the KIT_MAX_TOKENS env
|
||||
// var, or the top-level max-tokens key in config.yaml), AND
|
||||
// - no per-model override already bumped MaxTokens (ApplyModelSettings runs
|
||||
// before this function), AND
|
||||
// - modelInfo.Limit.Output is known and larger than the current MaxTokens.
|
||||
//
|
||||
// The raised value is capped at defaultRightSizeCap to keep accidental
|
||||
// allocations reasonable on very-large-output models. This prevents the
|
||||
// common "ghost" where the agent's reply is silently truncated at the 8192
|
||||
// default even though the selected model supports 64k or 262k output tokens.
|
||||
func rightSizeMaxTokens(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
if modelInfo == nil || modelInfo.Limit.Output <= 0 {
|
||||
return
|
||||
}
|
||||
if isExplicitlySet("max-tokens") {
|
||||
return
|
||||
}
|
||||
target := min(modelInfo.Limit.Output, defaultRightSizeCap)
|
||||
if config.MaxTokens < target {
|
||||
config.MaxTokens = target
|
||||
}
|
||||
}
|
||||
|
||||
// clearConflictingAnthropicSamplingParams ensures that temperature and top_p are
|
||||
// not both sent to the Anthropic API, which rejects requests containing both.
|
||||
// When both are set (typically from defaults), top_p is cleared so that
|
||||
@@ -493,13 +572,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
|
||||
return new(openai.ReasoningEffortLow)
|
||||
case ThinkingMedium:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
|
||||
return new(openai.ReasoningEffortMedium)
|
||||
case ThinkingHigh:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
|
||||
return new(openai.ReasoningEffortHigh)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -510,10 +589,15 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
|
||||
// SendReasoning to true and configures the thinking budget. For thinking-off
|
||||
// or non-reasoning models the returned map is nil.
|
||||
//
|
||||
// NOTE: With message-level caching, thinking and caching can work together.
|
||||
// Message-level cache control (ProviderCacheControlOptions) doesn't conflict
|
||||
// with provider-level thinking options (ProviderOptions).
|
||||
//
|
||||
// Anthropic requires max_tokens > thinking.budget_tokens. If the configured
|
||||
// MaxTokens is too low, it is bumped to budget + 4096 to leave room for the
|
||||
// actual response.
|
||||
func buildAnthropicProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
// Thinking is OFF by default. If user hasn't explicitly enabled it, return nil.
|
||||
if config.ThinkingLevel == "" || config.ThinkingLevel == ThinkingOff {
|
||||
return nil
|
||||
}
|
||||
@@ -963,12 +1047,29 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
|
||||
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
|
||||
return openai.DefaultToPrompt(prompt, systemPrompt, user)
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
// Resolve base URL: per-model override > global provider-url flag/config
|
||||
registry := GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel("custom", modelName)
|
||||
|
||||
baseURL := config.ProviderURL
|
||||
if modelInfo != nil && modelInfo.BaseURL != "" {
|
||||
baseURL = modelInfo.BaseURL
|
||||
}
|
||||
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url or a baseUrl in the model config")
|
||||
}
|
||||
|
||||
apiKey := config.ProviderAPIKey
|
||||
if modelInfo != nil && modelInfo.APIKey != "" {
|
||||
apiKey = modelInfo.APIKey
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("CUSTOM_API_KEY")
|
||||
}
|
||||
@@ -977,16 +1078,21 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
opts = append(opts, openaicompat.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openaicompat.WithAPIKey(apiKey))
|
||||
opts = append(opts, openaicompat.WithName("custom"))
|
||||
// <think> tag extraction is handled transparently at the agent layer,
|
||||
// so no provider-level hooks are needed here.
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithName("custom"))
|
||||
opts = append(opts, openai.WithLanguageModelOptions(
|
||||
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
|
||||
))
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
opts = append(opts, openai.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
}
|
||||
@@ -1063,7 +1169,7 @@ func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string,
|
||||
// Phase 1: Check if model exists locally
|
||||
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
|
||||
// Phase 2: Pull model if not found
|
||||
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
|
||||
if err := pullOllamaModel(ctx, client, baseURL, modelName, config.ProgressReaderFunc); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
|
||||
}
|
||||
}
|
||||
@@ -1106,6 +1212,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
|
||||
if config.TopK != nil {
|
||||
options["top_k"] = int(*config.TopK)
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
options["frequency_penalty"] = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
options["presence_penalty"] = *config.PresencePenalty
|
||||
}
|
||||
if len(config.StopSequences) > 0 {
|
||||
options["stop"] = config.StopSequences
|
||||
}
|
||||
@@ -1146,11 +1258,7 @@ func checkOllamaModelExists(client *http.Client, baseURL, modelName string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
|
||||
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
|
||||
}
|
||||
|
||||
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
|
||||
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string, progressFn func(io.Reader) io.ReadCloser) error {
|
||||
reqBody := map[string]string{"name": modelName}
|
||||
jsonBody, _ := json.Marshal(reqBody)
|
||||
|
||||
@@ -1174,10 +1282,10 @@ func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseU
|
||||
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
if showProgress {
|
||||
progressReader := progress.NewProgressReader(resp.Body)
|
||||
defer func() { _ = progressReader.Close() }()
|
||||
_, err = io.ReadAll(progressReader)
|
||||
if progressFn != nil {
|
||||
pr := progressFn(resp.Body)
|
||||
defer func() { _ = pr.Close() }()
|
||||
_, err = io.ReadAll(pr)
|
||||
} else {
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
+114
-9
@@ -17,12 +17,58 @@ var embeddedModelsJSON []byte
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Family string // Model family (e.g., "claude", "gpt", "gemini")
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
|
||||
// Params holds per-model generation parameter defaults. These are applied
|
||||
// when the user hasn't explicitly set the corresponding CLI flag or global
|
||||
// config value. Nil pointer fields mean "no model-level default".
|
||||
Params *GenerationParams
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
// This enables automatic cost savings for supported models regardless of provider.
|
||||
func (m *ModelInfo) SupportsCaching() bool {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CacheType returns the appropriate cache mechanism for this model family.
|
||||
// Returns empty string if caching is not supported.
|
||||
func (m *ModelInfo) CacheType() string {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return "anthropic-ephemeral"
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return "openai-prompt-cache"
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return "google-cached-content"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -86,6 +132,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
Family: dm.Family,
|
||||
Attachment: dm.Attachment,
|
||||
Reasoning: dm.Reasoning,
|
||||
Temperature: dm.Temperature,
|
||||
@@ -194,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// LookupModelForSettings is a convenience function that parses a
|
||||
// "provider/model" string and looks up the ModelInfo in the global registry.
|
||||
// Returns nil when the model string is invalid or the model is unknown.
|
||||
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
||||
func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return GetGlobalRegistry().LookupModel(provider, modelName)
|
||||
}
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
@@ -308,27 +367,27 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// GetLLMProviders returns provider IDs that have LLM support,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
func (r *ModelsRegistry) GetLLMProviders() []string {
|
||||
var providers []string
|
||||
for providerID, info := range r.providers {
|
||||
if isProviderFantasySupported(providerID, &info) {
|
||||
if isProviderLLMSupported(providerID, &info) {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// isProviderFantasySupported checks if a provider can be used with fantasy.
|
||||
func isProviderFantasySupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama and custom are always supported (model names are user-defined).
|
||||
if providerID == "ollama" || providerID == "custom" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to a fantasy provider
|
||||
if _, ok := npmToFantasyProvider[info.NPM]; ok {
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -355,6 +414,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
return &info
|
||||
}
|
||||
|
||||
// ValidateModelString checks whether a model string is well-formed and refers
|
||||
// to a known provider. It returns a user-friendly error with suggestions when
|
||||
// the model or provider is unrecognised. Passing validation does not guarantee
|
||||
// that API authentication will succeed — it only catches obvious mistakes
|
||||
// (typos, missing provider prefix, non-existent provider names) early so that
|
||||
// callers such as subagent spawning can return fast feedback.
|
||||
//
|
||||
// Unknown models under a known provider are allowed (the provider API is the
|
||||
// authority), but a completely unknown provider is rejected.
|
||||
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ollama and custom are always valid — model names are user-defined.
|
||||
if provider == "ollama" || provider == "custom" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the provider exists in the registry.
|
||||
providerInfo := r.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
known := r.GetSupportedProviders()
|
||||
return fmt.Errorf(
|
||||
"unknown provider %q in model string %q. Known providers: %s",
|
||||
provider, modelString, strings.Join(known, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// Provider exists — check if the model is known. An unknown model is
|
||||
// only a warning (the provider API decides), but we surface suggestions
|
||||
// so the caller can self-correct.
|
||||
if r.LookupModel(provider, modelName) == nil {
|
||||
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
return fmt.Errorf(
|
||||
"model %q not found for provider %s. Did you mean one of: %s",
|
||||
modelName, provider, strings.Join(suggestions, ", "),
|
||||
)
|
||||
}
|
||||
// No suggestions — let it through; the provider API is the authority.
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateModelString(t *testing.T) {
|
||||
registry := GetGlobalRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantErr bool
|
||||
errSubstr string // expected substring in error message (empty = don't check)
|
||||
}{
|
||||
{
|
||||
name: "valid anthropic model",
|
||||
model: "anthropic/claude-sonnet-4-6",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider prefix",
|
||||
model: "claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
model: "fakeprovider/some-model",
|
||||
wantErr: true,
|
||||
errSubstr: "unknown provider",
|
||||
},
|
||||
{
|
||||
name: "ollama always valid",
|
||||
model: "ollama/llama3",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "custom always valid",
|
||||
model: "custom/my-fine-tune",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty provider",
|
||||
model: "/claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
model: "anthropic/",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown model under known provider (no suggestions)",
|
||||
model: "anthropic/totally-unknown-xyz-999",
|
||||
wantErr: false, // no suggestions → passes through
|
||||
},
|
||||
{
|
||||
name: "typo model under known provider with suggestions",
|
||||
model: "anthropic/claude-sonet", // misspelled "sonnet"
|
||||
wantErr: true,
|
||||
errSubstr: "Did you mean",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := registry.ValidateModelString(tt.model)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("ValidateModelString(%q) = nil, want error", tt.model)
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err)
|
||||
}
|
||||
if tt.errSubstr != "" && err != nil {
|
||||
if !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("ValidateModelString(%q) error = %q, want substring %q",
|
||||
tt.model, err.Error(), tt.errSubstr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/pflag"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// bindMaxTokensFlag wires a fresh pflag-backed "max-tokens" key into viper so
|
||||
// isExplicitlySet behaves the same way it does in production. Returns a
|
||||
// cleanup function that removes the binding so sibling tests see a clean
|
||||
// state.
|
||||
func bindMaxTokensFlag(t *testing.T, args []string) func() {
|
||||
t.Helper()
|
||||
fs := pflag.NewFlagSet("test", pflag.ContinueOnError)
|
||||
fs.Int("max-tokens", 8192, "")
|
||||
if err := viper.BindPFlag("max-tokens", fs.Lookup("max-tokens")); err != nil {
|
||||
t.Fatalf("BindPFlag: %v", err)
|
||||
}
|
||||
if err := fs.Parse(args); err != nil {
|
||||
t.Fatalf("fs.Parse: %v", err)
|
||||
}
|
||||
return func() {
|
||||
viper.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_RaisesWhenBelowCeiling(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil) // no args → flag.Changed = false
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "claude-sonnet-4-5",
|
||||
Limit: Limit{Context: 200000, Output: 64000},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 32768 {
|
||||
t.Errorf("expected MaxTokens raised to defaultRightSizeCap (32768), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_CapsAtDefaultRightSizeCap(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Mistral Devstral has 262144 output — we should still cap at 32768.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "devstral-medium-latest",
|
||||
Limit: Limit{Context: 262144, Output: 262144},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != defaultRightSizeCap {
|
||||
t.Errorf("expected MaxTokens capped at %d, got %d", defaultRightSizeCap, config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_UsesExactOutputWhenBelowCap(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 4096}
|
||||
// Model with output limit smaller than the cap.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "gpt-4",
|
||||
Limit: Limit{Context: 8192, Output: 8192},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens raised to model output ceiling (8192), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_DoesNotLowerCurrentValue(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
// User (via per-model settings, applied earlier) already bumped MaxTokens
|
||||
// above the cap — we must not clobber their choice.
|
||||
config := &ProviderConfig{MaxTokens: 100000}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "devstral-medium-latest",
|
||||
Limit: Limit{Context: 262144, Output: 262144},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 100000 {
|
||||
t.Errorf("expected MaxTokens preserved at 100000, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_RespectsExplicitFlag(t *testing.T) {
|
||||
// Simulate `--max-tokens 4096` on the command line.
|
||||
cleanup := bindMaxTokensFlag(t, []string{"--max-tokens", "4096"})
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 4096}
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "claude-sonnet-4-5",
|
||||
Limit: Limit{Context: 200000, Output: 64000},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected explicit --max-tokens to be preserved (4096), got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_NilModelInfo(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Custom model / Ollama / unknown provider → no model info.
|
||||
rightSizeMaxTokens(config, nil)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens unchanged with nil modelInfo, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRightSizeMaxTokens_ZeroOutputLimit(t *testing.T) {
|
||||
cleanup := bindMaxTokensFlag(t, nil)
|
||||
defer cleanup()
|
||||
|
||||
config := &ProviderConfig{MaxTokens: 8192}
|
||||
// Model present in catalog but with no known output limit.
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "unknown-model",
|
||||
Limit: Limit{Context: 0, Output: 0},
|
||||
}
|
||||
|
||||
rightSizeMaxTokens(config, modelInfo)
|
||||
|
||||
if config.MaxTokens != 8192 {
|
||||
t.Errorf("expected MaxTokens unchanged with zero output limit, got %d", config.MaxTokens)
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,10 @@ package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// LoadOptions configures how templates are discovered and loaded.
|
||||
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
|
||||
})
|
||||
log.Debug("template collision",
|
||||
"name", tpl.Name,
|
||||
"dropped", tpl.FilePath,
|
||||
"kept", existing.FilePath)
|
||||
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
|
||||
} else {
|
||||
tpl.Source = source
|
||||
seen[tpl.Name] = tpl
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/fences"
|
||||
)
|
||||
|
||||
// PromptTemplate is a named prompt template with shell-style argument placeholders.
|
||||
// It supports Pi-style $1, $2, $@, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
// It supports Pi-style $1, $2, $@, $+, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
type PromptTemplate struct {
|
||||
// Name is the human-readable identifier for this template.
|
||||
Name string
|
||||
@@ -120,19 +122,28 @@ func ParseCommandArgs(input string) []string {
|
||||
|
||||
// argPlaceholder matches shell-style argument placeholders:
|
||||
// - $1, $2, etc. - positional arguments
|
||||
// - $@ - all arguments
|
||||
// - $@ - all arguments (zero or more)
|
||||
// - $+ - all arguments (one or more required)
|
||||
// - $ARGUMENTS - all arguments (alias for $@)
|
||||
// - ${@:N} - arguments from N onwards
|
||||
// - ${@:N:L} - L arguments starting from N
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$ARGUMENTS`)
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$\+|\$ARGUMENTS`)
|
||||
|
||||
// SubstituteArgs replaces argument placeholders in content with values from args.
|
||||
// Supported placeholders:
|
||||
// - $N, ${N} - the Nth argument (1-indexed)
|
||||
// - $@, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - $@, $+, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - ${@:N} - arguments from index N onwards (0-indexed)
|
||||
// - ${@:N:L} - L arguments starting from index N (0-indexed)
|
||||
func SubstituteArgs(content string, args []string) string {
|
||||
return fences.ReplaceOutside(content, func(segment string) string {
|
||||
return substituteArgsInSegment(segment, args)
|
||||
})
|
||||
}
|
||||
|
||||
// substituteArgsInSegment performs argument substitution on a single text
|
||||
// segment that is known to be outside fenced code blocks.
|
||||
func substituteArgsInSegment(content string, args []string) string {
|
||||
return argPlaceholder.ReplaceAllStringFunc(content, func(match string) string {
|
||||
// Check for ${N} or ${N:M} format
|
||||
if strings.HasPrefix(match, "${") && strings.Contains(match, "}") {
|
||||
@@ -191,8 +202,8 @@ func SubstituteArgs(content string, args []string) string {
|
||||
if strings.HasPrefix(match, "$") && !strings.HasPrefix(match, "${") {
|
||||
suffix := match[1:]
|
||||
|
||||
// $@ or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "ARGUMENTS" {
|
||||
// $@, $+, or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "+" || suffix == "ARGUMENTS" {
|
||||
return strings.Join(args, " ")
|
||||
}
|
||||
|
||||
@@ -266,6 +277,48 @@ func joinArgsRange(args []string, start, length int) string {
|
||||
return strings.Join(args[start:end], " ")
|
||||
}
|
||||
|
||||
// HasArgPlaceholders reports whether the template content contains any
|
||||
// argument placeholders ($1, $@, $ARGUMENTS, ${@:...}, etc.).
|
||||
// Placeholders inside fenced code blocks and inline code spans are ignored.
|
||||
func (t *PromptTemplate) HasArgPlaceholders() bool {
|
||||
return argPlaceholder.MatchString(fences.StripCode(t.Content))
|
||||
}
|
||||
|
||||
// RequiredArgs returns the number of positional arguments the template
|
||||
// expects. This is determined by the highest $N or ${N} placeholder found
|
||||
// in the content (1-indexed, so $2 means 2 args required). The $+
|
||||
// placeholder (required variadic) ensures at least 1. Optional wildcards
|
||||
// ($@, $ARGUMENTS) do not contribute to the count.
|
||||
func (t *PromptTemplate) RequiredArgs() int {
|
||||
content := fences.StripCode(t.Content)
|
||||
maxN := 0
|
||||
hasRequiredVariadic := strings.Contains(content, "$+")
|
||||
for _, match := range argPlaceholder.FindAllStringSubmatch(content, -1) {
|
||||
// Group 1: ${N} format — the N value.
|
||||
if match[1] != "" {
|
||||
if n, err := strconv.Atoi(match[1]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 2: ${N:M} format — the N value (start index).
|
||||
if match[2] != "" {
|
||||
if n, err := strconv.Atoi(match[2]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 6: $N format (no braces) — the N value.
|
||||
if match[6] != "" {
|
||||
if n, err := strconv.Atoi(match[6]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasRequiredVariadic && maxN < 1 {
|
||||
maxN = 1
|
||||
}
|
||||
return maxN
|
||||
}
|
||||
|
||||
// Expand substitutes arguments into the template content and returns the result.
|
||||
// It first parses args from the input string, then substitutes them into the template.
|
||||
func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
|
||||
@@ -129,6 +129,48 @@ func TestSubstituteArgs(t *testing.T) {
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "$1 inside code block preserved",
|
||||
content: "Use $1 here\n```bash\necho $1\n```\ndone",
|
||||
args: []string{"foo"},
|
||||
expected: "Use foo here\n```bash\necho $1\n```\ndone",
|
||||
},
|
||||
{
|
||||
name: "$@ inside code block preserved",
|
||||
content: "Run $@\n```\necho $@\n```\n",
|
||||
args: []string{"a", "b"},
|
||||
expected: "Run a b\n```\necho $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "all placeholders inside code block",
|
||||
content: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
args: []string{"x"},
|
||||
expected: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "$1 inside inline code preserved",
|
||||
content: "Use `$1` here and $1 outside",
|
||||
args: []string{"foo"},
|
||||
expected: "Use `$1` here and foo outside",
|
||||
},
|
||||
{
|
||||
name: "$+ required variadic",
|
||||
content: "Args: $+",
|
||||
args: []string{"a", "b", "c"},
|
||||
expected: "Args: a b c",
|
||||
},
|
||||
{
|
||||
name: "$+ with empty args",
|
||||
content: "Args: $+",
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "all placeholders in inline code",
|
||||
content: "Use `$1` and `$@` for args",
|
||||
args: []string{"x"},
|
||||
expected: "Use `$1` and `$@` for args",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -213,3 +255,78 @@ func TestPromptTemplateExpand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasArgPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want bool
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt with no args", false},
|
||||
{"$1 placeholder", "Create a $1 component", true},
|
||||
{"$@ placeholder", "Run with args: $@", true},
|
||||
{"$ARGUMENTS placeholder", "Features: $ARGUMENTS", true},
|
||||
{"${1} placeholder", "Name: ${1}", true},
|
||||
{"${ARGUMENTS} placeholder", "All: ${ARGUMENTS}", true},
|
||||
{"${@:1} placeholder", "Rest: ${@:1}", true},
|
||||
{"${@:1:2} placeholder", "Slice: ${@:1:2}", true},
|
||||
{"dollar in text", "Cost is one hundred dollars", false},
|
||||
{"empty content", "", false},
|
||||
{"$1 inside code block only", "Prompt\n```\necho $1\n```\n", false},
|
||||
{"$1 outside and inside code block", "Use $1 here\n```\necho $1\n```\n", true},
|
||||
{"$@ inside code block only", "Prompt\n```bash\necho $@\n```\n", false},
|
||||
{"$+ placeholder", "Run with args: $+", true},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", false},
|
||||
{"$1 inside inline code only", "Use `$1` for positional args", false},
|
||||
{"$1 outside and in inline code", "Create $1 (see `$1` syntax)", true},
|
||||
{"$@ outside $1 in inline code", "Run $@ with `$1` syntax", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.HasArgPlaceholders(); got != tt.want {
|
||||
t.Errorf("HasArgPlaceholders() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequiredArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want int
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt", 0},
|
||||
{"$1 only", "Create a $1 component", 1},
|
||||
{"$1 and $2", "Create $1 with $2", 2},
|
||||
{"$3 skipping $2", "Use $1 and $3", 3},
|
||||
{"${1} braced", "Name: ${1}", 1},
|
||||
{"${2} braced", "Name: ${1} Desc: ${2}", 2},
|
||||
{"$@ only", "Run with: $@", 0},
|
||||
{"$ARGUMENTS only", "Features: $ARGUMENTS", 0},
|
||||
{"${ARGUMENTS} only", "All: ${ARGUMENTS}", 0},
|
||||
{"$1 and $@", "Create $1 with extras: $@", 1},
|
||||
{"${@:1} slice only", "Rest: ${@:1}", 0},
|
||||
{"${@:1:2} slice only", "Slice: ${@:1:2}", 0},
|
||||
{"mixed $1 $2 and $@", "Create $1 named $2: $@", 2},
|
||||
{"empty content", "", 0},
|
||||
{"$2 inside code block only", "Prompt\n```\n$1 $2\n```\n", 0},
|
||||
{"$1 outside $2 inside code block", "Use $1\n```\n$2 inside\n```\n", 1},
|
||||
{"$+ only", "Run with: $+", 1},
|
||||
{"$+ and $2", "Create $2 with: $+", 2},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", 0},
|
||||
{"$1 and $2 in inline code only", "Use `$1` and `$2` for args", 0},
|
||||
{"$1 outside $2 in inline code", "Create $1 (see `$2`)", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.RequiredArgs(); got != tt.want {
|
||||
t.Errorf("RequiredArgs() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestCompactionCreatesNewLeaf verifies that after compaction, the compaction
|
||||
// entry has no parent (creating a new root), and BuildContext returns only
|
||||
// the summary and kept messages, not the old compacted messages.
|
||||
func TestCompactionCreatesNewLeaf(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add some messages: M1, M2 (old, will be compacted), M3, M4 (kept)
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1 - old"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2 - old"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Verify initial state - all messages should be in context
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages before compaction, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs before compaction, got %d", len(entryIDs))
|
||||
}
|
||||
|
||||
// Now add a compaction entry, simulating that M3 is the first kept entry
|
||||
summary := "Summary of old messages"
|
||||
compactionID, err := tm.AppendCompaction(summary, id3, 1000, 500, 2, []string{}, []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append compaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify the compaction entry has no parent (empty ParentID)
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// Verify the leaf is now the compaction entry
|
||||
if tm.GetLeafID() != compactionID {
|
||||
t.Errorf("leaf should be compaction entry %q, got %q", compactionID, tm.GetLeafID())
|
||||
}
|
||||
|
||||
// Now BuildContext should return: [summary] + [M3, M4]
|
||||
messages, _, _ = tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages after compaction (summary + 2 kept), got %d", len(messages))
|
||||
}
|
||||
|
||||
// First message should be the summary
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\n"+summary {
|
||||
t.Errorf("unexpected summary text: %s", summaryText)
|
||||
}
|
||||
|
||||
// Second message should be M3 (kept)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3), got %s", messages[1].Role)
|
||||
}
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
|
||||
// Third message should be M4 (kept)
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - kept" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
|
||||
// Verify GetContextEntryIDs returns correct IDs
|
||||
entryIDs = tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 3 {
|
||||
t.Fatalf("expected 3 entry IDs after compaction (empty for summary + 2 kept), got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First entry ID should be empty (summary has no entry)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Second and third should be id3 and id4 (the kept messages)
|
||||
if entryIDs[1] != id3 {
|
||||
t.Errorf("second entry ID should be %q (M3), got %q", id3, entryIDs[1])
|
||||
}
|
||||
if entryIDs[2] != id4 {
|
||||
t.Errorf("third entry ID should be %q (M4), got %q", id4, entryIDs[2])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNewMessagesAfterCompaction verifies that messages appended
|
||||
// after compaction are correctly included in the context.
|
||||
func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add initial messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
|
||||
// Compact, keeping only M3
|
||||
_, _ = tm.AppendCompaction("Summary", id3, 1000, 500, 2, []string{}, []string{})
|
||||
|
||||
// Add a new message after compaction
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Message 1" {
|
||||
t.Errorf("Message 1 (compacted) should not be in context at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNoKeptMessages verifies compaction when all messages are compacted.
|
||||
func TestCompactionWithNoKeptMessages(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add messages that will all be compacted
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
|
||||
if _, err := tm.AppendMessage(msg1); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
if _, err := tm.AppendMessage(msg2); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
|
||||
// Compact with no kept messages (empty firstKeptEntryID)
|
||||
summary := "All messages summarized"
|
||||
compactionID, _ := tm.AppendCompaction(summary, "", 1000, 100, 2, []string{}, []string{})
|
||||
|
||||
// Verify the compaction entry has no parent
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// BuildContext should return only the summary
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message (summary only), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleCompactions verifies that multiple compactions work correctly.
|
||||
func TestMultipleCompactions(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// First batch of messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - User"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - Assistant"}}}
|
||||
id1, _ := tm.AppendMessage(msg1)
|
||||
id2, _ := tm.AppendMessage(msg2)
|
||||
|
||||
// First compaction
|
||||
_, _ = tm.AppendCompaction("Summary 1", id1, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
// Second batch
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - User"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - Assistant"}}}
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Second compaction (compacting the first compaction + batch 2)
|
||||
// Note: id3 is the first kept entry, so id3 and id4 should be preserved
|
||||
compactionID2, _ := tm.AppendCompaction("Summary 2", id3, 1000, 500, 3, []string{}, []string{})
|
||||
|
||||
// Verify second compaction has no parent
|
||||
compactionEntry2 := tm.GetEntry(compactionID2).(*CompactionEntry)
|
||||
if compactionEntry2.ParentID != "" {
|
||||
t.Errorf("second compaction entry should have no parent, got %q", compactionEntry2.ParentID)
|
||||
}
|
||||
|
||||
// Add final message
|
||||
msg5 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Final message"}}}
|
||||
id5, _ := tm.AppendMessage(msg5)
|
||||
|
||||
// BuildContext should include:
|
||||
// - Summary 2 (from second compaction)
|
||||
// - msg5 (final message)
|
||||
// - msg3, msg4 (kept from second compaction)
|
||||
// But NOT Summary 1 or msg1, msg2 (they're before the first kept entry of compaction 2)
|
||||
messages, _, _ := tm.BuildContext()
|
||||
|
||||
// Should have: Summary 2 + msg5 + msg3 + msg4 = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages (Summary 2 + msg5 + msg3 + msg4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// First should be Summary 2
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system (Summary 2), got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\nSummary 2" {
|
||||
t.Errorf("unexpected summary: %s", summaryText)
|
||||
}
|
||||
|
||||
// Verify msg5 is included
|
||||
foundFinal := false
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Final message" {
|
||||
foundFinal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundFinal {
|
||||
t.Error("Final message (msg5) should be in context")
|
||||
}
|
||||
|
||||
// Verify msg1, msg2 are NOT included (compacted by first compaction, then second)
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser || msg.Role == fantasy.MessageRoleAssistant {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Batch 1 - User" || text == "Batch 1 - Assistant" {
|
||||
t.Errorf("Batch 1 messages should not be in context, found: %s", text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs, got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First should be empty (summary)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Check that id5 is in the list
|
||||
if !slices.Contains(entryIDs, id5) {
|
||||
t.Errorf("id5 (final message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id3 and id4 ARE in the list (they were kept)
|
||||
foundID3, foundID4 := false, false
|
||||
for _, id := range entryIDs {
|
||||
if id == id3 {
|
||||
foundID3 = true
|
||||
}
|
||||
if id == id4 {
|
||||
foundID4 = true
|
||||
}
|
||||
}
|
||||
if !foundID3 {
|
||||
t.Errorf("id3 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
if !foundID4 {
|
||||
t.Errorf("id4 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id1 and id2 are NOT in the list (they were compacted away)
|
||||
for _, id := range entryIDs {
|
||||
if id == id1 || id == id2 {
|
||||
t.Errorf("id1 or id2 (compacted) should not be in entry IDs, found %q in %v", id, entryIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
EntryTypeSessionInfo EntryType = "session_info"
|
||||
EntryTypeExtensionData EntryType = "extension_data"
|
||||
EntryTypeCompaction EntryType = "compaction"
|
||||
EntryTypeSystemPrompt EntryType = "system_prompt"
|
||||
)
|
||||
|
||||
// CurrentVersion is the session format version for JSONL tree sessions.
|
||||
@@ -117,6 +118,19 @@ type CompactionEntry struct {
|
||||
ModifiedFiles []string `json:"modified_files,omitempty"`
|
||||
}
|
||||
|
||||
// SystemPromptEntry records the system prompt and model used for the session.
|
||||
// This is primarily for sharing/debugging to see what instructions were
|
||||
// active during the conversation. It does NOT participate in the tree
|
||||
// structure (no ParentID) and is not used when building LLM context.
|
||||
type SystemPromptEntry struct {
|
||||
Type EntryType `json:"type"` // always "system_prompt"
|
||||
ID string `json:"id"` // unique entry ID
|
||||
Timestamp time.Time `json:"timestamp"` // when captured
|
||||
Content string `json:"content"` // the system prompt text
|
||||
Model string `json:"model"` // the model used (e.g., "claude-sonnet-4-5")
|
||||
Provider string `json:"provider"` // the provider used (e.g., "anthropic")
|
||||
}
|
||||
|
||||
// GenerateEntryID creates a unique entry identifier (16 hex chars).
|
||||
func GenerateEntryID() string {
|
||||
bytes := make([]byte, 8)
|
||||
@@ -217,6 +231,18 @@ func NewCompactionEntry(parentID, summary, firstKeptEntryID string, tokensBefore
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemPromptEntry creates a SystemPromptEntry.
|
||||
func NewSystemPromptEntry(content, model, provider string) *SystemPromptEntry {
|
||||
return &SystemPromptEntry{
|
||||
Type: EntryTypeSystemPrompt,
|
||||
ID: GenerateEntryID(),
|
||||
Timestamp: time.Now(),
|
||||
Content: content,
|
||||
Model: model,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// --- JSONL marshaling helpers ---
|
||||
|
||||
// MarshalEntry serializes any entry to a JSON line (no trailing newline).
|
||||
@@ -295,6 +321,13 @@ func UnmarshalEntry(data []byte) (any, error) {
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
case EntryTypeSystemPrompt:
|
||||
var e SystemPromptEntry
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal system_prompt entry: %w", err)
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown entry type: %q", env.Type)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSystemPromptEntry(t *testing.T) {
|
||||
// Test creation
|
||||
content := "You are a helpful coding assistant."
|
||||
model := "claude-sonnet-4-5"
|
||||
provider := "anthropic"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
if entry.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Expected type %q, got %q", EntryTypeSystemPrompt, entry.Type)
|
||||
}
|
||||
|
||||
if entry.Content != content {
|
||||
t.Errorf("Expected content %q, got %q", content, entry.Content)
|
||||
}
|
||||
|
||||
if entry.Model != model {
|
||||
t.Errorf("Expected model %q, got %q", model, entry.Model)
|
||||
}
|
||||
|
||||
if entry.Provider != provider {
|
||||
t.Errorf("Expected provider %q, got %q", provider, entry.Provider)
|
||||
}
|
||||
|
||||
if entry.ID == "" {
|
||||
t.Error("Expected non-empty ID")
|
||||
}
|
||||
|
||||
// Test marshaling
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Test unmarshaling
|
||||
unmarshaled, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
sysPrompt, ok := unmarshaled.(*SystemPromptEntry)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *SystemPromptEntry, got %T", unmarshaled)
|
||||
}
|
||||
|
||||
if sysPrompt.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Unmarshaled: expected type %q, got %q", EntryTypeSystemPrompt, sysPrompt.Type)
|
||||
}
|
||||
|
||||
if sysPrompt.Content != content {
|
||||
t.Errorf("Unmarshaled: expected content %q, got %q", content, sysPrompt.Content)
|
||||
}
|
||||
|
||||
if sysPrompt.Model != model {
|
||||
t.Errorf("Unmarshaled: expected model %q, got %q", model, sysPrompt.Model)
|
||||
}
|
||||
|
||||
if sysPrompt.Provider != provider {
|
||||
t.Errorf("Unmarshaled: expected provider %q, got %q", provider, sysPrompt.Provider)
|
||||
}
|
||||
|
||||
if sysPrompt.ID != entry.ID {
|
||||
t.Errorf("Unmarshaled: expected ID %q, got %q", entry.ID, sysPrompt.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemPromptEntryJSONStructure(t *testing.T) {
|
||||
content := "Test system prompt content"
|
||||
model := "gpt-4o"
|
||||
provider := "openai"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("Failed to unmarshal to raw map: %v", err)
|
||||
}
|
||||
|
||||
if raw["type"] != "system_prompt" {
|
||||
t.Errorf("Expected type 'system_prompt', got %v", raw["type"])
|
||||
}
|
||||
|
||||
if raw["content"] != content {
|
||||
t.Errorf("Expected content %q, got %v", content, raw["content"])
|
||||
}
|
||||
|
||||
if raw["model"] != model {
|
||||
t.Errorf("Expected model %q, got %v", model, raw["model"])
|
||||
}
|
||||
|
||||
if raw["provider"] != provider {
|
||||
t.Errorf("Expected provider %q, got %v", provider, raw["provider"])
|
||||
}
|
||||
|
||||
if raw["id"] == "" || raw["id"] == nil {
|
||||
t.Error("Expected non-empty id field")
|
||||
}
|
||||
|
||||
if raw["timestamp"] == "" || raw["timestamp"] == nil {
|
||||
t.Error("Expected non-empty timestamp field")
|
||||
}
|
||||
}
|
||||
@@ -114,6 +114,187 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
|
||||
return tm, nil
|
||||
}
|
||||
|
||||
// ForkToNewSession creates a new session file containing the history up to and
|
||||
// including the target entry ID. This matches Pi's /fork behavior: it creates
|
||||
// a completely new session file with a parent_session reference, copying all
|
||||
// entries from the root to the target point.
|
||||
func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManager, error) {
|
||||
tm.mu.RLock()
|
||||
defer tm.mu.RUnlock()
|
||||
|
||||
// Get the branch from root to target (root-to-leaf order).
|
||||
branch := tm.getBranchLocked(targetID)
|
||||
if len(branch) == 0 {
|
||||
return nil, fmt.Errorf("target entry %q not found", targetID)
|
||||
}
|
||||
|
||||
// Create a new session file.
|
||||
newTm, err := CreateTreeSession(cwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the parent session reference in the header.
|
||||
newTm.header.ParentSession = tm.filePath
|
||||
newTm.header.ParentSessionID = tm.header.ID
|
||||
|
||||
// Rewrite the header with the parent reference.
|
||||
// We need to close and recreate the file to rewrite the header.
|
||||
if err := newTm.file.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close new session file: %w", err)
|
||||
}
|
||||
|
||||
// Recreate the file and write the updated header.
|
||||
f, err := os.Create(newTm.filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to recreate session file: %w", err)
|
||||
}
|
||||
newTm.file = f
|
||||
|
||||
if err := newTm.writeEntry(&newTm.header); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to write session header: %w", err)
|
||||
}
|
||||
|
||||
// Copy entries from the branch to the new session.
|
||||
// We need to remap IDs since the new session is independent.
|
||||
idMap := make(map[string]string) // old ID -> new ID
|
||||
var prevNewID string
|
||||
|
||||
for _, entry := range branch {
|
||||
oldID := tm.EntryID(entry)
|
||||
newID := GenerateEntryID()
|
||||
idMap[oldID] = newID
|
||||
|
||||
// Create a copy of the entry with the new ID and remapped parent.
|
||||
var newEntry any
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
newEntry = &MessageEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeMessage,
|
||||
ID: newID,
|
||||
ParentID: prevNewID, // Chain sequentially in new session
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Role: e.Role,
|
||||
Parts: e.Parts,
|
||||
Model: e.Model,
|
||||
Provider: e.Provider,
|
||||
}
|
||||
// Copy label if present.
|
||||
if label, ok := tm.labels[oldID]; ok {
|
||||
newTm.labels[newID] = label
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
newEntry = &ModelChangeEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeModelChange,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Provider: e.Provider,
|
||||
ModelID: e.ModelID,
|
||||
}
|
||||
|
||||
case *LabelEntry:
|
||||
// Remap the target ID if it's in our copied branch.
|
||||
newTargetID := e.TargetID
|
||||
if mapped, ok := idMap[e.TargetID]; ok {
|
||||
newTargetID = mapped
|
||||
}
|
||||
newEntry = &LabelEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeLabel,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
TargetID: newTargetID,
|
||||
Label: e.Label,
|
||||
}
|
||||
|
||||
case *SessionInfoEntry:
|
||||
newEntry = &SessionInfoEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeSessionInfo,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Name: e.Name,
|
||||
}
|
||||
newTm.sessionName = e.Name
|
||||
|
||||
case *ExtensionDataEntry:
|
||||
newEntry = &ExtensionDataEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeExtensionData,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
ExtType: e.ExtType,
|
||||
Data: e.Data,
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Remap the from ID if it's in our copied branch.
|
||||
newFromID := e.FromID
|
||||
if mapped, ok := idMap[e.FromID]; ok {
|
||||
newFromID = mapped
|
||||
}
|
||||
newEntry = &BranchSummaryEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeBranchSummary,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
FromID: newFromID,
|
||||
Summary: e.Summary,
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
// Remap the first kept entry ID if it's in our copied branch.
|
||||
newFirstKeptID := e.FirstKeptEntryID
|
||||
if mapped, ok := idMap[e.FirstKeptEntryID]; ok {
|
||||
newFirstKeptID = mapped
|
||||
}
|
||||
newEntry = &CompactionEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeCompaction,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Summary: e.Summary,
|
||||
FirstKeptEntryID: newFirstKeptID,
|
||||
TokensBefore: e.TokensBefore,
|
||||
TokensAfter: e.TokensAfter,
|
||||
MessagesRemoved: e.MessagesRemoved,
|
||||
ReadFiles: e.ReadFiles,
|
||||
ModifiedFiles: e.ModifiedFiles,
|
||||
}
|
||||
}
|
||||
|
||||
if newEntry != nil {
|
||||
if err := newTm.appendAndPersist(newEntry); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to copy entry: %w", err)
|
||||
}
|
||||
prevNewID = newID
|
||||
}
|
||||
}
|
||||
|
||||
// Set the leaf to the last entry in the new session.
|
||||
newTm.leafID = prevNewID
|
||||
|
||||
return newTm, nil
|
||||
}
|
||||
|
||||
// OpenTreeSession opens an existing JSONL session file.
|
||||
func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -181,7 +362,7 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
|
||||
// Set leaf to the last entry.
|
||||
if len(tm.entries) > 0 {
|
||||
tm.leafID = tm.entryID(tm.entries[len(tm.entries)-1])
|
||||
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
|
||||
}
|
||||
|
||||
// Open file for appending.
|
||||
@@ -242,9 +423,14 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
return entry.ID, nil
|
||||
}
|
||||
|
||||
// AppendFantasyMessage converts a fantasy.Message and appends it.
|
||||
// AppendLLMMessage converts an LLM message and appends it.
|
||||
func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromFantasyMessage(msg))
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
@@ -323,11 +509,19 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
// AppendCompaction adds a compaction entry to the tree. The entry records
|
||||
// the summary and the ID of the first entry that should be preserved in the
|
||||
// LLM context. Messages before that entry are replaced by the summary.
|
||||
//
|
||||
// The compaction entry becomes a new "root" for the post-compaction branch
|
||||
// with no parent (empty ParentID). This breaks the parent chain so that old
|
||||
// compacted messages are no longer traversed when building context. The kept
|
||||
// messages are explicitly collected via FirstKeptEntryID in BuildContext.
|
||||
func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokensBefore, tokensAfter, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
entry := NewCompactionEntry(tm.leafID, summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
// The compaction entry has no parent, making it a new "root" for the
|
||||
// post-compaction branch. This ensures old compacted messages are not
|
||||
// traversed when walking from the current leaf.
|
||||
entry := NewCompactionEntry("", summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -497,14 +691,18 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
// Find the last compaction entry on this branch — it determines
|
||||
// which older messages are replaced by the summary.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first.
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -514,28 +712,111 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Determine whether to skip entries (everything before firstKeptEntryID).
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
// No compaction - process the entire branch normally.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
@@ -554,10 +835,6 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (the last one on the branch).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,38 +944,99 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
// Find the last compaction entry for skip logic.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction summary injected, it has no entry ID.
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
if lastCompaction != nil {
|
||||
ids = append(ids, "") // placeholder for the summary system message
|
||||
}
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
continue
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// No compaction - collect IDs from the entire branch.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
@@ -707,9 +1045,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,31 +1072,41 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// --- Legacy bridge ---
|
||||
|
||||
// AddFantasyMessages appends multiple fantasy messages as entries. This is
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendFantasyMessage(msg); err != nil {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFantasyMessages builds the context and returns just the messages.
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
func (tm *TreeManager) addEntryToIndex(entry any) {
|
||||
tm.entries = append(tm.entries, entry)
|
||||
|
||||
id := tm.entryID(entry)
|
||||
id := tm.EntryID(entry)
|
||||
parentID := tm.entryParentID(entry)
|
||||
|
||||
if id != "" {
|
||||
@@ -798,8 +1143,8 @@ func (tm *TreeManager) writeEntry(entry any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// entryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) entryID(entry any) string {
|
||||
// EntryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) EntryID(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
return e.ID
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
|
||||
@@ -60,34 +60,38 @@ type MCPConnection struct {
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
// If config is nil, default configuration values will be used. The pool starts a background
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
tokenStoreFactory: tokenStoreFactory,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
|
||||
}
|
||||
|
||||
go pool.startHealthCheck()
|
||||
@@ -103,6 +107,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
|
||||
// When set, the pool can trigger OAuth re-authorization when a tool call fails
|
||||
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oauthFlow = flow
|
||||
}
|
||||
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
@@ -127,9 +140,7 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
} else {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
_ = conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
@@ -232,18 +243,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
client, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry after successful auth
|
||||
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
@@ -269,6 +305,8 @@ func (p *MCPConnectionPool) createMCPClient(ctx context.Context, serverName stri
|
||||
return p.createSSEClient(ctx, serverConfig)
|
||||
case "streamable":
|
||||
return p.createStreamableClient(ctx, serverConfig)
|
||||
case "inprocess":
|
||||
return p.createInProcessClient(serverConfig)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
|
||||
}
|
||||
@@ -325,13 +363,39 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sseClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start SSE client: %w", err)
|
||||
}
|
||||
|
||||
return sseClient, nil
|
||||
@@ -356,18 +420,70 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := streamableClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
|
||||
}
|
||||
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createInProcessClient creates an in-process MCP client that communicates
|
||||
// directly with an *server.MCPServer in the same process. No subprocess is
|
||||
// spawned and no network I/O occurs — calls go through JSON marshal →
|
||||
// MCPServer.HandleMessage → JSON unmarshal, all in-memory.
|
||||
func (p *MCPConnectionPool) createInProcessClient(serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
srv, ok := serverConfig.InProcessServer.(*server.MCPServer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("InProcessServer must be *server.MCPServer, got %T", serverConfig.InProcessServer)
|
||||
}
|
||||
inProcessClient, err := client.NewInProcessClient(srv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create in-process client: %w", err)
|
||||
}
|
||||
return inProcessClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
|
||||
if p.tokenStoreFactory != nil {
|
||||
return p.tokenStoreFactory(serverURL)
|
||||
}
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
@@ -383,7 +499,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
@@ -514,6 +630,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// RemoveConnection closes and removes a single connection from the pool.
|
||||
// Returns an error if the connection does not exist or if closing fails.
|
||||
// Thread-safe for concurrent use.
|
||||
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conn, exists := p.connections[serverName]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection %q not found in pool", serverName)
|
||||
}
|
||||
|
||||
err := conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
@@ -541,6 +678,9 @@ func (p *MCPConnectionPool) Close() error {
|
||||
|
||||
// isConnectionError checks if the error is connection-related
|
||||
func isConnectionError(err error) bool {
|
||||
if IsOAuthError(err) {
|
||||
return false // OAuth errors are recoverable, not connection failures
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "Connection not found") ||
|
||||
strings.Contains(errStr, "transport error") ||
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// mcpFantasyTool adapts an MCP tool to the fantasy.AgentTool interface.
|
||||
// It bridges the MCP tool protocol with fantasy's agent tool system, handling
|
||||
// name prefixing, schema conversion, connection pooling, and result marshaling.
|
||||
type mcpFantasyTool struct {
|
||||
toolInfo fantasy.ToolInfo
|
||||
mapping *toolMapping
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpFantasyTool) Info() fantasy.ToolInfo {
|
||||
return t.toolInfo
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by routing through the connection pool.
|
||||
// It maps the prefixed tool name back to the original name, retrieves a healthy
|
||||
// connection, invokes the tool, and converts the MCP result to a fantasy ToolResponse.
|
||||
func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
// Parse and validate JSON arguments
|
||||
var arguments any
|
||||
input := call.Input
|
||||
if input == "" || input == "{}" {
|
||||
arguments = nil
|
||||
} else {
|
||||
var temp any
|
||||
if err := json.Unmarshal([]byte(input), &temp); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid JSON arguments: %v", err)), nil
|
||||
}
|
||||
arguments = json.RawMessage(input)
|
||||
}
|
||||
|
||||
// Get connection from pool with health check
|
||||
conn, err := t.mapping.manager.connectionPool.GetConnectionWithHealthCheck(
|
||||
ctx, t.mapping.serverName, t.mapping.serverConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
||||
}
|
||||
|
||||
// Call the MCP tool using the original (unprefixed) name
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
marshaledResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
|
||||
// Return as text response, preserving error status from MCP
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(string(marshaledResult)), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(marshaledResult)), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTestInProcessServer creates a simple MCP server with one tool for testing.
|
||||
func newTestInProcessServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("test-server", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description("Name to greet")),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("Hello, " + name + "!"), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestInProcessTransportType(t *testing.T) {
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessTransportTypeInferred(t *testing.T) {
|
||||
// When Type is empty but InProcessServer is set, infer "inprocess".
|
||||
cfg := config.MCPServerConfig{
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessValidation(t *testing.T) {
|
||||
// Valid: InProcessServer is set.
|
||||
validCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := validCfg.Validate(); err != nil {
|
||||
t.Errorf("expected valid config, got error: %v", err)
|
||||
}
|
||||
|
||||
// Invalid: type is inprocess but InProcessServer is nil.
|
||||
invalidCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := invalidCfg.Validate(); err == nil {
|
||||
t.Error("expected validation error for nil InProcessServer, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessClient(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the connection is healthy and functional.
|
||||
if !conn.isHealthy {
|
||||
t.Error("expected connection to be healthy")
|
||||
}
|
||||
|
||||
// List tools to verify the connection works end-to-end.
|
||||
toolsResp, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools failed: %v", err)
|
||||
}
|
||||
if len(toolsResp.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(toolsResp.Tools))
|
||||
}
|
||||
if toolsResp.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", toolsResp.Tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessToolExecution(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the tool.
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: "greet",
|
||||
Arguments: map[string]any{"name": "World"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected at least one content block")
|
||||
}
|
||||
text, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("expected TextContent, got %T", result.Content[0])
|
||||
}
|
||||
if text.Text != "Hello, World!" {
|
||||
t.Errorf("expected 'Hello, World!', got %q", text.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolManagerInProcess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
mgr := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
count, err := mgr.AddServer(ctx, "myserver", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", count)
|
||||
}
|
||||
|
||||
tools := mgr.GetTools()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Name != "myserver__greet" {
|
||||
t.Errorf("expected tool name 'myserver__greet', got %q", tools[0].Name)
|
||||
}
|
||||
|
||||
// Execute the tool.
|
||||
input, _ := json.Marshal(map[string]any{"name": "SDK"})
|
||||
result, err := mgr.ExecuteTool(ctx, "myserver__greet", string(input))
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if result.Content == "" {
|
||||
t.Error("expected non-empty result content")
|
||||
}
|
||||
|
||||
// Verify result contains our greeting.
|
||||
if !strings.Contains(result.Content, "Hello, SDK!") {
|
||||
t.Errorf("expected 'Hello, SDK!' in result, got %q", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessInvalidServer(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pass a non-*server.MCPServer value.
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: "not a server",
|
||||
}
|
||||
|
||||
_, err := pool.GetConnection(ctx, "bad", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid InProcessServer type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessReuse(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
// Get connection twice — should reuse.
|
||||
conn1, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("first GetConnection failed: %v", err)
|
||||
}
|
||||
conn2, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("second GetConnection failed: %v", err)
|
||||
}
|
||||
if conn1 != conn2 {
|
||||
t.Error("expected same connection object on reuse")
|
||||
}
|
||||
}
|
||||
+901
-59
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user