mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
166 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 633fa38b2b | |||
| f905cee48c | |||
| 182c10ea1a | |||
| fcaa52bf1c | |||
| 7e6455732c | |||
| 71301a9035 | |||
| 0974d37ab2 | |||
| 398e825df8 | |||
| 3c51c20be7 | |||
| 25410af440 | |||
| 26c9f009f9 | |||
| e068487ff7 | |||
| 0ffb0ba788 | |||
| 65c6e9f797 | |||
| 68d798d2f4 | |||
| eefd5565f8 | |||
| 9d1b8a102e | |||
| f57e045c69 | |||
| eb5da28a15 | |||
| cd8e2a7654 | |||
| 64da1caf41 | |||
| 7eaeafff8c | |||
| 8ed8d23c73 | |||
| 2de98d32be | |||
| 83127467c5 | |||
| e07c94f49d | |||
| b87146a284 | |||
| 186d9f7f44 | |||
| 3a8ffc2104 | |||
| e54570162e | |||
| 34bb97a40e | |||
| f5c1a16f8a | |||
| b29d7d2166 | |||
| 3ea0db69ea | |||
| 4304a5e899 | |||
| 4019c1e4f7 | |||
| 30ad7c1d0b | |||
| e33564c569 | |||
| 5ff28445fd | |||
| 13d177e5d0 | |||
| 3ffc995f27 | |||
| b2bd016135 | |||
| 812dedaea2 | |||
| f65b6737f2 | |||
| 5d45aa196b | |||
| debb39f56c | |||
| 7ce6f4fd9e | |||
| c2f2bdb3d3 | |||
| 201d14804e | |||
| 7e54710d4a | |||
| 88870be4d2 | |||
| 46bf809715 | |||
| e19e9642a2 | |||
| 32675b8b35 | |||
| aecce001ee | |||
| 32d73171fd | |||
| 265fd2ec0c | |||
| efebf2eba6 | |||
| f7b655ae33 | |||
| 35982b41ad | |||
| 788e3b71fd | |||
| 3496bc2684 | |||
| 997c7d15ff | |||
| 83246e47d5 | |||
| 50e7b78c33 | |||
| b937af3056 | |||
| a5e995c750 | |||
| e95e08a699 | |||
| bcaf92f62a | |||
| ead4afbfe6 | |||
| 685aaf207f | |||
| 76ff6c9639 | |||
| 1cf24ee5de | |||
| c9637090fa | |||
| 0ff0ff42ab | |||
| a4fb32ff2b | |||
| 7d2f078111 | |||
| b0b66941ab | |||
| cbb7387a72 | |||
| 19430b0ecb | |||
| 8e3cfeede5 | |||
| 4fa5775974 | |||
| 4e7d823ee4 | |||
| 7a16c76adc | |||
| 70a21ee73a | |||
| 28d2de8f39 | |||
| 7f192ae850 | |||
| 9f6746ded9 | |||
| 7514d3a0ff | |||
| c83281a52b | |||
| 4515bb92c2 | |||
| e326b84204 | |||
| 1b93049b8e | |||
| 4912449dda | |||
| b70cce4f34 | |||
| 4c566836b2 | |||
| bb3261883a | |||
| 512d0f16ce | |||
| 8159431ce4 | |||
| 9f9f265fb3 | |||
| 9d38349091 | |||
| fec8bac800 | |||
| e76f5f3d45 | |||
| 1ad493c5c7 | |||
| ea6ddc8792 | |||
| 6d4e8bcec5 | |||
| e2ed345280 | |||
| e542eb797e | |||
| e631fc1b17 | |||
| 290c5a4774 | |||
| 287d60c31e | |||
| 3d45d98895 | |||
| db4be4f9a2 | |||
| 80093e69ed | |||
| ef519ba517 | |||
| d79eb1f0fa | |||
| ac8ee6525d | |||
| e35e8382d6 | |||
| fbb3408a25 | |||
| 44fed9a647 | |||
| e7f11487b9 | |||
| 054c417603 | |||
| 94d62a6ef0 | |||
| 91e6dfd2c8 | |||
| b6a0c4b44c | |||
| 8eb0fa855a | |||
| 3bf696c546 | |||
| 3e461a0539 | |||
| a2ece01ecf | |||
| 623c9fb5ad | |||
| 139506f336 | |||
| 6d424554ad | |||
| 5a3d3fdd7d | |||
| c91225629d | |||
| 5a71cde5ff | |||
| 044d3eb206 | |||
| 80f3a642a3 | |||
| 26f0969e3e | |||
| 4af75901b5 | |||
| 49ff4c0678 | |||
| b0802a5c32 | |||
| dfe65ca227 | |||
| d4ec756ce5 | |||
| 2971e73ee8 | |||
| 5aa6c9e116 | |||
| bca08476de | |||
| 6a599d86af | |||
| fd6f200659 | |||
| b295a25946 | |||
| f0e4e2f757 | |||
| d25249506a | |||
| 971521f534 | |||
| 8c00682367 | |||
| 58caf155c1 | |||
| 3f08bf2424 | |||
| 9fbbab05f6 | |||
| b0991c7aa6 | |||
| 9c90563765 | |||
| f36166bee5 | |||
| 879e81f9b5 | |||
| 727b42acfe | |||
| 4830981570 | |||
| dcfebafcc5 | |||
| 1f5c103667 | |||
| 4caa8ba3dc | |||
| 15ef8ad78b |
@@ -0,0 +1,79 @@
|
||||
name: Bug Report
|
||||
description: Report a bug or issue with Kit
|
||||
title: "fix: "
|
||||
labels: ["bug"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Bug Description
|
||||
description: What happened? What did you expect to happen?
|
||||
placeholder: |
|
||||
The BorderColor field in ToolRenderConfig is documented but never applied
|
||||
during tool rendering. I expected the tool block to render with my custom
|
||||
color, but it uses the default styling instead.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide clear steps to reproduce the issue
|
||||
placeholder: |
|
||||
1. Create an extension with `api.RegisterToolRenderer(ext.ToolRenderConfig{...})`
|
||||
2. Set `BorderColor: "#89b4fa"` in the config
|
||||
3. Run a tool that uses this renderer
|
||||
4. Observe the border color is not applied
|
||||
render: markdown
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: code
|
||||
attributes:
|
||||
label: Relevant Code / Configuration
|
||||
description: Paste any code, configuration, or error messages
|
||||
placeholder: |
|
||||
```go
|
||||
api.RegisterToolRenderer(ext.ToolRenderConfig{
|
||||
ToolName: "bash",
|
||||
DisplayName: "Shell",
|
||||
BorderColor: "#a6e3a1", // This is ignored!
|
||||
Background: "#1e1e2e", // This is ignored!
|
||||
})
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: input
|
||||
id: component
|
||||
attributes:
|
||||
label: Affected Component
|
||||
description: Which part of Kit is affected?
|
||||
placeholder: e.g., extensions, ui, tool rendering, session management
|
||||
|
||||
- type: input
|
||||
id: version
|
||||
attributes:
|
||||
label: Kit Version
|
||||
description: What version of Kit are you running?
|
||||
placeholder: e.g., v0.1.0, commit hash, or "main"
|
||||
|
||||
- type: textarea
|
||||
id: context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Any other context, proposed fixes, or related issues
|
||||
placeholder: |
|
||||
The issue appears to be in `internal/ui/messages.go:RenderToolMessage()`
|
||||
which ignores the BorderColor and Background fields from ToolRendererData.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been reported yet
|
||||
required: true
|
||||
- label: I've tested with the latest version of Kit
|
||||
required: false
|
||||
@@ -0,0 +1,11 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Kit Documentation
|
||||
url: https://github.com/mark3labs/kit/tree/main/www/pages
|
||||
about: Check the documentation before filing an issue
|
||||
- name: Extension Examples
|
||||
url: https://github.com/mark3labs/kit/tree/main/examples/extensions
|
||||
about: See working extension examples for reference
|
||||
- name: Discussions
|
||||
url: https://github.com/mark3labs/kit/discussions
|
||||
about: For questions, ideas, or general discussion
|
||||
@@ -0,0 +1,40 @@
|
||||
name: Documentation Issue
|
||||
description: Report missing, incorrect, or unclear documentation
|
||||
title: "docs: "
|
||||
labels: ["documentation"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Documentation Issue
|
||||
description: What's wrong or missing in the documentation?
|
||||
placeholder: |
|
||||
The ToolRenderConfig documentation mentions BorderColor and Background fields,
|
||||
but the code doesn't actually use them. The docs should either be updated
|
||||
to reflect reality, or the bug should be fixed.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: location
|
||||
attributes:
|
||||
label: Documentation Location
|
||||
description: Where is the affected documentation?
|
||||
placeholder: e.g., README.md, examples/extensions/tool-renderer-demo.go, pkg/kit docs
|
||||
|
||||
- type: textarea
|
||||
id: suggestion
|
||||
attributes:
|
||||
label: Suggested Improvement
|
||||
description: How should the documentation be improved?
|
||||
placeholder: |
|
||||
Add a note that BorderColor and Background are not yet implemented,
|
||||
or fix the bug and document the correct behavior.
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've checked that this documentation issue still exists in the latest version
|
||||
required: true
|
||||
@@ -0,0 +1,64 @@
|
||||
name: Feature Request
|
||||
description: Suggest a new feature or enhancement for Kit
|
||||
title: "feat: "
|
||||
labels: ["enhancement"]
|
||||
body:
|
||||
- type: textarea
|
||||
id: description
|
||||
attributes:
|
||||
label: Feature Description
|
||||
description: What would you like to see added or changed?
|
||||
placeholder: |
|
||||
I'd like to be able to customize the border color of tool result blocks
|
||||
dynamically based on the tool type or result status.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
attributes:
|
||||
label: Motivation / Use Case
|
||||
description: Why is this feature needed? What problem does it solve?
|
||||
placeholder: |
|
||||
When running multiple tools in sequence, it's hard to visually distinguish
|
||||
between file reads (blue), shell commands (green), and errors (red)
|
||||
without custom border colors.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposed
|
||||
attributes:
|
||||
label: Proposed Implementation
|
||||
description: How do you think this should work? (optional)
|
||||
placeholder: |
|
||||
Extend `ToolRenderConfig` to accept a function that receives the tool
|
||||
result and returns a color based on the content:
|
||||
|
||||
```go
|
||||
BorderColorFunc: func(result string, isError bool) string {
|
||||
if isError {
|
||||
return "#f38ba8"
|
||||
}
|
||||
return "#89b4fa"
|
||||
}
|
||||
```
|
||||
render: go
|
||||
|
||||
- type: checkboxes
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
options:
|
||||
- label: I've considered workarounds or alternative approaches
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Checklist
|
||||
options:
|
||||
- label: I've searched existing issues and this hasn't been requested yet
|
||||
required: true
|
||||
- label: This feature aligns with Kit's design philosophy (TUI-first, extension-based)
|
||||
required: false
|
||||
@@ -3,6 +3,7 @@
|
||||
.env
|
||||
.kit/*
|
||||
!.kit/extensions/
|
||||
!.kit/prompts/
|
||||
aidocs/
|
||||
*.log
|
||||
/kit
|
||||
|
||||
@@ -28,11 +28,15 @@ type lintResult struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
// Package-level state: set of .go files edited during the current agent turn.
|
||||
var editedFiles map[string]bool
|
||||
|
||||
func Init(api ext.API) {
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint on Go file edits")
|
||||
ctx.Print("go-edit-lint extension loaded - will run gopls and golangci-lint after agent turns that edit Go files")
|
||||
})
|
||||
|
||||
// Track edited .go files — don't lint yet.
|
||||
api.OnToolResult(func(e ext.ToolResultEvent, ctx ext.Context) *ext.ToolResultResult {
|
||||
if e.IsError || !isEditOrWrite(e.ToolName) {
|
||||
return nil
|
||||
@@ -43,30 +47,72 @@ func Init(api ext.API) {
|
||||
return nil
|
||||
}
|
||||
|
||||
report := runGoDiagnostics(ctx.CWD, absPath)
|
||||
|
||||
// Check if there are issues and add explicit prompt for the LLM to react
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
var enhanced string
|
||||
if hasIssues {
|
||||
enhanced = e.Content + "\n\n" + report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review the issues above and fix them before proceeding."
|
||||
} else {
|
||||
enhanced = e.Content + "\n\n" + report
|
||||
if editedFiles == nil {
|
||||
editedFiles = make(map[string]bool)
|
||||
}
|
||||
editedFiles[absPath] = true
|
||||
return nil
|
||||
})
|
||||
|
||||
// After the agent turn ends, lint all collected files.
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if len(editedFiles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Show TUI message block for diagnostics visibility (only if there are issues)
|
||||
// Snapshot and reset immediately so the next turn starts clean.
|
||||
files := editedFiles
|
||||
editedFiles = nil
|
||||
|
||||
// Skip lint on errored turns.
|
||||
if e.StopReason == "error" {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect unique directories and file list for gopls.
|
||||
var allGoplsOutput []string
|
||||
for absPath := range files {
|
||||
res := runGopls(ctx.CWD, absPath)
|
||||
formatted := formatToolResult(res, "")
|
||||
if formatted != "" {
|
||||
allGoplsOutput = append(allGoplsOutput, fmt.Sprintf("# %s\n%s", filepath.Base(absPath), formatted))
|
||||
}
|
||||
}
|
||||
|
||||
lintRes := runGolangCILint(ctx.CWD, "./...")
|
||||
|
||||
goplsSection := "No diagnostics."
|
||||
if len(allGoplsOutput) > 0 {
|
||||
goplsSection = strings.Join(allGoplsOutput, "\n\n")
|
||||
}
|
||||
lintSection := formatToolResult(lintRes, "No lint issues.")
|
||||
|
||||
// Build file list for the report header.
|
||||
var fileNames []string
|
||||
for absPath := range files {
|
||||
fileNames = append(fileNames, filepath.Base(absPath))
|
||||
}
|
||||
|
||||
report := fmt.Sprintf(
|
||||
"<go_diagnostics files=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
strings.Join(fileNames, ", "),
|
||||
goplsSection,
|
||||
lintSection,
|
||||
)
|
||||
|
||||
goplsIssues, lintIssues := countIssues(report)
|
||||
hasIssues := goplsIssues > 0 || lintIssues > 0
|
||||
|
||||
if hasIssues {
|
||||
// Show TUI block so the user sees it too.
|
||||
var msgLines []string
|
||||
msgLines = append(msgLines, fmt.Sprintf("File: %s", filepath.Base(absPath)))
|
||||
msgLines = append(msgLines, fmt.Sprintf("Files: %s", strings.Join(fileNames, ", ")))
|
||||
if goplsIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("gopls: %d issue(s)", goplsIssues))
|
||||
}
|
||||
if lintIssues > 0 {
|
||||
msgLines = append(msgLines, fmt.Sprintf("golangci-lint: %d issue(s)", lintIssues))
|
||||
}
|
||||
msgLines = append(msgLines, "", "⚠️ Please fix these issues before proceeding.")
|
||||
|
||||
borderColor := "#f9e2af" // yellow
|
||||
if goplsIssues > 0 && lintIssues > 0 {
|
||||
@@ -78,9 +124,16 @@ func Init(api ext.API) {
|
||||
BorderColor: borderColor,
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
|
||||
return &ext.ToolResultResult{Content: &enhanced}
|
||||
// Inject a follow-up message so the agent fixes the issues.
|
||||
ctx.SendMessage(report + "\n\n⚠️ DIAGNOSTICS FOUND: Please review and fix the issues above.")
|
||||
} else {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: fmt.Sprintf("Files: %s\n✓ All clean", strings.Join(fileNames, ", ")),
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "go-edit-lint",
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -106,18 +159,6 @@ func resolveGoFilePath(inputJSON, cwd string) (string, bool) {
|
||||
return absPath, true
|
||||
}
|
||||
|
||||
func runGoDiagnostics(cwd, absPath string) string {
|
||||
gopls := runGopls(cwd, absPath)
|
||||
lint := runGolangCILint(cwd, "./...")
|
||||
|
||||
return fmt.Sprintf(
|
||||
"<go_diagnostics file=%q>\n[gopls]\n%s\n\n[golangci-lint]\n%s\n</go_diagnostics>",
|
||||
filepath.Base(absPath),
|
||||
formatToolResult(gopls, "No diagnostics."),
|
||||
formatToolResult(lint, "No lint issues."),
|
||||
)
|
||||
}
|
||||
|
||||
func runGopls(cwd, absPath string) lintResult {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), diagnosticsTimeout)
|
||||
defer cancel()
|
||||
@@ -178,7 +219,9 @@ func formatToolResult(res lintResult, emptyFallback string) string {
|
||||
out := strings.TrimSpace(res.Output)
|
||||
if out == "" {
|
||||
if res.Err == nil {
|
||||
lines = append(lines, emptyFallback)
|
||||
if emptyFallback != "" {
|
||||
lines = append(lines, emptyFallback)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lines = append(lines, out)
|
||||
@@ -197,17 +240,15 @@ func truncate(s string, max int) string {
|
||||
}
|
||||
|
||||
func countIssues(report string) (goplsCount, lintCount int) {
|
||||
// Extract gopls section
|
||||
goplsStart := strings.Index(report, "[gopls]")
|
||||
lintStart := strings.Index(report, "[golangci-lint]")
|
||||
endTag := strings.Index(report, "</go_diagnostics>")
|
||||
|
||||
if goplsStart != -1 && lintStart != -1 {
|
||||
goplsSection := report[goplsStart:lintStart]
|
||||
// Count non-empty lines excluding the header and "No diagnostics." message
|
||||
for _, line := range strings.Split(goplsSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." {
|
||||
if line != "" && line != "[gopls]" && line != "No diagnostics." && !strings.HasPrefix(line, "#") {
|
||||
goplsCount++
|
||||
}
|
||||
}
|
||||
@@ -215,7 +256,6 @@ func countIssues(report string) (goplsCount, lintCount int) {
|
||||
|
||||
if lintStart != -1 && endTag != -1 {
|
||||
lintSection := report[lintStart:endTag]
|
||||
// Count non-empty lines excluding the header and "No lint issues." message
|
||||
for _, line := range strings.Split(lintSection, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line != "" && line != "[golangci-lint]" && line != "No lint issues." {
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
// - No channels in maps (Yaegi panics on range over map[string]chan)
|
||||
// - All ctx.* calls guarded with nil checks
|
||||
// - Simple data structures only
|
||||
// - The extension runner serializes handler calls per-extension, so
|
||||
// concurrent subagent events cannot race on this shared state.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -43,7 +45,8 @@ const (
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Package-level state - all simple types
|
||||
// Package-level state — safe because the runner serializes all handler
|
||||
// invocations for the same extension (per-extension reentrant mutex).
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
var (
|
||||
@@ -282,8 +285,8 @@ func Init(api ext.API) {
|
||||
|
||||
submonPushWidget()
|
||||
|
||||
// Remove the entry immediately (no goroutine to avoid races)
|
||||
newEntries := submonEntries[:0]
|
||||
// Remove the entry — build a new slice to avoid aliasing bugs
|
||||
newEntries := make([]*submonEntry, 0, len(submonEntries))
|
||||
for _, en := range submonEntries {
|
||||
if en.callID != e.ToolCallID {
|
||||
newEntries = append(newEntries, en)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
---
|
||||
description: Run ACP smoke test against opencode/kimi-k2.5 to verify JSON-RPC stdio works
|
||||
---
|
||||
|
||||
Run the ACP smoke test to verify the Kit ACP server works correctly over JSON-RPC stdio with streaming responses.
|
||||
|
||||
## Steps
|
||||
|
||||
1. Build the kit binary:
|
||||
```bash
|
||||
go build -o output/kit ./cmd/kit
|
||||
```
|
||||
|
||||
2. Run the smoke test Python script against opencode/kimi-k2.5:
|
||||
```bash
|
||||
python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
3. Verify the output shows:
|
||||
- `session/new` returns a valid `sessionId`
|
||||
- `session/prompt` streams `agent_thought_chunk` notifications (reasoning)
|
||||
- `session/prompt` streams `agent_message_chunk` notifications (response)
|
||||
- Final result has `stopReason: "end_turn"`
|
||||
- `✓ SMOKE TEST PASSED` at the end
|
||||
|
||||
4. If the test fails, check:
|
||||
- `output/kit` binary exists and is executable
|
||||
- `OPENCODE_API_KEY` or `OPENCODE_ZEN_API_KEY` environment variable is set
|
||||
- `scripts/acp_smoke_test.py` exists
|
||||
- The model `opencode/kimi-k2.5` is available (`kit models opencode | grep kimi-k2.5`)
|
||||
|
||||
5. For testing with a different model, edit the script or set the `MODEL` variable:
|
||||
```bash
|
||||
MODEL=anthropic/claude-sonnet-4-5 python3 scripts/acp_smoke_test.py
|
||||
```
|
||||
|
||||
The smoke test exercises the full ACP protocol: session lifecycle, streaming notifications, and tool-free prompt completion.
|
||||
@@ -0,0 +1,30 @@
|
||||
---
|
||||
description: Stage, commit, and push changes with an auto-generated conventional commit message
|
||||
---
|
||||
|
||||
Review the current git status and diff, then stage all changes, write a concise conventional commit message, commit, and push to the current branch.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Check status**: `git status` — understand what has changed
|
||||
2. **Review the diff**: `git diff` (and `git diff --cached` if anything is already staged) — read the actual changes
|
||||
3. **Stage everything**: `git add -A`
|
||||
4. **Craft the commit message** following Conventional Commits:
|
||||
- Format: `<type>(<scope>): <short summary>`
|
||||
- Types: `feat`, `fix`, `refactor`, `chore`, `docs`, `test`, `perf`, `build`
|
||||
- Scope: optional, the subsystem affected (e.g. `ui`, `cmd`, `config`)
|
||||
- Summary: imperative mood, lowercase, no trailing period, ≤72 chars
|
||||
- Body: add a blank line then bullet points for non-trivial changes
|
||||
- Do **not** include "Generated by" or similar noise
|
||||
5. **Commit**: `git commit -m "<message>"`
|
||||
6. **Push**: `git push`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Read the actual diff — do not guess from filenames alone
|
||||
- Prefer one well-scoped commit; do not split unless the changes are clearly unrelated
|
||||
- Keep the subject line under 72 characters
|
||||
- Use the body to explain *what* and *why*, not *how*
|
||||
- If there is nothing to commit, say so and stop
|
||||
|
||||
$@
|
||||
@@ -0,0 +1,86 @@
|
||||
---
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
| Field | Required | Purpose |
|
||||
|-------|----------|---------|
|
||||
| **Feature Description** | Yes | What should be added or changed |
|
||||
| **Motivation / Use Case** | Yes | Why is this needed? What problem does it solve? |
|
||||
| **Proposed Implementation** | No | How do you think this should work? |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$+`
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
2. **Ask clarifying questions** if needed:
|
||||
- "What problem does this solve for you?"
|
||||
- "How would you expect this to work?"
|
||||
- "Are there similar features in other tools you use?"
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `feat: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Good examples:
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `feat: support custom themes per extension`
|
||||
- `feat: add fuzzy matching to model selector`
|
||||
- Bad examples:
|
||||
- `Feature request: can we have...` (too vague)
|
||||
- `It would be nice if...` (not imperative)
|
||||
|
||||
4. **Build the body** with the template fields:
|
||||
|
||||
**Feature Description:**
|
||||
- Clear statement of what to add/change
|
||||
- Be specific about the behavior
|
||||
- Include UI/UX details if relevant
|
||||
|
||||
**Motivation / Use Case:**
|
||||
- What problem does this solve?
|
||||
- Current workaround (if any) and why it's insufficient
|
||||
- Who benefits from this feature?
|
||||
|
||||
**Proposed Implementation** (optional but helpful):
|
||||
- High-level approach
|
||||
- API changes if applicable
|
||||
- Example usage code
|
||||
|
||||
5. **Create the issue**:
|
||||
```bash
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
```
|
||||
|
||||
6. **Confirm success**:
|
||||
- Show the issue URL and number
|
||||
- Mention it was created with the feature_request template
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Focus on the *problem* first, then the solution
|
||||
- Include concrete examples of how the feature would be used
|
||||
- Consider edge cases and mention them
|
||||
- If proposing API changes, show before/after code
|
||||
- Check if similar features exist in related tools (mention them for reference)
|
||||
- Align with Kit's philosophy: TUI-first, extension-based, keyboard-driven
|
||||
|
||||
## Example
|
||||
|
||||
User: `/feature-request I want to be able to customize tool border colors dynamically`
|
||||
|
||||
You:
|
||||
1. Title: `feat: dynamic border colors for tool results based on status`
|
||||
2. Body:
|
||||
- **Feature Description**: Allow `ToolRenderConfig` to accept a function that determines border color based on tool result content or status, enabling dynamic visual feedback.
|
||||
- **Motivation**: When running multiple tools, it's hard to distinguish file reads (blue), shell commands (green), and errors (red) without custom colors per result.
|
||||
- **Proposed Implementation**: Add `BorderColorFunc` callback that receives `(result string, isError bool)` and returns a color string.
|
||||
|
||||
3. Execute: `gh issue create --template feature_request --title "feat: ..." --body "..."`
|
||||
4. Confirm: Created issue #43 using feature_request template
|
||||
@@ -0,0 +1,100 @@
|
||||
---
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
This repository has structured issue templates. You MUST use the appropriate template:
|
||||
|
||||
| Type | Template | Use For |
|
||||
|------|----------|---------|
|
||||
| `bug` | `bug_report` | Something is broken, not working as expected |
|
||||
| `feat` | `feature_request` | New feature, enhancement, improvement |
|
||||
| `docs` | `documentation` | Missing, incorrect, or unclear documentation |
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$+`:
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
2. **Ask clarifying questions** if critical info is missing:
|
||||
- For bugs: "What were you doing when this happened?" (reproduction steps)
|
||||
- For features: "What problem does this solve?" (motivation)
|
||||
- For docs: "Where did you look for this information?" (location)
|
||||
|
||||
3. **Craft the title** using conventional format:
|
||||
- `<type>: <short description>`
|
||||
- Lowercase, imperative mood, ≤72 chars
|
||||
- Examples:
|
||||
- `fix: ToolRenderConfig BorderColor ignored during rendering`
|
||||
- `feat: add keyboard shortcut for clearing input`
|
||||
- `docs: clarify extension widget lifecycle`
|
||||
|
||||
4. **File the issue** using the template:
|
||||
```bash
|
||||
# For bugs
|
||||
gh issue create --template bug_report --title "fix: ..." --body "..."
|
||||
|
||||
# For features
|
||||
gh issue create --template feature_request --title "feat: ..." --body "..."
|
||||
|
||||
# For documentation
|
||||
gh issue create --template documentation --title "docs: ..." --body "..."
|
||||
```
|
||||
|
||||
The template will guide the user through the required fields. You need to provide:
|
||||
- **Bug reports**: Description, reproduction steps, expected vs actual behavior
|
||||
- **Feature requests**: Description, motivation/use case, optional proposed implementation
|
||||
- **Documentation**: Description, location of docs, suggested improvement
|
||||
|
||||
5. **Confirm success** by showing:
|
||||
- The issue URL
|
||||
- The issue number
|
||||
- Which template was used
|
||||
|
||||
## Template Field Guide
|
||||
|
||||
### Bug Report (`bug_report`)
|
||||
Required fields in the body:
|
||||
- **Bug Description** - what happened vs expected
|
||||
- **Steps to Reproduce** - numbered list to recreate the bug
|
||||
- **Relevant Code** - code snippets, configuration, error messages
|
||||
- **Component** - which part of Kit (ui, extensions, session, etc.)
|
||||
- **Version** - Kit version or commit hash
|
||||
|
||||
### Feature Request (`feature_request`)
|
||||
Required fields in the body:
|
||||
- **Feature Description** - what to add/change
|
||||
- **Motivation / Use Case** - why this is needed
|
||||
- **Proposed Implementation** - how it could work (optional)
|
||||
|
||||
### Documentation (`documentation`)
|
||||
Required fields in the body:
|
||||
- **Documentation Issue** - what's wrong or missing
|
||||
- **Documentation Location** - file or URL where docs exist
|
||||
- **Suggested Improvement** - how to fix the docs
|
||||
|
||||
## Guidelines
|
||||
|
||||
- ALWAYS use `--template <name>` instead of bare `gh issue create`
|
||||
- Include file paths and line numbers when you know them
|
||||
- Use triple backticks for code blocks
|
||||
- Keep the body factual - avoid speculation unless in "Proposed Fix" section
|
||||
- If you're unsure about technical details, say so in the issue
|
||||
- For UI bugs, describe what you see vs what you expect
|
||||
- For API bugs, include the relevant struct/function names
|
||||
|
||||
## Example Usage
|
||||
|
||||
User: `/file-issue The ToolRenderConfig BorderColor field is documented but never used in rendering`
|
||||
|
||||
You:
|
||||
1. Determine this is a **bug** (documented field doesn't work)
|
||||
2. Use `--template bug_report`
|
||||
3. Gather: reproduction steps (register renderer with BorderColor), expected (custom color), actual (default color)
|
||||
4. Create issue with title `fix: ToolRenderConfig BorderColor and Background fields are ignored`
|
||||
5. Confirm: Created issue #42 using bug_report template
|
||||
@@ -0,0 +1,49 @@
|
||||
---
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
A prompt template is a `.md` file in `.kit/prompts/` (project-local) or `~/.kit/prompts/` (global).
|
||||
It becomes a `/slug` slash command in the kit input box — typed as `/filename` with optional arguments.
|
||||
|
||||
## File format
|
||||
|
||||
```
|
||||
---
|
||||
description: One-line description shown in autocomplete
|
||||
---
|
||||
|
||||
Body text of the prompt. Use $@ for all user-supplied arguments,
|
||||
$1 $2 etc. for positional arguments.
|
||||
```
|
||||
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Keep prompts action-oriented — they should tell kit *what to do*, not just *what to think about*
|
||||
- Prefer concrete steps over vague instructions
|
||||
- A prompt that does one thing well beats one that tries to cover every edge case
|
||||
- If the workflow already exists as a prompt, suggest extending it instead of duplicating
|
||||
@@ -0,0 +1,70 @@
|
||||
---
|
||||
description: Semantic version tagging workflow - analyzes commits and tags releases
|
||||
---
|
||||
|
||||
# Release Tagging Workflow
|
||||
|
||||
Tag a new version of this Go project following semantic versioning.
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Fetch remote tags**: `git fetch --tags origin`
|
||||
|
||||
2. **Find latest version**: `git tag -l | sort -V | tail -5` to see recent tags
|
||||
|
||||
3. **Analyze changes since last tag**:
|
||||
- `git log <latest-tag>..HEAD --oneline` - list commits
|
||||
- `git diff <latest-tag>..HEAD --stat` - see file stats
|
||||
- `git diff <latest-tag>..HEAD --name-only` - see changed files
|
||||
|
||||
4. **Determine version bump** (Semantic Versioning):
|
||||
- **MAJOR (X.0.0)**: Breaking API changes, incompatible modifications
|
||||
- **MINOR (0.X.0)**: New features, backward-compatible additions
|
||||
- **PATCH (0.0.X)**: Bug fixes, backward-compatible fixes
|
||||
|
||||
Look for indicators:
|
||||
- `feat:` or `feature:` commits → MINOR
|
||||
- `fix:` or `bugfix:` commits → PATCH
|
||||
- `breaking:` or `BREAKING CHANGE:` → MAJOR
|
||||
- Breaking API changes in `pkg/` or public interfaces → MAJOR
|
||||
- New commands, flags, or features → MINOR
|
||||
- Documentation-only changes → PATCH (or skip)
|
||||
|
||||
5. **Calculate new version**: Increment appropriate segment, reset lower segments to 0
|
||||
|
||||
6. **Draft tag message**:
|
||||
- Summarize key changes from commits
|
||||
- Group by type (Features, Fixes, Breaking Changes)
|
||||
- Keep concise but informative
|
||||
|
||||
7. **Create annotated tag**: `git tag -a vX.Y.Z -m "vX.Y.Z - <summary>\n\n<detailed list>"`
|
||||
|
||||
8. **Push tag**: `git push origin vX.Y.Z`
|
||||
|
||||
## Guidelines
|
||||
|
||||
- Always fetch remote tags first to avoid conflicts
|
||||
- Use annotated tags (`-a`) with descriptive messages
|
||||
- Follow semver strictly - when in doubt, prefer conservative bump (patch over minor)
|
||||
- For Go projects, changes to `pkg/` or exported APIs warrant careful version consideration
|
||||
- If no changes since last tag, suggest skipping the release
|
||||
- Include commit summaries in the tag message body
|
||||
|
||||
## Example Tag Message Format
|
||||
|
||||
```
|
||||
v0.30.1 - Bug fixes for model handling and UI improvements
|
||||
|
||||
Fixes:
|
||||
- Properly handle think tags from Qwen/DeepSeek models
|
||||
- Handle custom provider model persistence and bare model names
|
||||
|
||||
Improvements:
|
||||
- UI style refactoring and cleanup
|
||||
```
|
||||
|
||||
Wait for the user to confirm the version and message before executing tag commands.
|
||||
|
||||
---
|
||||
|
||||
$@
|
||||
@@ -1,22 +1,3 @@
|
||||
<!-- OPENSPEC:START -->
|
||||
# OpenSpec Instructions
|
||||
|
||||
These instructions are for AI assistants working in this project.
|
||||
|
||||
Always open `@/openspec/AGENTS.md` when the request:
|
||||
- Mentions planning or proposals (words like proposal, spec, change, plan)
|
||||
- Introduces new capabilities, breaking changes, architecture shifts, or big performance/security work
|
||||
- Sounds ambiguous and you need the authoritative spec before coding
|
||||
|
||||
Use `@/openspec/AGENTS.md` to learn:
|
||||
- How to create and apply change proposals
|
||||
- Spec format and conventions
|
||||
- Project structure and guidelines
|
||||
|
||||
Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
|
||||
<!-- OPENSPEC:END -->
|
||||
|
||||
# KIT Agent Guidelines
|
||||
|
||||
## Build/Test Commands
|
||||
@@ -42,6 +23,33 @@ Keep this managed block so 'openspec update' can refresh the instructions.
|
||||
- **Extension system** (`internal/extensions/`): Yaegi-interpreted Go, 13 lifecycle events, custom tools/commands/widgets/overlays/editor interceptors
|
||||
- **TUI** (`internal/ui/`): Bubble Tea v2 parent-child model (`AppModel` → `InputComponent`, `StreamComponent`, etc.)
|
||||
- **Decoupling pattern**: `cmd/root.go` has converter functions (e.g. `widgetProviderForUI()`) that bridge `internal/extensions/` types to `internal/ui/` types — the UI never imports extensions directly
|
||||
- **Public SDK** (`pkg/kit/`): The public-facing Go SDK for embedding Kit as a library. See rules below.
|
||||
|
||||
## Public SDK (`pkg/kit/`) Rules
|
||||
|
||||
`pkg/kit/` is the **public API surface** consumed by external Go developers. All exported symbols, types, function names, and godoc comments in this package are part of the SDK contract.
|
||||
|
||||
### No Dependency Name Leakage
|
||||
Internal dependency names (e.g. `charm.land/fantasy`, library-specific jargon) **must not** appear in:
|
||||
- **Exported function/method names** — use generic terms (`LLM`, `Provider`, `Message`) instead of library names
|
||||
- **Exported type names** — type aliases should use domain names (e.g. `LLMMessage`, not `FantasyMessage`)
|
||||
- **Godoc comments** on exported symbols — these are visible in `go doc` output and pkg.go.dev
|
||||
- **Struct field names and tags** on exported types
|
||||
|
||||
Using dependency types directly in **function bodies** (private implementation) is fine — that's invisible to SDK consumers.
|
||||
|
||||
### Naming Conventions for SDK Symbols
|
||||
- Type aliases re-exporting dependency types: use `LLM*` prefix (e.g. `LLMMessage`, `LLMUsage`, `LLMResponse`)
|
||||
- Conversion helpers: use `ConvertToLLM*` / `ConvertFromLLM*` (not the dependency name)
|
||||
- Provider queries: use `GetLLMProviders` (not `GetFantasyProviders`)
|
||||
- When wrapping internal methods, the `pkg/kit/` name should be dependency-agnostic even if the `internal/` method still uses the old name
|
||||
|
||||
### Deprecation Pattern
|
||||
When renaming a public SDK symbol, keep the old name as a deprecated wrapper for one release cycle:
|
||||
```go
|
||||
// Deprecated: Use NewName instead.
|
||||
func OldName() { return NewName() }
|
||||
```
|
||||
|
||||
## Key Patterns
|
||||
|
||||
@@ -92,3 +100,21 @@ Positional args are the prompt. `@file` args attach file content. Key flags: `--
|
||||
- Never guess or manually search the filesystem for external projects
|
||||
- Example: `btca ask -r https://github.com/user/repo -q "How does X work?"`
|
||||
- See `.agents/skills/btca-cli/SKILL.md` for full btca usage
|
||||
|
||||
## BTCA Configured Resources
|
||||
The following external repositories are configured in `btca.config.jsonc` for research:
|
||||
|
||||
- bubbletea
|
||||
- lipgloss
|
||||
- bubbles
|
||||
- glamour
|
||||
- fantasy
|
||||
- catwalk
|
||||
- crush
|
||||
- pi
|
||||
- iteratr
|
||||
- yaegi
|
||||
- acp-go-sdk
|
||||
- opencode
|
||||
- herald
|
||||
- herald-md
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
# Autoscroll Fix - Final Summary
|
||||
|
||||
## Root Cause
|
||||
|
||||
The autoscroll was failing for streaming assistant messages due to a bug in how `GotoBottom()` calculated item heights.
|
||||
|
||||
### The Problem
|
||||
|
||||
1. **Reasoning blocks** (`StreamingMessageItem` with `role="reasoning"`) are **never cached** because they have live duration counters that update every render
|
||||
2. The `Height()` method returns `0` when `cachedRender == ""`
|
||||
3. `GotoBottom()` was calling:
|
||||
```go
|
||||
itemHeight := item.Height() // Returns 0 for reasoning
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width) // Renders but doesn't cache (reasoning)
|
||||
itemHeight = item.Height() // Still returns 0!
|
||||
}
|
||||
```
|
||||
4. This caused incorrect scroll position calculations, especially during reasoning → assistant transitions
|
||||
|
||||
## The Solution
|
||||
|
||||
Changed `GotoBottom()` and `AtBottom()` to calculate height **directly from the rendered string** instead of relying on the cached height:
|
||||
|
||||
```go
|
||||
// OLD: item.Height() which checks cached render
|
||||
itemHeight := item.Height()
|
||||
if itemHeight == 0 {
|
||||
item.Render(s.width)
|
||||
itemHeight = item.Height() // Still might be 0!
|
||||
}
|
||||
|
||||
// NEW: Calculate from rendered string directly
|
||||
rendered := item.Render(s.width)
|
||||
itemHeight := strings.Count(rendered, "\n") + 1
|
||||
```
|
||||
|
||||
This works for **all** items regardless of whether they cache their render or not.
|
||||
|
||||
## Files Changed
|
||||
|
||||
### `internal/ui/scrolllist.go`
|
||||
- **`GotoBottom()`**: Calculate height from rendered string (2 loops)
|
||||
- **`AtBottom()`**: Calculate height from rendered string (1 loop)
|
||||
|
||||
### `internal/ui/model.go`
|
||||
- **`appendStreamingChunk()`**: For existing messages, call `GotoBottom()` directly (iteratr pattern)
|
||||
- **`refreshContent()`**: Simplified to only call `SetItems()` (removed redundant `GotoBottom()`)
|
||||
- **Bash streaming handler**: Removed redundant `GotoBottom()` after `refreshContent()`
|
||||
|
||||
## Testing Results
|
||||
|
||||
✅ **Test prompt**: "explore this repo"
|
||||
|
||||
**Before fix**:
|
||||
- Autoscroll stopped after reasoning block completed
|
||||
- Viewport stuck showing end of reasoning ("Thought for 203ms")
|
||||
- Assistant response streamed off-screen below
|
||||
|
||||
**After fix**:
|
||||
- Autoscroll works throughout reasoning block
|
||||
- Autoscroll continues during reasoning → assistant transition
|
||||
- Viewport stays at bottom showing latest assistant content
|
||||
- Final position shows end of response (build commands section)
|
||||
|
||||
## Behavior Verified
|
||||
|
||||
1. ✅ Streaming text auto-scrolls to bottom
|
||||
2. ✅ Works across reasoning → assistant transition
|
||||
3. ✅ Manual scroll up (PgUp) disables autoscroll
|
||||
4. ✅ Scroll to bottom (Alt+End) re-enables autoscroll
|
||||
5. ✅ Accurate positioning with no offset errors
|
||||
|
||||
## Performance Note
|
||||
|
||||
The fix calls `Render()` on all items during `GotoBottom()` calculations. This is acceptable because:
|
||||
- `Render()` is already optimized with caching for non-reasoning items
|
||||
- `GotoBottom()` is only called during content updates (not every frame)
|
||||
- Reasoning blocks need to render anyway for live duration updates
|
||||
- This matches iteratr's approach of ensuring items are rendered before height calculations
|
||||
@@ -18,7 +18,8 @@ A powerful, extensible AI coding agent CLI with multi-provider support, built-in
|
||||
## Features
|
||||
|
||||
- **Multi-Provider LLM Support**: Anthropic, OpenAI, Google Gemini, Ollama, Azure OpenAI, AWS Bedrock, OpenRouter, and more
|
||||
- **Built-in Core Tools**: bash, read, write, edit, grep, find, ls, spawn_subagent - no MCP overhead
|
||||
- **Built-in Core Tools**: bash (with interactive sudo password prompt), read, write, edit, grep, find, ls, subagent - no MCP overhead
|
||||
- **Smart @ Attachments**: Binary files auto-detected via MIME type, MCP resources via `@mcp:server:uri`
|
||||
- **MCP Integration**: Connect external MCP servers for expanded capabilities
|
||||
- **Extension System**: Write custom tools, commands, widgets, and UI modifications in Go
|
||||
- **Theming**: 22 built-in color themes (KITT, Catppuccin, Dracula, Nord, etc.) with runtime switching, persistence, and custom theme files
|
||||
@@ -191,6 +192,8 @@ mcpServers:
|
||||
--top-p Nucleus sampling 0.0-1.0 (default: 0.95)
|
||||
--top-k Limit top K tokens (default: 40)
|
||||
--stop-sequences Custom stop sequences (comma-separated)
|
||||
--frequency-penalty Penalize frequent tokens 0.0-2.0 (default: 0.0)
|
||||
--presence-penalty Penalize present tokens 0.0-2.0 (default: 0.0)
|
||||
--thinking-level Extended thinking level: off, minimal, low, medium, high (default: off)
|
||||
|
||||
# System
|
||||
@@ -209,7 +212,7 @@ kit auth status # Check authentication status
|
||||
|
||||
# Model database
|
||||
kit models [provider] # List available models (optionally filter by provider)
|
||||
kit models --all # Show all providers (not just Fantasy-compatible)
|
||||
kit models --all # Show all providers (not just LLM-compatible)
|
||||
kit update-models [source] # Update model database (from models.dev, URL, file, or 'embedded')
|
||||
|
||||
# Extension management
|
||||
@@ -307,41 +310,50 @@ kit -e examples/extensions/minimal.go
|
||||
- **Themes**: Register and switch color themes via `RegisterTheme`, `SetTheme`, `ListThemes`
|
||||
- **Custom Events**: Inter-extension communication via `EmitCustomEvent`
|
||||
|
||||
**Bridged SDK APIs** (NEW): Extensions can now access internal SDK capabilities:
|
||||
- **Tree Navigation**: Navigate conversation history (`GetTreeNode`, `GetCurrentBranch`, `NavigateTo`), summarize branches (`SummarizeBranch`), and implement fresh context loops (`CollapseBranch`)
|
||||
- **Skill Loading**: Dynamically load and inject skills at runtime (`LoadSkill`, `DiscoverSkills`, `InjectSkillAsContext`)
|
||||
- **Template Parsing**: Parse and render templates with `{{variables}}` (`ParseTemplate`, `RenderTemplate`), parse CLI-style arguments (`ParseArguments`, `SimpleParseArguments`), and evaluate model conditionals (`EvaluateModelConditional`, `RenderWithModelConditionals`)
|
||||
- **Model Resolution**: Resolve model fallback chains (`ResolveModelChain`), query model capabilities (`GetModelCapabilities`, `CheckModelAvailable`), and extract provider/model ID (`GetCurrentProvider`, `GetCurrentModelID`)
|
||||
|
||||
### Extension Examples
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
|
||||
- `minimal.go` - Clean UI with custom footer
|
||||
- `auto-commit.go` - Auto-commit on shutdown
|
||||
- `bookmark.go` - Bookmark conversations
|
||||
- `branded-output.go` - Branded output rendering
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
- `inline-bash.go` - Inline bash execution
|
||||
- `interactive-shell.go` - Interactive shell integration
|
||||
- `kit-kit.go` - Kit-in-Kit (sub-agent spawning)
|
||||
- `lsp-diagnostics.go` - LSP diagnostic integration
|
||||
- `notify.go` - Desktop notifications
|
||||
- `overlay-demo.go` - Modal dialogs
|
||||
- `permission-gate.go` - Permission gating for tools
|
||||
- `pirate.go` - Pirate-themed personality
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
- `tool-renderer-demo.go` - Custom tool call rendering
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
- [`minimal.go`](examples/extensions/minimal.go) - Clean UI with custom footer
|
||||
- [`auto-commit.go`](examples/extensions/auto-commit.go) - Auto-commit on shutdown
|
||||
- [`bookmark.go`](examples/extensions/bookmark.go) - Bookmark conversations
|
||||
- [`branded-output.go`](examples/extensions/branded-output.go) - Branded output rendering
|
||||
- [`bridge-demo.go`](examples/extensions/bridge_demo.go) - Bridged SDK API demo (tree navigation, skills, templates, model resolution)
|
||||
- [`compact-notify.go`](examples/extensions/compact-notify.go) - Notification on compaction
|
||||
- [`confirm-destructive.go`](examples/extensions/confirm-destructive.go) - Confirm destructive operations
|
||||
- [`context-inject.go`](examples/extensions/context-inject.go) - Inject context into conversations
|
||||
- [`conversation-manager.go`](examples/extensions/conversation-manager.go) - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- [`custom-editor-demo.go`](examples/extensions/custom-editor-demo.go) - Vim-like modal editor
|
||||
- [`dev-reload.go`](examples/extensions/dev-reload.go) - Development live-reload
|
||||
- [`header-footer-demo.go`](examples/extensions/header-footer-demo.go) - Custom headers and footers
|
||||
- [`inline-bash.go`](examples/extensions/inline-bash.go) - Inline bash execution
|
||||
- [`interactive-shell.go`](examples/extensions/interactive-shell.go) - Interactive shell integration
|
||||
- [`kit-kit.go`](examples/extensions/kit-kit.go) - Kit-in-Kit (sub-agent spawning)
|
||||
- [`lsp-diagnostics.go`](examples/extensions/lsp-diagnostics.go) - LSP diagnostic integration
|
||||
- [`notify.go`](examples/extensions/notify.go) - Desktop notifications
|
||||
- [`overlay-demo.go`](examples/extensions/overlay-demo.go) - Modal dialogs
|
||||
- [`permission-gate.go`](examples/extensions/permission-gate.go) - Permission gating for tools
|
||||
- [`pirate.go`](examples/extensions/pirate.go) - Pirate-themed personality
|
||||
- [`plan-mode.go`](examples/extensions/plan-mode.go) - Read-only planning mode
|
||||
- [`project-rules.go`](examples/extensions/project-rules.go) - Project-specific rules
|
||||
- [`prompt-demo.go`](examples/extensions/prompt-demo.go) - Interactive prompts (select/confirm/input)
|
||||
- [`prompt-templates.go`](examples/extensions/prompt-templates.go) - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- [`protected-paths.go`](examples/extensions/protected-paths.go) - Path protection for sensitive files
|
||||
- [`subagent-widget.go`](examples/extensions/subagent-widget.go) - Multi-agent orchestration with status widget
|
||||
- [`subagent-test.go`](examples/extensions/subagent-test.go) - Subagent testing utilities
|
||||
- [`summarize.go`](examples/extensions/summarize.go) - Conversation summarization
|
||||
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
|
||||
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
|
||||
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
|
||||
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
|
||||
|
||||
Also see `.kit/extensions/go-edit-lint.go` (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
@@ -398,7 +410,7 @@ func TestMyExtension(t *testing.T) {
|
||||
- `AssertPrinted()`, `AssertPrintedContains()` — Verify output
|
||||
- `AssertToolRegistered()`, `AssertCommandRegistered()` — Verify registration
|
||||
|
||||
See `examples/extensions/tool-logger_test.go` for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
See [`examples/extensions/tool-logger_test.go`](examples/extensions/tool-logger_test.go) for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
|
||||
### Prompt Templates
|
||||
|
||||
@@ -420,10 +432,13 @@ Focus on $1 specifically.
|
||||
|
||||
**Argument placeholders:**
|
||||
- `$1`, `$2`, etc. — Individual arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments
|
||||
- `$@` or `$ARGUMENTS` — All arguments (zero or more)
|
||||
- `$+` — All arguments (one or more required; error if none given)
|
||||
- `${@:2}` — Arguments from position 2 onwards
|
||||
- `${@:1:3}` — 3 arguments starting at position 1
|
||||
|
||||
Placeholders inside fenced code blocks (```) and inline code spans are ignored.
|
||||
|
||||
Disable templates with `--no-prompt-templates` or load a specific template with `--prompt-template <name>`.
|
||||
|
||||
## Session Management
|
||||
@@ -469,9 +484,18 @@ During an interactive session, use these slash commands:
|
||||
| `/import <path>` | Import and switch to a session from a JSONL file |
|
||||
| `/share` | Upload session to GitHub Gist and get a shareable viewer URL |
|
||||
| `/tree` | Navigate the session tree |
|
||||
| `/fork` | Branch from an earlier message |
|
||||
| `/fork` | Fork to new session from an earlier message |
|
||||
| `/new` | Start a fresh session |
|
||||
|
||||
### Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Description |
|
||||
|----------|-------------|
|
||||
| `Ctrl+X e` | Open `$VISUAL`/`$EDITOR` to compose or edit your prompt |
|
||||
| `Ctrl+X s` | Steer — inject a system-level instruction mid-turn |
|
||||
| `ESC ESC` | Cancel the current operation (tool call or streaming) |
|
||||
| `↑` / `↓` | Navigate prompt history |
|
||||
|
||||
## Go SDK
|
||||
|
||||
Embed Kit in your Go applications:
|
||||
@@ -494,7 +518,7 @@ func main() {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer host.Close()
|
||||
defer func() { _ = host.Close() }()
|
||||
|
||||
// Send a prompt
|
||||
response, err := host.Prompt(ctx, "What is 2+2?")
|
||||
@@ -523,7 +547,12 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
NoSession: true, // Ephemeral mode
|
||||
|
||||
// Tool options
|
||||
ExtraTools: []kit.Tool{...}, // Additional tools alongside defaults
|
||||
Tools: []kit.Tool{...}, // Replace default tool set entirely
|
||||
ExtraTools: []kit.Tool{...}, // Add tools alongside defaults
|
||||
DisableCoreTools: true, // Use no core tools (0 tools, for chat-only)
|
||||
|
||||
// Configuration
|
||||
SkipConfig: true, // Skip .kit.yml files (viper defaults + env vars still apply)
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // Auto-compact near context limit
|
||||
@@ -532,26 +561,51 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
})
|
||||
```
|
||||
|
||||
### Custom Tools
|
||||
|
||||
Create custom tools with automatic schema generation — no external dependencies needed:
|
||||
|
||||
```go
|
||||
type SearchInput struct {
|
||||
Query string `json:"query" description:"Search query"`
|
||||
}
|
||||
|
||||
searchTool := kit.NewTool("search", "Search the codebase",
|
||||
func(ctx context.Context, input SearchInput) (kit.ToolOutput, error) {
|
||||
return kit.TextResult("Found: ..."), nil
|
||||
},
|
||||
)
|
||||
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
ExtraTools: []kit.Tool{searchTool}, // adds alongside built-in tools
|
||||
})
|
||||
```
|
||||
|
||||
Use `kit.NewParallelTool` for tools safe to run concurrently. See the [SDK docs](/sdk/overview) for full details on struct tags, `ToolOutput` fields, and `ToolCallIDFromContext`.
|
||||
|
||||
### With Callbacks
|
||||
|
||||
```go
|
||||
response, err := host.PromptWithCallbacks(
|
||||
unsub := host.OnToolCall(func(e kit.ToolCallEvent) {
|
||||
println("Calling tool:", e.ToolName)
|
||||
})
|
||||
defer unsub()
|
||||
|
||||
unsub2 := host.OnToolResult(func(e kit.ToolResultEvent) {
|
||||
if e.IsError {
|
||||
println("Tool failed:", e.ToolName)
|
||||
}
|
||||
})
|
||||
defer unsub2()
|
||||
|
||||
unsub3 := host.OnStreaming(func(e kit.MessageUpdateEvent) {
|
||||
print(e.Chunk)
|
||||
})
|
||||
defer unsub3()
|
||||
|
||||
response, err := host.Prompt(
|
||||
ctx,
|
||||
"List files in current directory",
|
||||
func(name, args string) {
|
||||
// Tool call started
|
||||
println("Calling tool:", name)
|
||||
},
|
||||
func(name, args, result string, isError bool) {
|
||||
// Tool call completed
|
||||
if isError {
|
||||
println("Tool failed:", name)
|
||||
}
|
||||
},
|
||||
func(chunk string) {
|
||||
// Streaming text chunk
|
||||
print(chunk)
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
@@ -715,7 +769,7 @@ Use `custom/custom` when pointing Kit at any OpenAI-compatible endpoint with `--
|
||||
kit --provider-url "http://localhost:8080/v1" "Hello"
|
||||
```
|
||||
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through fantasy's `openaicompat` provider and supports:
|
||||
This automatically defaults to `custom/custom` without needing to specify a model. The custom provider routes through the `openaicompat` provider and supports:
|
||||
|
||||
- Zero cost tracking (input/output = 0)
|
||||
- 262K context window, 65K output limit
|
||||
|
||||
@@ -82,6 +82,12 @@
|
||||
"name": "herald",
|
||||
"url": "https://github.com/indaco/herald",
|
||||
"branch": "main"
|
||||
},
|
||||
{
|
||||
"type": "git",
|
||||
"name": "herald-md",
|
||||
"url": "https://github.com/indaco/herald-md",
|
||||
"branch": "main"
|
||||
}
|
||||
],
|
||||
"model": "claude-haiku-4-5",
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
acp "github.com/coder/acp-go-sdk"
|
||||
|
||||
"github.com/mark3labs/kit/internal/acpserver"
|
||||
@@ -54,6 +55,8 @@ func runACP(cmd *cobra.Command, _ []string) error {
|
||||
conn.SetLogger(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
})))
|
||||
// Also set charmbracelet/log level for acpserver package logging
|
||||
log.SetLevel(log.DebugLevel)
|
||||
}
|
||||
|
||||
// Wait for either the client to disconnect or a signal.
|
||||
|
||||
+1
-1
@@ -55,7 +55,7 @@ func printAllProviders(showAll bool) error {
|
||||
if showAll {
|
||||
providerIDs = kit.GetSupportedProviders()
|
||||
} else {
|
||||
providerIDs = kit.GetFantasyProviders()
|
||||
providerIDs = kit.GetLLMProviders()
|
||||
}
|
||||
sort.Strings(providerIDs)
|
||||
|
||||
|
||||
+821
-234
File diff suppressed because it is too large
Load Diff
@@ -41,7 +41,6 @@ func BuildAppOptions(mcpConfig *config.Config, modelName string, serverNames, to
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
Quiet: quietFlag,
|
||||
Debug: viper.GetBool("debug"),
|
||||
CompactMode: viper.GetBool("compact"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -131,7 +130,6 @@ func SetupCLIForNonInteractive(k *kit.Kit) (*ui.CLI, error) {
|
||||
Agent: agentAdapter,
|
||||
ModelString: viper.GetString("model"),
|
||||
Debug: viper.GetBool("debug"),
|
||||
Compact: viper.GetBool("compact"),
|
||||
Quiet: quietFlag,
|
||||
ShowDebug: false,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// TestAllExtensions_Load is a smoke test that verifies every single-file
|
||||
// example extension in this directory can be loaded by the Yaegi interpreter
|
||||
// without errors. This catches syntax errors, missing symbols, bad imports,
|
||||
// and Init signature mismatches.
|
||||
func TestAllExtensions_Load(t *testing.T) {
|
||||
files := extensionFiles(t)
|
||||
|
||||
for _, file := range files {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
ext := harness.LoadFile(file)
|
||||
if ext == nil {
|
||||
t.Fatalf("%s: extension should not be nil after loading", file)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Logf("successfully loaded %d extensions", len(files))
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/extensions"
|
||||
"github.com/mark3labs/kit/pkg/extensions/test"
|
||||
)
|
||||
|
||||
// extensionFiles returns all single-file extensions in the current directory.
|
||||
// It skips test files, the test template, and files without an Init function.
|
||||
func extensionFiles(t *testing.T) []string {
|
||||
t.Helper()
|
||||
|
||||
skip := map[string]bool{
|
||||
"extension_test_template.go": true,
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(".")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read directory: %v", err)
|
||||
}
|
||||
|
||||
var files []string
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() || filepath.Ext(name) != ".go" {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(name, "_test.go") || skip[name] {
|
||||
continue
|
||||
}
|
||||
src, err := os.ReadFile(name)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read %s: %v", name, err)
|
||||
}
|
||||
if !strings.Contains(string(src), "func Init(") {
|
||||
continue
|
||||
}
|
||||
files = append(files, name)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
t.Fatal("no extensions found — check the directory")
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// TestAllExtensions_Lifecycle verifies that every extension survives a full
|
||||
// SessionStart → SessionShutdown round-trip without errors.
|
||||
func TestAllExtensions_Lifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "smoke-test-session",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart error: %v", err)
|
||||
}
|
||||
|
||||
_, err = harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_CommandSanity checks that every registered command has
|
||||
// a non-empty name, a non-empty description, no spaces in the name, no
|
||||
// leading slash, a non-nil Execute function, and no duplicate names.
|
||||
func TestAllExtensions_CommandSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
cmds := harness.RegisteredCommands()
|
||||
seen := make(map[string]bool)
|
||||
for _, cmd := range cmds {
|
||||
if cmd.Name == "" {
|
||||
t.Error("command has empty name")
|
||||
}
|
||||
if strings.Contains(cmd.Name, " ") {
|
||||
t.Errorf("command %q contains spaces", cmd.Name)
|
||||
}
|
||||
if strings.HasPrefix(cmd.Name, "/") {
|
||||
t.Errorf("command %q has leading slash (framework adds it)", cmd.Name)
|
||||
}
|
||||
if cmd.Description == "" {
|
||||
t.Errorf("command %q has empty description", cmd.Name)
|
||||
}
|
||||
if cmd.Execute == nil {
|
||||
t.Errorf("command %q has nil Execute function", cmd.Name)
|
||||
}
|
||||
if seen[cmd.Name] {
|
||||
t.Errorf("duplicate command name %q", cmd.Name)
|
||||
}
|
||||
seen[cmd.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ToolSanity checks that every registered tool has a
|
||||
// non-empty name, a non-empty description, at least one executor, valid
|
||||
// JSON in its Parameters field, and no duplicate names.
|
||||
func TestAllExtensions_ToolSanity(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
tools := harness.RegisteredTools()
|
||||
seen := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
if tool.Name == "" {
|
||||
t.Error("tool has empty name")
|
||||
}
|
||||
if tool.Description == "" {
|
||||
t.Errorf("tool %q has empty description", tool.Name)
|
||||
}
|
||||
if tool.Execute == nil && tool.ExecuteWithContext == nil {
|
||||
t.Errorf("tool %q has no executor (both Execute and ExecuteWithContext are nil)", tool.Name)
|
||||
}
|
||||
if tool.Parameters != "" && !json.Valid([]byte(tool.Parameters)) {
|
||||
t.Errorf("tool %q has invalid JSON in Parameters: %s", tool.Name, tool.Parameters)
|
||||
}
|
||||
if seen[tool.Name] {
|
||||
t.Errorf("duplicate tool name %q", tool.Name)
|
||||
}
|
||||
seen[tool.Name] = true
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_ZeroValueEvents fires every event type (as zero-value
|
||||
// structs) at each extension and verifies no errors are returned. Extensions
|
||||
// should be resilient to events they don't handle and to events with empty
|
||||
// fields.
|
||||
func TestAllExtensions_ZeroValueEvents(t *testing.T) {
|
||||
// Build the set of zero-value events for every event type.
|
||||
zeroEvents := []extensions.Event{
|
||||
extensions.ToolCallEvent{},
|
||||
extensions.ToolExecutionStartEvent{},
|
||||
extensions.ToolExecutionEndEvent{},
|
||||
extensions.ToolOutputEvent{},
|
||||
extensions.ToolResultEvent{},
|
||||
extensions.InputEvent{},
|
||||
extensions.BeforeAgentStartEvent{},
|
||||
extensions.AgentStartEvent{},
|
||||
extensions.AgentEndEvent{},
|
||||
extensions.MessageStartEvent{},
|
||||
extensions.MessageUpdateEvent{},
|
||||
extensions.MessageEndEvent{},
|
||||
extensions.SessionStartEvent{},
|
||||
extensions.SessionShutdownEvent{},
|
||||
extensions.ModelChangeEvent{},
|
||||
extensions.ContextPrepareEvent{},
|
||||
extensions.BeforeForkEvent{},
|
||||
extensions.BeforeSessionSwitchEvent{},
|
||||
extensions.BeforeCompactEvent{},
|
||||
extensions.SubagentStartEvent{},
|
||||
extensions.SubagentChunkEvent{},
|
||||
extensions.SubagentEndEvent{},
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for _, ev := range zeroEvents {
|
||||
_, err := harness.Emit(ev)
|
||||
if err != nil {
|
||||
t.Errorf("event %T returned error: %v", ev, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_WidgetSanity emits SessionStart and then checks that
|
||||
// any widgets set during initialization have non-empty IDs and valid
|
||||
// placements.
|
||||
func TestAllExtensions_WidgetSanity(t *testing.T) {
|
||||
validPlacements := map[extensions.WidgetPlacement]bool{
|
||||
"above": true,
|
||||
"below": true,
|
||||
}
|
||||
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
// Trigger SessionStart so extensions that set widgets on init do so.
|
||||
_, _ = harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "widget-sanity-test",
|
||||
})
|
||||
|
||||
// Widgets is an exported field on MockContext; reads are safe
|
||||
// here because Emit returned synchronously.
|
||||
for id, w := range harness.Context().Widgets {
|
||||
if w.ID == "" {
|
||||
t.Errorf("widget stored with key %q has empty ID", id)
|
||||
}
|
||||
if w.ID != id {
|
||||
t.Errorf("widget key %q doesn't match widget ID %q", id, w.ID)
|
||||
}
|
||||
if !validPlacements[w.Placement] {
|
||||
t.Errorf("widget %q has invalid placement %q (want \"above\" or \"below\")", id, w.Placement)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAllExtensions_IdempotentLifecycle verifies that receiving SessionStart
|
||||
// twice and SessionShutdown twice doesn't cause errors — extensions should
|
||||
// be defensive about repeated lifecycle events.
|
||||
func TestAllExtensions_IdempotentLifecycle(t *testing.T) {
|
||||
for _, file := range extensionFiles(t) {
|
||||
t.Run(file, func(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile(file)
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{
|
||||
SessionID: "idempotent-test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range 2 {
|
||||
_, err := harness.Emit(extensions.SessionShutdownEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionShutdown #%d error: %v", i+1, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build ignore
|
||||
|
||||
// bridge_demo.go - Demonstrates the new bridged SDK APIs for extensions.
|
||||
// This extension showcases tree navigation, skill loading, template parsing,
|
||||
// and model resolution capabilities.
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
discoveredSkills []ext.Skill
|
||||
currentBranch []ext.TreeNode
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Register /tree-info command to demonstrate tree navigation
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree-info",
|
||||
Description: "Show current conversation tree information",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
info := fmt.Sprintf("Current branch has %d nodes:\n", len(branch))
|
||||
for i, node := range branch {
|
||||
info += fmt.Sprintf(" [%d] %s (%s): %s...\n", i, node.Type, node.ID[:8], truncate(node.Content, 40))
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /discover-skills command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "discover-skills",
|
||||
Description: "Discover and list available skills",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error != "" {
|
||||
return "", fmt.Errorf("discovery failed: %s", result.Error)
|
||||
}
|
||||
discoveredSkills = result.Skills
|
||||
|
||||
info := fmt.Sprintf("Discovered %d skills:\n", len(result.Skills))
|
||||
for _, s := range result.Skills {
|
||||
info += fmt.Sprintf(" - %s: %s\n", s.Name, s.Description)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /parse-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "parse-template",
|
||||
Description: "Parse a template and show extracted variables",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "Hello {{name}}, welcome to {{place}}!"
|
||||
}
|
||||
tpl := ctx.ParseTemplate("demo", args)
|
||||
info := fmt.Sprintf("Template: %s\nVariables: %v", tpl.Content, tpl.Variables)
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /render-template command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "render-template",
|
||||
Description: "Render a template with variables (usage: /render-template name=John place=Kit)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
tpl := ctx.ParseTemplate("demo", "Hello {{name}}, welcome to {{place}}!")
|
||||
vars := ctx.ParseArguments(args, ext.ArgumentPattern{
|
||||
Flags: map[string]string{"name": "name", "place": "place"},
|
||||
})
|
||||
rendered := ctx.RenderTemplate(tpl, vars.Vars)
|
||||
ctx.PrintInfo("Rendered: " + rendered)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /check-model command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "check-model",
|
||||
Description: "Check model capabilities and availability",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
model := args
|
||||
if model == "" {
|
||||
model = ctx.Model
|
||||
}
|
||||
|
||||
available := ctx.CheckModelAvailable(model)
|
||||
caps, err := ctx.GetModelCapabilities(model)
|
||||
|
||||
info := fmt.Sprintf("Model: %s\n", model)
|
||||
info += fmt.Sprintf("Available: %v\n", available)
|
||||
if err == "" {
|
||||
info += fmt.Sprintf("Provider: %s\n", caps.Provider)
|
||||
info += fmt.Sprintf("Context Limit: %d\n", caps.ContextLimit)
|
||||
info += fmt.Sprintf("Reasoning: %v\n", caps.Reasoning)
|
||||
} else {
|
||||
info += fmt.Sprintf("Error: %s\n", err)
|
||||
}
|
||||
ctx.PrintInfo(info)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /resolve-chain command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "resolve-chain",
|
||||
Description: "Resolve a model chain (usage: /resolve-chain claude-opus,gpt-4o,claude-sonnet)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
args = "anthropic/claude-opus-4,anthropic/claude-sonnet-4,openai/gpt-4o"
|
||||
}
|
||||
prefs := ctx.SimpleParseArguments(args, 1)
|
||||
chain := []string{}
|
||||
if len(prefs) > 1 {
|
||||
// Split the first arg by comma
|
||||
for _, p := range strings.Split(prefs[1], ",") {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
chain = append(chain, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(chain)
|
||||
info, _ := json.MarshalIndent(result, "", " ")
|
||||
ctx.PrintInfo("Resolution Result:\n" + string(info))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Register /test-conditional command
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "test-conditional",
|
||||
Description: "Test model conditional rendering",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
content := `<if-model is="claude-*">This is for Claude models<else>This is for other models</if-model>`
|
||||
rendered := ctx.RenderWithModelConditionals(content)
|
||||
ctx.PrintInfo("Input: " + content)
|
||||
ctx.PrintInfo("Output: " + rendered)
|
||||
ctx.PrintInfo(fmt.Sprintf("Current model matches 'claude-*': %v", ctx.EvaluateModelConditional("claude-*")))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnSessionStart: discover skills automatically
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
result := ctx.DiscoverSkills()
|
||||
if result.Error == "" && len(result.Skills) > 0 {
|
||||
discoveredSkills = result.Skills
|
||||
ctx.SetStatus("bridge-demo", fmt.Sprintf("%d skills", len(result.Skills)), 50)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
//go:build ignore
|
||||
|
||||
// conversation-manager.go - Advanced conversation tree navigation and management.
|
||||
// This extension demonstrates:
|
||||
// - Tree navigation (GetTreeNode, GetCurrentBranch, NavigateTo)
|
||||
// - Branch summarization and collapsing
|
||||
// - Interactive tree exploration
|
||||
//
|
||||
// Commands:
|
||||
// /tree - Show conversation tree structure
|
||||
// /branch - Show current branch path
|
||||
// /goto <entry-id> - Navigate to a specific entry
|
||||
// /summarize <n> - Summarize last N messages
|
||||
// /fresh-context - Collapse branch and start fresh
|
||||
// /loop <n> <prompt> - Execute prompt N times with fresh context each iteration
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
loopActive bool
|
||||
loopCount int
|
||||
loopCurrent int
|
||||
loopPrompt string
|
||||
loopStartNode string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// /tree - Show tree structure
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "tree",
|
||||
Description: "Show conversation tree structure",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showTree(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /branch - Show current branch
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "branch",
|
||||
Description: "Show current conversation branch",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
showBranch(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /goto - Navigate to entry
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "goto",
|
||||
Description: "Navigate to a specific entry ID (usage: /goto <entry-id>)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if args == "" {
|
||||
ctx.PrintError("Usage: /goto <entry-id>")
|
||||
return "", nil
|
||||
}
|
||||
result := ctx.NavigateTo(args)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Navigation failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
ctx.PrintInfo(fmt.Sprintf("Navigated to entry: %s", args))
|
||||
|
||||
// Show the node we navigated to
|
||||
node := ctx.GetTreeNode(args)
|
||||
if node != nil {
|
||||
ctx.PrintInfo(fmt.Sprintf("Entry type: %s, Role: %s", node.Type, node.Role))
|
||||
}
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /summarize - Summarize recent messages
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "summarize",
|
||||
Description: "Summarize last N messages (usage: /summarize [n=5])",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
n := 5
|
||||
if args != "" {
|
||||
if parsed, err := strconv.Atoi(args); err == nil && parsed > 0 {
|
||||
n = parsed
|
||||
}
|
||||
}
|
||||
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 2 {
|
||||
ctx.PrintError("Not enough messages to summarize")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Find range to summarize
|
||||
startIdx := len(branch) - n - 1
|
||||
if startIdx < 0 {
|
||||
startIdx = 0
|
||||
}
|
||||
endIdx := len(branch) - 1
|
||||
|
||||
fromID := branch[startIdx].ID
|
||||
toID := branch[endIdx].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Summarizing messages %d to %d...", startIdx, endIdx))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Summary",
|
||||
})
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /fresh-context - Collapse and restart
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "fresh-context",
|
||||
Description: "Collapse conversation to summary and start fresh",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) < 3 {
|
||||
ctx.PrintError("Not enough context to collapse")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Keep first message (system), summarize rest
|
||||
fromID := branch[1].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo("Generating summary for context collapse...")
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
|
||||
if summary == "" {
|
||||
ctx.PrintError("Failed to generate summary")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Collapse the branch
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if !result.Success {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
return "", nil
|
||||
}
|
||||
|
||||
ctx.PrintInfo("Context collapsed. Starting fresh with summary.")
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Collapsed Context",
|
||||
})
|
||||
|
||||
// Set a widget showing we're in fresh mode
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "fresh-context",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: "🌱 Fresh Context Mode - Previous conversation collapsed"},
|
||||
Style: ext.WidgetStyle{BorderColor: "#a6e3a1"},
|
||||
})
|
||||
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// /loop - Execute with fresh context each iteration
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "loop",
|
||||
Description: "Execute prompt N times with fresh context (usage: /loop 5 analyze this code)",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
if loopActive {
|
||||
ctx.PrintError("Loop already in progress. Wait for completion.")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
parts := strings.SplitN(args, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
ctx.PrintError("Usage: /loop <count> <prompt>")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(parts[0])
|
||||
if err != nil || count <= 0 || count > 10 {
|
||||
ctx.PrintError("Invalid count (must be 1-10)")
|
||||
return "", nil
|
||||
}
|
||||
|
||||
loopCount = count
|
||||
loopCurrent = 0
|
||||
loopPrompt = parts[1]
|
||||
loopActive = true
|
||||
|
||||
// Store current branch position
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 {
|
||||
loopStartNode = branch[len(branch)-1].ID
|
||||
}
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Starting loop: %d iterations", loopCount))
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: 0/%d - %s", loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Start first iteration
|
||||
executeLoopIteration(ctx)
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// OnAgentEnd handles loop continuation
|
||||
api.OnAgentEnd(func(e ext.AgentEndEvent, ctx ext.Context) {
|
||||
if !loopActive {
|
||||
return
|
||||
}
|
||||
|
||||
loopCurrent++
|
||||
|
||||
if loopCurrent >= loopCount {
|
||||
// Loop complete
|
||||
loopActive = false
|
||||
ctx.RemoveWidget("loop-progress")
|
||||
ctx.PrintInfo(fmt.Sprintf("✅ Loop complete: %d/%d iterations", loopCurrent, loopCount))
|
||||
|
||||
// Show final summary
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) > 0 && loopStartNode != "" {
|
||||
summary := ctx.SummarizeBranch(loopStartNode, branch[len(branch)-1].ID)
|
||||
if summary != "" {
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: summary,
|
||||
BorderColor: "#a6e3a1",
|
||||
Subtitle: "conversation-manager · Loop Summary",
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update progress
|
||||
ctx.SetWidget(ext.WidgetConfig{
|
||||
ID: "loop-progress",
|
||||
Placement: ext.WidgetAbove,
|
||||
Content: ext.WidgetContent{Text: fmt.Sprintf("🔄 Loop: %d/%d - %s", loopCurrent, loopCount, loopPrompt)},
|
||||
Style: ext.WidgetStyle{BorderColor: "#fab387"},
|
||||
})
|
||||
|
||||
// Collapse previous iteration for fresh context
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) >= 2 {
|
||||
// Find the user messages (look for the one before the last assistant message)
|
||||
// We want to collapse from the user message that started this iteration
|
||||
// to the last assistant response
|
||||
var collapseStartIdx = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if branch[i].Role == "assistant" {
|
||||
// Found the last assistant message, now find the user message before it
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
if branch[j].Role == "user" {
|
||||
collapseStartIdx = j
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if collapseStartIdx >= 0 {
|
||||
fromID := branch[collapseStartIdx].ID
|
||||
toID := branch[len(branch)-1].ID
|
||||
|
||||
ctx.PrintInfo(fmt.Sprintf("Collapsing iteration %d for fresh context...", loopCurrent))
|
||||
summary := ctx.SummarizeBranch(fromID, toID)
|
||||
if summary != "" {
|
||||
result := ctx.CollapseBranch(fromID, toID, summary)
|
||||
if result.Success {
|
||||
ctx.PrintInfo("Context collapsed successfully")
|
||||
} else {
|
||||
ctx.PrintError(fmt.Sprintf("Collapse failed: %s", result.Error))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay to let UI update
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Trigger next iteration
|
||||
executeLoopIteration(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// showTree displays the conversation tree structure
|
||||
func showTree(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("Tree is empty")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Conversation Tree (%d nodes):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
prefix := " "
|
||||
if i == len(branch)-1 {
|
||||
prefix = "▶ " // Current node
|
||||
} else {
|
||||
prefix = " "
|
||||
}
|
||||
|
||||
roleIcon := "💬"
|
||||
switch node.Role {
|
||||
case "user":
|
||||
roleIcon = "👤"
|
||||
case "assistant":
|
||||
roleIcon = "🤖"
|
||||
case "system":
|
||||
roleIcon = "⚙️"
|
||||
}
|
||||
|
||||
content := truncate(node.Content, 50)
|
||||
if node.Type == "branch_summary" {
|
||||
roleIcon = "📋"
|
||||
content = "[Summary] " + truncate(node.Content, 40)
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s%s %s: %s (%s...)\n", prefix, roleIcon, node.Role, node.ID[:8], content))
|
||||
|
||||
// Show children count if any
|
||||
children := ctx.GetChildren(node.ID)
|
||||
if len(children) > 0 {
|
||||
output.WriteString(fmt.Sprintf(" └─ %d branch(es)\n", len(children)))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#89b4fa",
|
||||
Subtitle: "conversation-manager · Tree View",
|
||||
})
|
||||
}
|
||||
|
||||
// showBranch displays the current branch path
|
||||
func showBranch(ctx ext.Context) {
|
||||
branch := ctx.GetCurrentBranch()
|
||||
if len(branch) == 0 {
|
||||
ctx.PrintInfo("No active branch")
|
||||
return
|
||||
}
|
||||
|
||||
var output strings.Builder
|
||||
output.WriteString(fmt.Sprintf("Current Branch (%d nodes from root to leaf):\n\n", len(branch)))
|
||||
|
||||
for i, node := range branch {
|
||||
marker := " "
|
||||
if i == len(branch)-1 {
|
||||
marker = "▶ " // Current leaf
|
||||
}
|
||||
|
||||
output.WriteString(fmt.Sprintf("%s[%d] %s (%s): %s\n",
|
||||
marker, i, node.Type, node.ID[:8], truncate(node.Content, 40)))
|
||||
}
|
||||
|
||||
// Show current node details
|
||||
leaf := branch[len(branch)-1]
|
||||
output.WriteString(fmt.Sprintf("\nCurrent Leaf:\n"))
|
||||
output.WriteString(fmt.Sprintf(" ID: %s\n", leaf.ID))
|
||||
output.WriteString(fmt.Sprintf(" Type: %s\n", leaf.Type))
|
||||
output.WriteString(fmt.Sprintf(" Role: %s\n", leaf.Role))
|
||||
output.WriteString(fmt.Sprintf(" Model: %s\n", leaf.Model))
|
||||
output.WriteString(fmt.Sprintf(" Children: %d\n", len(leaf.Children)))
|
||||
|
||||
ctx.PrintBlock(ext.PrintBlockOpts{
|
||||
Text: output.String(),
|
||||
BorderColor: "#cba6f7",
|
||||
Subtitle: "conversation-manager · Branch View",
|
||||
})
|
||||
}
|
||||
|
||||
// executeLoopIteration triggers the next loop iteration
|
||||
func executeLoopIteration(ctx ext.Context) {
|
||||
iterationPrompt := fmt.Sprintf("[%d/%d] %s", loopCurrent+1, loopCount, loopPrompt)
|
||||
ctx.SendMessage(iterationPrompt)
|
||||
}
|
||||
|
||||
// truncate helper
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-3] + "..."
|
||||
}
|
||||
@@ -7,10 +7,12 @@
|
||||
// development: edit your extension source, then type /reload to pick up
|
||||
// changes immediately.
|
||||
//
|
||||
// Event handlers, slash commands, tool renderers, message renderers, and
|
||||
// keyboard shortcuts update immediately. Extension-defined tools are NOT
|
||||
// updated (they are baked into the agent at creation time and require a
|
||||
// restart).
|
||||
// Note: Extensions in autoloaded directories (~/.config/kit/extensions/
|
||||
// and .kit/extensions/) are automatically reloaded on save. The /reload
|
||||
// command is useful for extensions loaded via -e from other locations.
|
||||
//
|
||||
// Event handlers, slash commands, tool definitions, tool renderers,
|
||||
// message renderers, and keyboard shortcuts all update immediately.
|
||||
//
|
||||
// Commands:
|
||||
// /reload — hot-reload all extensions from disk
|
||||
|
||||
@@ -10,13 +10,21 @@ import (
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// re matches !{...} with non-greedy content.
|
||||
var re = regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
|
||||
// Init expands inline bash expressions in user prompts before they reach the
|
||||
// LLM. Text like !{git branch --show-current} is replaced with the command's
|
||||
// stdout.
|
||||
// LLM. Text like !{git rev-parse --abbrev-ref HEAD} is replaced with the
|
||||
// command's stdout.
|
||||
//
|
||||
// In interactive mode the expansion happens at submit time via an editor
|
||||
// interceptor, so the expanded text is also visible in the user message
|
||||
// block on screen. In non-interactive mode (CLI, script, queue) the
|
||||
// expansion happens via OnInput transform.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "Fix the tests on !{git branch --show-current}"
|
||||
// "Fix the tests on !{git rev-parse --abbrev-ref HEAD}"
|
||||
// → "Fix the tests on main"
|
||||
//
|
||||
// "The current directory is !{pwd}"
|
||||
@@ -24,29 +32,59 @@ import (
|
||||
//
|
||||
// Usage: kit -e examples/extensions/inline-bash.go
|
||||
func Init(api ext.API) {
|
||||
// Matches !{...} with non-greedy content.
|
||||
re := regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
// ── Interactive mode: editor interceptor ──────────────────────────
|
||||
// Intercept Enter / Ctrl+D so we can expand !{...} BEFORE the
|
||||
// SubmitMsg is created. This ensures the expanded text appears in
|
||||
// the user message block on screen as well as in the LLM prompt.
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
if !ctx.Interactive {
|
||||
return
|
||||
}
|
||||
ctx.SetEditor(ext.EditorConfig{
|
||||
HandleKey: func(key string, currentText string) ext.EditorKeyAction {
|
||||
if (key == "enter" || key == "ctrl+d") && re.MatchString(currentText) {
|
||||
expanded := expand(currentText)
|
||||
// Clear the textarea asynchronously — calling
|
||||
// SetEditorText synchronously from inside Update()
|
||||
// would deadlock the BubbleTea event loop.
|
||||
go ctx.SetEditorText("")
|
||||
return ext.EditorKeyAction{
|
||||
Type: ext.EditorKeySubmit,
|
||||
SubmitText: expanded,
|
||||
}
|
||||
}
|
||||
return ext.EditorKeyAction{Type: ext.EditorKeyPassthrough}
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// ── Non-interactive fallback: OnInput transform ──────────────────
|
||||
// For CLI, script, and queue sources the editor interceptor is not
|
||||
// active, so we fall back to OnInput which still rewrites the
|
||||
// prompt text sent to the LLM.
|
||||
api.OnInput(func(ev ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if !re.MatchString(ev.Text) {
|
||||
if ev.Source == "interactive" || !re.MatchString(ev.Text) {
|
||||
return nil
|
||||
}
|
||||
|
||||
expanded := re.ReplaceAllStringFunc(ev.Text, func(match string) string {
|
||||
// Extract the command between !{ and }.
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
|
||||
return &ext.InputResult{
|
||||
Action: "transform",
|
||||
Text: expanded,
|
||||
Text: expand(ev.Text),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// expand replaces every !{cmd} in text with the command's stdout.
|
||||
// On error the original !{cmd} token is preserved.
|
||||
func expand(text string) string {
|
||||
return re.ReplaceAllStringFunc(text, func(match string) string {
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -168,6 +168,10 @@ var (
|
||||
// Test
|
||||
pendingTest *PendingTest
|
||||
|
||||
// Typing indicator
|
||||
typingTicker *time.Ticker
|
||||
typingStop chan struct{}
|
||||
|
||||
// Latest context for background goroutines
|
||||
latestCtx ext.Context
|
||||
latestCtxSet bool
|
||||
@@ -203,8 +207,23 @@ func configDir() string {
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func globalConfigDir() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
return filepath.Join(home, ".config", "kit")
|
||||
}
|
||||
|
||||
func configPath() string {
|
||||
return filepath.Join(configDir(), "kit-telegram.json")
|
||||
// Prefer project-local config, fall back to global config.
|
||||
local := filepath.Join(configDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(local); err == nil {
|
||||
return local
|
||||
}
|
||||
global := filepath.Join(globalConfigDir(), "kit-telegram.json")
|
||||
if _, err := os.Stat(global); err == nil {
|
||||
return global
|
||||
}
|
||||
// Neither exists — return local path (will be created on connect).
|
||||
return local
|
||||
}
|
||||
|
||||
func failureLogDir() string {
|
||||
@@ -387,6 +406,14 @@ func tgEditMessageText(token string, chatID int64, messageID int, text string) (
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func tgSendChatAction(token string, chatID int64, action string) error {
|
||||
_, err := telegramRequest(token, "sendChatAction", map[string]any{
|
||||
"chat_id": chatID,
|
||||
"action": action,
|
||||
}, 15)
|
||||
return err
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Error classification
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -637,6 +664,48 @@ func clearHealthTimer() {
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Typing indicator
|
||||
// ──────────────────────────────────────────────
|
||||
|
||||
func startTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
return
|
||||
}
|
||||
cfg := config
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return
|
||||
}
|
||||
token := cfg.BotToken
|
||||
chatID := cfg.ChatID
|
||||
typingTicker = time.NewTicker(4 * time.Second)
|
||||
typingStop = make(chan struct{})
|
||||
// Send immediately, then every 4 seconds.
|
||||
go func() {
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
for {
|
||||
select {
|
||||
case <-typingTicker.C:
|
||||
tgSendChatAction(token, chatID, "typing")
|
||||
case <-typingStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func stopTypingLoop() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if typingTicker != nil {
|
||||
typingTicker.Stop()
|
||||
close(typingStop)
|
||||
typingTicker = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────
|
||||
// Polling lifecycle
|
||||
// ──────────────────────────────────────────────
|
||||
@@ -908,7 +977,7 @@ func summarizeToolAction(toolName string, inputJSON string) string {
|
||||
return "searching " + getStr("pattern", "text")
|
||||
case "ls":
|
||||
return "listing " + getStr("path", "directory")
|
||||
case "spawn_subagent":
|
||||
case "subagent":
|
||||
return "spawning subagent"
|
||||
default:
|
||||
return "using " + toolName
|
||||
@@ -2105,6 +2174,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
sendShutdownDisconnectedMessage()
|
||||
stopTypingLoop()
|
||||
stopPolling()
|
||||
clearHealthTimer()
|
||||
clearFooter()
|
||||
@@ -2128,6 +2198,7 @@ func Init(api ext.API) {
|
||||
mu.Unlock()
|
||||
|
||||
report("run.start", fmt.Sprintf("runId=%d", run.ID))
|
||||
startTypingLoop()
|
||||
ensureProgressMessage()
|
||||
updateProgressMessage()
|
||||
})
|
||||
@@ -2140,6 +2211,8 @@ func Init(api ext.API) {
|
||||
run := activeRun
|
||||
mu.Unlock()
|
||||
|
||||
stopTypingLoop()
|
||||
|
||||
if run != nil {
|
||||
// Capture final response from event
|
||||
if e.Response != "" {
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
|
||||
// lsp-diagnostics.go — LSP-powered diagnostics for Kit's edit tool.
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits,
|
||||
// following the same pattern used by Charm's crush editor:
|
||||
//
|
||||
// Starts language servers on demand and surfaces diagnostics after file edits:
|
||||
// 1. After an edit, notify the LSP server of the file change
|
||||
// 2. Wait for the server to publish fresh diagnostics
|
||||
// 3. Append diagnostic output to the edit tool's result
|
||||
@@ -412,7 +410,7 @@ func (c *lspClient) changeFile(absPath, content string) {
|
||||
}
|
||||
|
||||
// waitForDiagnostics polls until the server publishes new diagnostics or
|
||||
// the timeout elapses. Mirrors crush's WaitForDiagnostics pattern.
|
||||
// the timeout elapses.
|
||||
func (c *lspClient) waitForDiagnostics(timeout time.Duration) {
|
||||
c.diagMu.Lock()
|
||||
startVersion := c.diagVersion
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
//go:build ignore
|
||||
|
||||
// prompt-templates.go - Frontmatter-driven prompt templates with model switching.
|
||||
// This extension demonstrates the new bridged SDK APIs:
|
||||
// - Tree navigation for conversation management
|
||||
// - Template parsing with {{variable}} substitution
|
||||
// - Model resolution with fallback chains
|
||||
// - Skill injection
|
||||
//
|
||||
// Usage:
|
||||
// 1. Create ~/.config/kit/prompts/debug.md with frontmatter:
|
||||
// ---
|
||||
// description: Debug Python code
|
||||
// model: claude-sonnet-4-20250514
|
||||
// skill: python
|
||||
// ---
|
||||
// Help me debug this Python code: {{input}}
|
||||
//
|
||||
// 2. In Kit: /debug my_script.py
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// PromptTemplate represents a loaded template with frontmatter
|
||||
type PromptTemplate struct {
|
||||
Name string
|
||||
Description string
|
||||
Model string
|
||||
Skill string
|
||||
Content string
|
||||
Variables []string
|
||||
Path string
|
||||
}
|
||||
|
||||
var (
|
||||
templates = make(map[string]PromptTemplate)
|
||||
templateDir string
|
||||
)
|
||||
|
||||
func Init(api ext.API) {
|
||||
// Determine template directory
|
||||
home, _ := os.UserHomeDir()
|
||||
templateDir = filepath.Join(home, ".config", "kit", "prompts")
|
||||
|
||||
// Ensure directory exists
|
||||
os.MkdirAll(templateDir, 0755)
|
||||
|
||||
// Register commands
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: "reload-templates",
|
||||
Description: "Reload prompt templates from disk",
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
loadTemplates(ctx)
|
||||
ctx.PrintInfo(fmt.Sprintf("Loaded %d templates from %s", len(templates), templateDir))
|
||||
return "", nil
|
||||
},
|
||||
})
|
||||
|
||||
// Dynamic template commands are registered after loading
|
||||
api.OnSessionStart(func(e ext.SessionStartEvent, ctx ext.Context) {
|
||||
loadTemplates(ctx)
|
||||
registerTemplateCommands(api, ctx)
|
||||
})
|
||||
}
|
||||
|
||||
// loadTemplates discovers and loads all template files
|
||||
func loadTemplates(ctx ext.Context) {
|
||||
templates = make(map[string]PromptTemplate)
|
||||
|
||||
entries, err := os.ReadDir(templateDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(templateDir, entry.Name())
|
||||
tpl, err := loadTemplateFile(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(entry.Name(), ".md")
|
||||
templates[name] = tpl
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplateFile parses a template with YAML frontmatter
|
||||
func loadTemplateFile(path string) (PromptTemplate, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return PromptTemplate{}, err
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
tpl := PromptTemplate{Path: path}
|
||||
|
||||
// Parse frontmatter
|
||||
if strings.HasPrefix(content, "---") {
|
||||
parts := strings.SplitN(content[3:], "---", 2)
|
||||
if len(parts) == 2 {
|
||||
frontmatter := strings.TrimSpace(parts[0])
|
||||
body := strings.TrimSpace(parts[1])
|
||||
|
||||
// Simple line-by-line frontmatter parsing
|
||||
for _, line := range strings.Split(frontmatter, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
key, value, found := strings.Cut(line, ":")
|
||||
if found {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
switch key {
|
||||
case "description":
|
||||
tpl.Description = value
|
||||
case "model":
|
||||
tpl.Model = value
|
||||
case "skill":
|
||||
tpl.Skill = value
|
||||
}
|
||||
}
|
||||
}
|
||||
tpl.Content = body
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
} else {
|
||||
tpl.Content = content
|
||||
}
|
||||
|
||||
// Parse {{variables}} using simple string parsing
|
||||
// (Can't use ctx.ParseTemplate here since we're in Init, not a handler)
|
||||
var vars []string
|
||||
for {
|
||||
start := strings.Index(tpl.Content, "{{")
|
||||
if start == -1 {
|
||||
break
|
||||
}
|
||||
end := strings.Index(tpl.Content[start:], "}}")
|
||||
if end == -1 {
|
||||
break
|
||||
}
|
||||
varName := strings.TrimSpace(tpl.Content[start+2 : start+end])
|
||||
vars = append(vars, varName)
|
||||
tpl.Content = tpl.Content[:start] + "{{" + varName + "}}" + tpl.Content[start+end+2:]
|
||||
}
|
||||
tpl.Variables = vars
|
||||
|
||||
return tpl, nil
|
||||
}
|
||||
|
||||
// registerTemplateCommands dynamically registers commands for each template
|
||||
func registerTemplateCommands(api ext.API, ctx ext.Context) {
|
||||
for name, tpl := range templates {
|
||||
// Skip if already registered (we'd need to track this)
|
||||
tplCopy := tpl // Capture for closure
|
||||
nameCopy := name
|
||||
|
||||
// Build description with metadata
|
||||
desc := tplCopy.Description
|
||||
if desc == "" {
|
||||
desc = fmt.Sprintf("Run %s template", nameCopy)
|
||||
}
|
||||
if tplCopy.Model != "" {
|
||||
desc += fmt.Sprintf(" [%s", tplCopy.Model)
|
||||
if tplCopy.Skill != "" {
|
||||
desc += fmt.Sprintf(" +%s", tplCopy.Skill)
|
||||
}
|
||||
desc += "]"
|
||||
}
|
||||
|
||||
api.RegisterCommand(ext.CommandDef{
|
||||
Name: nameCopy,
|
||||
Description: desc,
|
||||
Execute: func(args string, ctx ext.Context) (string, error) {
|
||||
return executeTemplate(ctx, tplCopy, args)
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplate runs a template with the given arguments
|
||||
func executeTemplate(ctx ext.Context, tpl PromptTemplate, args string) (string, error) {
|
||||
// Store original model for restoration
|
||||
originalModel := ctx.Model
|
||||
|
||||
// 1. Resolve and switch model if specified
|
||||
if tpl.Model != "" {
|
||||
// Parse model chain (comma-separated)
|
||||
preferences := strings.Split(tpl.Model, ",")
|
||||
for i := range preferences {
|
||||
preferences[i] = strings.TrimSpace(preferences[i])
|
||||
}
|
||||
|
||||
result := ctx.ResolveModelChain(preferences)
|
||||
if result.Error != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Model resolution failed: %s", result.Error))
|
||||
// Continue with current model
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Switching to model: %s", result.Model))
|
||||
if err := ctx.SetModel(result.Model); err != nil {
|
||||
ctx.PrintError(fmt.Sprintf("Failed to switch model: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Inject skill if specified
|
||||
if tpl.Skill != "" {
|
||||
err := ctx.InjectSkillAsContext(tpl.Skill)
|
||||
if err != "" {
|
||||
ctx.PrintError(fmt.Sprintf("Skill injection failed: %s", err))
|
||||
} else {
|
||||
ctx.PrintInfo(fmt.Sprintf("Injected skill: %s", tpl.Skill))
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Parse and render template
|
||||
parsed := ctx.ParseTemplate(tpl.Name, tpl.Content)
|
||||
|
||||
// Build variable map
|
||||
vars := make(map[string]string)
|
||||
|
||||
// Simple argument parsing: first arg is $1 (input), rest is $@
|
||||
if len(parsed.Variables) > 0 {
|
||||
argsList := ctx.SimpleParseArguments(args, len(parsed.Variables))
|
||||
for i, varName := range parsed.Variables {
|
||||
if i < len(parsed.Variables) && i+1 < len(argsList) {
|
||||
vars[varName] = argsList[i+1]
|
||||
}
|
||||
}
|
||||
// If single variable, use full args
|
||||
if len(parsed.Variables) == 1 && vars[parsed.Variables[0]] == "" {
|
||||
vars[parsed.Variables[0]] = args
|
||||
}
|
||||
}
|
||||
|
||||
// Render with model conditionals
|
||||
content := ctx.RenderWithModelConditionals(tpl.Content)
|
||||
rendered := ctx.RenderTemplate(ext.PromptTemplate{Name: tpl.Name, Content: content, Variables: parsed.Variables}, vars)
|
||||
|
||||
// 4. Send the rendered prompt
|
||||
ctx.SendMessage(rendered)
|
||||
|
||||
// 5. Schedule model restoration after turn completes
|
||||
// We use a goroutine to wait and restore
|
||||
if tpl.Model != "" && originalModel != "" {
|
||||
go func() {
|
||||
// Note: In a real implementation, we'd use OnAgentEnd event
|
||||
// For now, the user can manually switch back
|
||||
ctx.SetStatus("template-mode", fmt.Sprintf("Template: %s (model will restore)", tpl.Name), 20)
|
||||
}()
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Executing template: %s", tpl.Name), nil
|
||||
}
|
||||
@@ -130,6 +130,58 @@ func TestSubagentMonitor_MultipleSubagents(t *testing.T) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_ConcurrentSubagents verifies no panics when multiple
|
||||
// subagents emit events concurrently from different goroutines.
|
||||
func TestSubagentMonitor_ConcurrentSubagents(t *testing.T) {
|
||||
harness := test.New(t)
|
||||
harness.LoadFile("../../.kit/extensions/subagent-monitor.go")
|
||||
|
||||
_, err := harness.Emit(extensions.SessionStartEvent{SessionID: "test-session"})
|
||||
if err != nil {
|
||||
t.Fatalf("SessionStart should not error: %v", err)
|
||||
}
|
||||
|
||||
// Start 5 subagents concurrently
|
||||
done := make(chan struct{}, 5)
|
||||
for i := range 5 {
|
||||
go func(idx int) {
|
||||
defer func() { done <- struct{}{} }()
|
||||
|
||||
callID := fmt.Sprintf("concurrent-%d", idx)
|
||||
task := fmt.Sprintf("concurrent task %d", idx)
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentStartEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
})
|
||||
|
||||
// Emit many chunks rapidly
|
||||
for j := range 20 {
|
||||
_, _ = harness.Emit(extensions.SubagentChunkEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
ChunkType: "text",
|
||||
Content: fmt.Sprintf("agent %d chunk %d", idx, j),
|
||||
})
|
||||
}
|
||||
|
||||
_, _ = harness.Emit(extensions.SubagentEndEvent{
|
||||
ToolCallID: callID,
|
||||
Task: task,
|
||||
Response: "done",
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for range 5 {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Allow any final processing
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// TestSubagentMonitor_SessionShutdown verifies shutdown doesn't panic
|
||||
// even with nil ctx functions.
|
||||
func TestSubagentMonitor_SessionShutdown(t *testing.T) {
|
||||
|
||||
@@ -37,7 +37,7 @@ func Init(api ext.API) {
|
||||
"Subagent Test Extension loaded\n\n" +
|
||||
"/subtest <task> Spawn blocking subagent\n" +
|
||||
"/subbg <task> Spawn background subagent\n\n" +
|
||||
"The LLM can also use the spawn_subagent tool.")
|
||||
"The LLM can also use the subagent tool.")
|
||||
})
|
||||
|
||||
api.OnAgentEnd(func(_ ext.AgentEndEvent, ctx ext.Context) {
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
//go:build ignore
|
||||
|
||||
// sudo-handler.go - Extension to handle sudo password prompts securely
|
||||
//
|
||||
// This extension intercepts bash commands containing "sudo" and:
|
||||
// 1. Checks if sudo credentials are already cached (via sudo -n)
|
||||
// 2. If not cached, prompts the user for their password (with masking)
|
||||
// 3. Temporarily sets SUDO_PASSWORD environment variable for execution
|
||||
// 4. The bash tool automatically uses sudo -S -p '' to pipe the password
|
||||
//
|
||||
// Usage: kit -e examples/extensions/sudo-handler.go
|
||||
//
|
||||
// Security notes:
|
||||
// - Password is only stored in memory for the duration of the session
|
||||
// - Password is never logged or displayed
|
||||
// - Each session requires re-authentication (sudo -k is used)
|
||||
// - The SUDO_PASSWORD env var is set only during tool execution
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
var (
|
||||
// cachedPassword stores the sudo password for the session
|
||||
cachedPassword string
|
||||
// hasCachedPassword tracks if we have a valid cached password
|
||||
hasCachedPassword bool
|
||||
// mu protects cached password access
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// Init sets up the sudo handler extension
|
||||
func Init(api ext.API) {
|
||||
api.OnToolCall(func(tc ext.ToolCallEvent, ctx ext.Context) *ext.ToolCallResult {
|
||||
if tc.ToolName != "bash" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse the command from tool input
|
||||
var input struct {
|
||||
Command string `json:"command"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.Input), &input); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if command contains sudo
|
||||
if !containsSudo(input.Command) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if we already have cached credentials
|
||||
mu.RLock()
|
||||
password := cachedPassword
|
||||
hasCached := hasCachedPassword
|
||||
mu.RUnlock()
|
||||
|
||||
if hasCached {
|
||||
// Use cached password
|
||||
os.Setenv("SUDO_PASSWORD", password)
|
||||
return nil
|
||||
}
|
||||
|
||||
// No cached password - prompt user
|
||||
result := ctx.PromptInput(ext.PromptInputConfig{
|
||||
Message: "🔐 Sudo password required for:\n " + truncateCommand(input.Command, 60),
|
||||
Placeholder: "Enter your password",
|
||||
})
|
||||
|
||||
if result.Cancelled {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "Sudo password prompt cancelled by user",
|
||||
}
|
||||
}
|
||||
|
||||
if result.Value == "" {
|
||||
return &ext.ToolCallResult{
|
||||
Block: true,
|
||||
Reason: "No password provided",
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the password for this session
|
||||
mu.Lock()
|
||||
cachedPassword = result.Value
|
||||
hasCachedPassword = true
|
||||
mu.Unlock()
|
||||
|
||||
// Set environment variable for the bash tool to use
|
||||
os.Setenv("SUDO_PASSWORD", result.Value)
|
||||
|
||||
// Show confirmation (without revealing password)
|
||||
ctx.PrintInfo("Sudo password cached for this session")
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Clear cached password when session ends
|
||||
api.OnSessionShutdown(func(event ext.SessionShutdownEvent, ctx ext.Context) {
|
||||
mu.Lock()
|
||||
cachedPassword = ""
|
||||
hasCachedPassword = false
|
||||
mu.Unlock()
|
||||
os.Unsetenv("SUDO_PASSWORD")
|
||||
})
|
||||
}
|
||||
|
||||
// containsSudo checks if the command contains sudo as a command (not in a string)
|
||||
func containsSudo(command string) bool {
|
||||
// Simple check for sudo as a word, not inside quotes or as part of another word
|
||||
lower := strings.ToLower(command)
|
||||
|
||||
// Check for sudo at start or after separators
|
||||
patterns := []string{
|
||||
"sudo ",
|
||||
"sudo\t",
|
||||
";sudo ",
|
||||
"&& sudo ",
|
||||
"|| sudo ",
|
||||
"| sudo ",
|
||||
"$(sudo ",
|
||||
"`sudo ",
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if command starts with sudo
|
||||
if strings.HasPrefix(lower, "sudo ") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// truncateCommand truncates a long command for display
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
@@ -1,98 +1,95 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.1
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.0.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.5
|
||||
charm.land/fantasy v0.17.2
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.48.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/term v0.42.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go v0.123.0 // indirect
|
||||
cloud.google.com/go/auth v0.19.0 // indirect
|
||||
cloud.google.com/go/auth v0.20.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 // indirect
|
||||
github.com/aws/smithy-go v1.24.2 // indirect
|
||||
github.com/aymerick/douceur v0.2.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 // indirect
|
||||
github.com/aws/smithy-go v1.24.3 // indirect
|
||||
github.com/catppuccin/go v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/charmbracelet/anthropic-sdk-go v0.0.0-20260223140439-63879b0b8dab // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.3 // indirect
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 // indirect
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
github.com/charmbracelet/x/windows v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.11.0 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0 // indirect
|
||||
github.com/dlclark/regexp2 v1.11.5 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/go-json-experiment/json v0.0.0-20260214004413-d219187c3433 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 // indirect
|
||||
github.com/gorilla/css v1.0.1 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/indaco/herald v0.9.0 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.2.12 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.0 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.6 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 // indirect
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
github.com/muesli/mango-pflag v0.2.0 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/roff v0.1.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
|
||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||
@@ -106,41 +103,39 @@ require (
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/yuin/goldmark v1.8.2 // indirect
|
||||
github.com/yuin/goldmark-emoji v1.0.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect
|
||||
go.opentelemetry.io/otel v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.42.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.42.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
|
||||
go.opentelemetry.io/otel v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.273.0 // indirect
|
||||
google.golang.org/genai v1.51.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect
|
||||
google.golang.org/grpc v1.79.3 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/glamour v1.0.0
|
||||
github.com/charmbracelet/x/ansi v0.11.6
|
||||
github.com/charmbracelet/x/ansi v0.11.7
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.21 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.21 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.23 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.36.0
|
||||
)
|
||||
|
||||
@@ -1,29 +1,29 @@
|
||||
charm.land/bubbles/v2 v2.0.0 h1:tE3eK/pHjmtrDiRdoC9uGNLgpopOd8fjhEe31B/ai5s=
|
||||
charm.land/bubbles/v2 v2.0.0/go.mod h1:rCHoleP2XhU8um45NTuOWBPNVHxnkXKTiZqcclL/qOI=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.5 h1:TQlLFqxo39AAHSVuOhJ5D3nH7O9Nk8JGinsfWQ4y1U4=
|
||||
charm.land/bubbletea/v2 v2.0.5/go.mod h1:dvbsYZD+MHkdIZl+Z67D212hEvB+GII2tfH8f9SnoDw=
|
||||
charm.land/fantasy v0.17.2 h1:ojTMufMxY/PVH7TzYUxht2SVkvD90iCTJfmPR6c8BR8=
|
||||
charm.land/fantasy v0.17.2/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ=
|
||||
cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
|
||||
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
|
||||
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 h1:fou+2+WFTib47nS+nz/ozhEBnvU96bKHy6LjRsY4E28=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0/go.mod h1:t76Ruy8AHvUAC8GfMWJMa0ElSbuIcO03NLpynfbgsPA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0/go.mod h1:7dCRMLwisfRH3dBupKeNCioWYUZ4SS09Z14H+7i8ZoY=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||
github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0=
|
||||
@@ -34,42 +34,40 @@ github.com/alecthomas/repr v0.5.2 h1:SU73FTI9D1P5UNtvseffFSGmdNci/O6RsqzeXJtP0Qs
|
||||
github.com/alecthomas/repr v0.5.2/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4 h1:10f50G7WyU02T56ox1wWXq+zTX9I1zxG46HYuG1hH/k=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.4/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV/yY=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o=
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12 h1:O3csC7HUGn2895eNrLytOJQdoL2xyJy0iYXhoZ1OmP0=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.12/go.mod h1:96zTvoOFR4FURjI+/5wY1vc1ABceROO4lWgWJuxgy0g=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12 h1:oqtA6v+y5fZg//tcTWahyN9PEn5eDU/Wpvc2+kJ4aY8=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.12/go.mod h1:U3R1RtSHx6NB0DvEQFGyf/0sbrpJrluENHdPy1j/3TE=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20 h1:zOgq3uezl5nznfoK3ODuqbhVg1JzAGDUhXOsU0IDCAo=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.20/go.mod h1:z/MVwUARehy6GAg/yQ1GO2IMl0k++cu1ohP9zo887wE=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20 h1:CNXO7mvgThFGqOFgbNAP2nol2qAWBOGfqR/7tQlvLmc=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.20/go.mod h1:oydPDJKcfMhgfcgBUZaG+toBbwy8yPWubJXBVERtI4o=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20 h1:tN6W/hg+pkM+tf9XDkWUbDEjGLb+raoBMFsTodcoYKw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.20/go.mod h1:YJ898MhD067hSHA6xYCx5ts/jEd8BSOLtQDL3iZsvbc=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI=
|
||||
github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI=
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g=
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4=
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21/go.mod h1:A/kJFst/nm//cyqonihbdpQZwiUhhzpqTsdbhDdRF9c=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21 h1:PEgGVtPoB6NTpPrBgqSE5hE/o47Ij9qk/SEZFbUOe9A=
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.21/go.mod h1:p+hz+PRAYlY3zcpJhPwXlLC4C+kqn70WIHwnzAfs6ps=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6 h1:qYQ4pzQ2Oz6WpQ8T3HvGHnZydA72MnLuFK9tJwmrbHw=
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.6/go.mod h1:O3h0IK87yXci+kg6flUKzJnWeziQUKciKrLjcatSNcY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7 h1:5EniKhLZe4xzL7a+fU3C2tfUN4nWIqlLesfrjkuPFTY=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.7/go.mod h1:x0nZssQ3qZSnIcePWLvcoFisRXJzcTVvYpAAdYX8+GI=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20 h1:2HvVAIq+YqgGotK6EkMf+KIEqTISmTYh5zLpYyeTo1Y=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.20/go.mod h1:V4X406Y666khGa8ghKmphma/7C0DAtEQYhkq9z4vpbk=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8 h1:0GFOLzEbOyZABS3PhYfBIx2rNBACYcKty+XGkTgw1ow=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.8/go.mod h1:LXypKvk85AROkKhOG6/YEcHFPoX+prKTowKnVdcaIxE=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13 h1:kiIDLZ005EcKomYYITtfsjn7dtOwHDOFy7IbPXKek2o=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.13/go.mod h1:2h/xGEowcW/g38g06g3KpRWDlT+OTfxxI0o1KqayAB8=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17 h1:jzKAXIlhZhJbnYwHbvUQZEB8KfgAEuG0dc08Bkda7NU=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.17/go.mod h1:Al9fFsXjv4KfbzQHGe6V4NZSZQXecFcvaIF4e70FoRA=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9 h1:Cng+OOwCHmFljXIxpEVXAGMnBia8MSU6Ch5i9PgBkcU=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.9/go.mod h1:LrlIndBDdjA/EeXeyNBle+gyCwTlizzW5ycgWnvIxkk=
|
||||
github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng=
|
||||
github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 h1:c31//R3xgIJMSC8S6hEVq+38DcvUlgFY0FM6mSI5oto=
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21/go.mod h1:r6+pf23ouCB718FUxaqzZdbpYFyDtehyZcmP5KL9FkA=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg=
|
||||
github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM=
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys=
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U=
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw=
|
||||
github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg=
|
||||
github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o=
|
||||
github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w=
|
||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||
github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
@@ -80,8 +78,6 @@ github.com/charmbracelet/colorprofile v0.4.3 h1:QPa1IWkYI+AOB+fE+mg/5/4HRMZcaXex
|
||||
github.com/charmbracelet/colorprofile v0.4.3/go.mod h1:/zT4BhpD5aGFpqQQqw7a+VtHCzu+zrQtt1zhMt9mR4Q=
|
||||
github.com/charmbracelet/fang v1.0.0 h1:jESBY40agJOlLYnnv9jE0mLqDGTxEk0hkOnx7YGyRlQ=
|
||||
github.com/charmbracelet/fang v1.0.0/go.mod h1:P5/DNb9DddQ0Z0dbc0P3ol4/ix5Po7Ofr2KMBfAqoCo=
|
||||
github.com/charmbracelet/glamour v1.0.0 h1:AWMLOVFHTsysl4WV8T8QgkQ0s/ZNZo7CiE4WKhk8l08=
|
||||
github.com/charmbracelet/glamour v1.0.0/go.mod h1:DSdohgOBkMr2ZQNhw4LZxSGpx3SvpeujNoXrQyH2hxo=
|
||||
github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
|
||||
github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 h1:ZR7e0ro+SZZiIZD7msJyA+NjkCNNavuiPBLgerbOziE=
|
||||
@@ -90,24 +86,26 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502 h1:hzWNs3UQRSUTS6YCbLaQnwqKBFXT5Yh1OOw6+26apqg=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260316091819-b93f6a3b8502/go.mod h1:mkUCcxn9w9j89JJp3pOza5tmDQZPgIB75UfmQlFYvas=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e h1:O5hZFj55wZQWxMiRtQLa3uLKhZGZGS/j8M3OXinQlrw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4jxsAuEs=
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIRg8gGWwk=
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca h1:62yAoS1Ynbuzwcn1LkNBxi3IMF5p0E0cHCoaLOOmN9w=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 h1:6F/6bu5nBLjodsvaU5xAszTaxtHrDU5UiJarpMPZj48=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca h1:QQoyQLgUzojMNWHVHToN6d9qTvT0KWtxUKIRPx/Ox5o=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260323091123-df7b1bcffcca/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 h1:RiZFY92Ug9iz1CenzxSSQla2Z3WflsR7bIuXq40JlpU=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -177,43 +175,40 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8=
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs=
|
||||
github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8=
|
||||
github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
|
||||
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/indaco/herald v0.9.0 h1:LrAfXEHkKz8WmctUKdndppIU/qFpylSbZ8galS0DVAc=
|
||||
github.com/indaco/herald v0.9.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/kaptinlin/go-i18n v0.2.12 h1:ywDsvb4KDFddMC2dpI/rrIzGU2mWUSvHmWUm9BMsdl4=
|
||||
github.com/kaptinlin/go-i18n v0.2.12/go.mod h1:pVcu9qsW5pOIOoZFJXesRYmLos1vMQrby70JPAoWmJU=
|
||||
github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.4.0 h1:i7L3U2yurg+xhokITtJ0k+mjHnXqkoyz8ju5Wb7W8Oc=
|
||||
github.com/kaptinlin/go-i18n v0.4.0/go.mod h1:njA6x0+4MWGcLWT0KLrwekhRPmze1Hnstf2+VJFzwpM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.6 h1:UUMqZGFAk7nOzQsYAxvgygm4wpDp/nwXxA4VP9mCPCs=
|
||||
github.com/kaptinlin/jsonschema v0.7.6/go.mod h1:GGk/oE+F1lWUfYrzKaCf4QWZmMdytt0LL4XdFEFB0LE=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18 h1:RBlHVWgZyoxTcUgGWBsl2AcyScq/urqbLZvzgryTmSI=
|
||||
github.com/kaptinlin/messageformat-go v0.4.18/go.mod h1:ntI3154RnqJgr7GaC+vZBnIExl2V3sv9selvRNNEM24=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 h1:a0ufTd5liiUubIGeGxpSTnNS8ZSrN4DV01/wGFmfzMs=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20/go.mod h1:FqdEPfQLkqVBX7OBRMPgYwUPvKYJohFD9Ok1BMzCfIo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w=
|
||||
github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk=
|
||||
github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.48.0 h1:o+MXuGW/HCeR2ny5LcAcZQn2bo6I2xaZMEHnpRG+dtw=
|
||||
github.com/mark3labs/mcp-go v0.48.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 h1:vGKWl0YJqUNxE8d+h8f6NJLcCJrgbhC4NcD46KavDd4=
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2/go.mod h1:MG3aRVU/N29oo/V/IhBX8GR/zz4kQkprJgF2EVszyDE=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -224,8 +219,6 @@ github.com/muesli/mango-cobra v1.3.0 h1:vQy5GvPg3ndOSpduxutqFoINhWk3vD5K2dXo5E8p
|
||||
github.com/muesli/mango-cobra v1.3.0/go.mod h1:Cj1ZrBu3806Qw7UjxnAUgE+7tllUBj1NCLQDwwGx19E=
|
||||
github.com/muesli/mango-pflag v0.2.0 h1:QViokgKDZQCzKhYe1zH8D+UlPJzBSGoP9yx0hBG0t5k=
|
||||
github.com/muesli/mango-pflag v0.2.0/go.mod h1:X9LT1p/pbGA1wjvEbtwnixujKErkP0jVmrxwrw3fL0Y=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8=
|
||||
github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig=
|
||||
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
|
||||
@@ -238,8 +231,6 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgm
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
@@ -281,59 +272,56 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
|
||||
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
|
||||
github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs=
|
||||
github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg=
|
||||
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
|
||||
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
|
||||
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
|
||||
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo=
|
||||
go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc=
|
||||
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
|
||||
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0 h1:0Qx7VGBacMm9ZENQ7TnNObTYI4ShC+lHI16seduaxZo=
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.68.0/go.mod h1:Sje3i3MjSPKTSPvVWCaL8ugBzJwik3u4smCjUeuupqg=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 h1:CqXxU8VOmDefoh0+ztfGaymYbhdB/tT3zs79QaZTNGY=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0/go.mod h1:BuhAPThV8PBHBvg8ZzZ/Ok3idOdhWIodywz2xEcRbJo=
|
||||
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
|
||||
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ=
|
||||
google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew=
|
||||
google.golang.org/genai v1.51.0 h1:IZGuUqgfx40INv3hLFGCbOSGp0qFqm7LVmDghzNIYqg=
|
||||
google.golang.org/genai v1.51.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
|
||||
google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 h1:RmoJA1ujG+/lRGNfUnOMfhCy5EipVMyvUE+KNbPbTlw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+226
-16
@@ -7,8 +7,11 @@ package acpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
@@ -20,7 +23,6 @@ import (
|
||||
// Version is injected at build time; fallback to "dev".
|
||||
var Version = "dev"
|
||||
|
||||
// Agent implements the acp.Agent interface, delegating to Kit for LLM
|
||||
// execution, tool calls, and session management.
|
||||
type Agent struct {
|
||||
conn *acp.AgentSideConnection
|
||||
@@ -111,13 +113,20 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
)
|
||||
}
|
||||
|
||||
// Extract text from prompt content blocks.
|
||||
promptText := extractPromptText(params.Prompt)
|
||||
if promptText == "" {
|
||||
// Extract text and file attachments from prompt content blocks.
|
||||
promptText, files := extractPromptContent(params.Prompt)
|
||||
if promptText == "" && len(files) == 0 {
|
||||
return acp.PromptResponse{}, acp.NewInvalidParams("empty prompt")
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText))
|
||||
// If we have files but no text prompt, add a default prompt
|
||||
// This is required because the underlying LLM library needs a non-empty prompt
|
||||
// when there are no previous messages in the conversation.
|
||||
if promptText == "" && len(files) > 0 {
|
||||
promptText = "Please analyze the attached file."
|
||||
}
|
||||
|
||||
log.Debug("acp: prompt", "session", sessionID, "prompt_len", len(promptText), "files", len(files))
|
||||
|
||||
// Create a cancellable context for this prompt turn.
|
||||
promptCtx, cancel := context.WithCancel(ctx)
|
||||
@@ -129,7 +138,13 @@ func (a *Agent) Prompt(ctx context.Context, params acp.PromptRequest) (acp.Promp
|
||||
defer unsub()
|
||||
|
||||
// Run the prompt through Kit's full turn lifecycle.
|
||||
_, err := sess.kit.PromptResult(promptCtx, promptText)
|
||||
// Use PromptResultWithFiles when file attachments are present.
|
||||
var err error
|
||||
if len(files) > 0 {
|
||||
_, err = sess.kit.PromptResultWithFiles(promptCtx, promptText, files)
|
||||
} else {
|
||||
_, err = sess.kit.PromptResult(promptCtx, promptText)
|
||||
}
|
||||
if err != nil {
|
||||
if promptCtx.Err() != nil {
|
||||
return acp.PromptResponse{
|
||||
@@ -162,6 +177,24 @@ func (a *Agent) SetSessionMode(_ context.Context, _ acp.SetSessionModeRequest) (
|
||||
return acp.SetSessionModeResponse{}, nil
|
||||
}
|
||||
|
||||
// SetSessionModel changes the active model for a session.
|
||||
func (a *Agent) SetSessionModel(ctx context.Context, params acp.SetSessionModelRequest) (acp.SetSessionModelResponse, error) {
|
||||
sessionID := string(params.SessionId)
|
||||
sess, ok := a.registry.get(sessionID)
|
||||
if !ok {
|
||||
return acp.SetSessionModelResponse{}, acp.NewInvalidParams(fmt.Sprintf("session not found: %s", sessionID))
|
||||
}
|
||||
|
||||
modelID := string(params.ModelId)
|
||||
log.Debug("acp: set_session_model", "session", sessionID, "model", modelID)
|
||||
|
||||
if err := sess.kit.SetModel(ctx, modelID); err != nil {
|
||||
return acp.SetSessionModelResponse{}, fmt.Errorf("set model: %w", err)
|
||||
}
|
||||
|
||||
return acp.SetSessionModelResponse{}, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Event streaming: Kit events → ACP SessionUpdate notifications
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -231,19 +264,196 @@ func (a *Agent) subscribeEvents(ctx context.Context, k *kit.Kit, sessionID acp.S
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// extractPromptText extracts the concatenated text content from ACP content
|
||||
// blocks. Non-text blocks are ignored for now.
|
||||
func extractPromptText(blocks []acp.ContentBlock) string {
|
||||
var text string
|
||||
for _, block := range blocks {
|
||||
if block.Text != nil {
|
||||
if text != "" {
|
||||
text += "\n"
|
||||
// extractPromptContent extracts text and file attachments from ACP content blocks.
|
||||
// It converts supported content blocks (image, audio, resource) to Kit's LLMFilePart.
|
||||
func extractPromptContent(blocks []acp.ContentBlock) (string, []kit.LLMFilePart) {
|
||||
var textParts []string
|
||||
var files []kit.LLMFilePart
|
||||
|
||||
log.Debug("acp: extracting content", "blocks", len(blocks))
|
||||
|
||||
for i, block := range blocks {
|
||||
switch {
|
||||
// Text content
|
||||
case block.Text != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "text", "len", len(block.Text.Text))
|
||||
textParts = append(textParts, block.Text.Text)
|
||||
|
||||
// Image data (base64)
|
||||
case block.Image != nil:
|
||||
mimeType := block.Image.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "image/png" // Default fallback
|
||||
}
|
||||
text += block.Text.Text
|
||||
log.Debug("acp: content block", "index", i, "type", "image", "mime", mimeType, "data_len", len(block.Image.Data))
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Image.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "image.png",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode image", "error", err)
|
||||
}
|
||||
|
||||
// Audio data (base64)
|
||||
case block.Audio != nil:
|
||||
mimeType := block.Audio.MimeType
|
||||
if mimeType == "" {
|
||||
mimeType = "audio/wav" // Default fallback
|
||||
}
|
||||
log.Debug("acp: content block", "index", i, "type", "audio", "mime", mimeType)
|
||||
if data, err := base64.StdEncoding.DecodeString(block.Audio.Data); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: "audio.wav",
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode audio", "error", err)
|
||||
}
|
||||
|
||||
// Embedded resource (text or binary file content)
|
||||
case block.Resource != nil:
|
||||
log.Debug("acp: content block", "index", i, "type", "resource")
|
||||
res := block.Resource.Resource
|
||||
// Text resource - append as text content with file reference
|
||||
if res.TextResourceContents != nil {
|
||||
uri := res.TextResourceContents.Uri
|
||||
content := res.TextResourceContents.Text
|
||||
mimeType := "text/plain"
|
||||
if res.TextResourceContents.MimeType != nil {
|
||||
mimeType = *res.TextResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: text resource", "uri", uri, "mime", mimeType, "len", len(content))
|
||||
// Text files are included as formatted text, NOT as FilePart
|
||||
// FilePart is for binary files (images, audio, PDFs) only
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, content))
|
||||
}
|
||||
// Binary resource (base64 blob) - these become FilePart
|
||||
if res.BlobResourceContents != nil {
|
||||
uri := res.BlobResourceContents.Uri
|
||||
mimeType := "application/octet-stream"
|
||||
if res.BlobResourceContents.MimeType != nil {
|
||||
mimeType = *res.BlobResourceContents.MimeType
|
||||
}
|
||||
log.Debug("acp: binary resource", "uri", uri, "mime", mimeType, "blob_len", len(res.BlobResourceContents.Blob))
|
||||
if data, err := base64.StdEncoding.DecodeString(res.BlobResourceContents.Blob); err == nil {
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: data,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
} else {
|
||||
log.Debug("acp: failed to decode binary resource", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Resource link (file reference without embedded content)
|
||||
case block.ResourceLink != nil:
|
||||
uri := block.ResourceLink.Uri
|
||||
name := block.ResourceLink.Name
|
||||
log.Debug("acp: content block", "index", i, "type", "resource_link", "uri", uri, "name", name)
|
||||
// For resource links, we'll try to read the file from disk
|
||||
// This requires the file URI to be accessible (file:// scheme)
|
||||
if content, err := readResourceFromURI(uri); err == nil {
|
||||
// Detect if it's a text file or binary file
|
||||
mimeType := "text/plain"
|
||||
if block.ResourceLink.MimeType != nil {
|
||||
mimeType = *block.ResourceLink.MimeType
|
||||
}
|
||||
log.Debug("acp: resource link loaded", "uri", uri, "mime", mimeType, "size", len(content))
|
||||
|
||||
// Only create FilePart for binary files (images, audio, PDFs, etc.)
|
||||
// Text files are included as formatted text in the message
|
||||
if isTextMimeType(mimeType) || looksLikeText(content) {
|
||||
textParts = append(textParts, fmt.Sprintf("[File: %s]\n```\n%s\n```", uri, string(content)))
|
||||
} else {
|
||||
// Binary file - create FilePart for models that support it
|
||||
files = append(files, kit.LLMFilePart{
|
||||
Filename: extractFilenameFromURI(uri),
|
||||
Data: content,
|
||||
MediaType: mimeType,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If we can't read it, include as a text reference
|
||||
log.Debug("acp: resource link failed to load", "uri", uri, "error", err)
|
||||
textParts = append(textParts, fmt.Sprintf("[Referenced file: %s]", uri))
|
||||
}
|
||||
|
||||
default:
|
||||
log.Debug("acp: content block", "index", i, "type", "unknown/unhandled")
|
||||
}
|
||||
}
|
||||
return text
|
||||
|
||||
// Debug log the extracted content
|
||||
for i, f := range files {
|
||||
log.Debug("acp: extracted file", "index", i, "filename", f.Filename, "mime", f.MediaType, "size", len(f.Data))
|
||||
}
|
||||
|
||||
return strings.Join(textParts, "\n"), files
|
||||
}
|
||||
|
||||
// isTextMimeType returns true if the MIME type indicates text content.
|
||||
func isTextMimeType(mimeType string) bool {
|
||||
return strings.HasPrefix(mimeType, "text/") ||
|
||||
mimeType == "application/json" ||
|
||||
mimeType == "application/xml" ||
|
||||
mimeType == "application/javascript" ||
|
||||
mimeType == "application/typescript" ||
|
||||
mimeType == "application/x-sh" ||
|
||||
mimeType == "application/x-python" ||
|
||||
mimeType == "application/x-yaml" ||
|
||||
mimeType == "application/x-toml"
|
||||
}
|
||||
|
||||
// looksLikeText checks if the content appears to be text (not binary).
|
||||
// It samples the first 512 bytes and checks for null bytes or high
|
||||
// concentration of non-printable characters.
|
||||
func looksLikeText(data []byte) bool {
|
||||
if len(data) == 0 {
|
||||
return true
|
||||
}
|
||||
// Check first 512 bytes (or less if file is smaller)
|
||||
sampleSize := min(len(data), 512)
|
||||
sample := data[:sampleSize]
|
||||
|
||||
// Count non-printable characters
|
||||
nonPrintable := 0
|
||||
for _, b := range sample {
|
||||
// Null byte indicates binary
|
||||
if b == 0 {
|
||||
return false
|
||||
}
|
||||
// Count control characters (except common whitespace)
|
||||
if b < 32 && b != '\n' && b != '\r' && b != '\t' {
|
||||
nonPrintable++
|
||||
}
|
||||
}
|
||||
|
||||
// If more than 30% non-printable, consider it binary
|
||||
return float64(nonPrintable)/float64(sampleSize) < 0.3
|
||||
}
|
||||
|
||||
// extractFilenameFromURI extracts a filename from a file URI or path.
|
||||
func extractFilenameFromURI(uri string) string {
|
||||
// Handle file:// URIs
|
||||
uri = strings.TrimPrefix(uri, "file://")
|
||||
// Extract basename
|
||||
if idx := strings.LastIndex(uri, "/"); idx >= 0 {
|
||||
return uri[idx+1:]
|
||||
}
|
||||
return uri
|
||||
}
|
||||
|
||||
// readResourceFromURI attempts to read file content from a file:// URI.
|
||||
func readResourceFromURI(uri string) ([]byte, error) {
|
||||
if !strings.HasPrefix(uri, "file://") {
|
||||
return nil, fmt.Errorf("unsupported URI scheme: %s", uri)
|
||||
}
|
||||
path := uri[7:] // Remove file:// prefix
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
// parseToolArgs attempts to parse a JSON tool args string into a map for
|
||||
|
||||
@@ -62,8 +62,8 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
// work in ACP mode. TUI-dependent features (widgets, prompts, editor)
|
||||
// become no-ops or return cancelled; all data/model/tool APIs work
|
||||
// identically to interactive mode.
|
||||
if kitInstance.HasExtensions() {
|
||||
kitInstance.SetExtensionContext(extensions.Context{
|
||||
if kitInstance.Extensions().HasExtensions() {
|
||||
kitInstance.Extensions().SetContext(extensions.Context{
|
||||
SessionID: sessionID,
|
||||
CWD: cwd,
|
||||
Model: kitInstance.GetModelString(),
|
||||
@@ -121,31 +121,31 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
MessageCount: s.MessageCount,
|
||||
}
|
||||
},
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionFilePath() },
|
||||
GetMessages: func() []extensions.SessionMessage { return kitInstance.Extensions().GetSessionMessages() },
|
||||
GetSessionPath: func() string { return kitInstance.GetSessionPath() },
|
||||
AppendEntry: func(entryType, data string) (string, error) {
|
||||
return kitInstance.AppendExtensionEntry(entryType, data)
|
||||
return kitInstance.Extensions().AppendEntry(entryType, data)
|
||||
},
|
||||
GetEntries: func(entryType string) []extensions.ExtensionEntry {
|
||||
return kitInstance.GetExtensionEntries(entryType)
|
||||
return kitInstance.Extensions().GetEntries(entryType)
|
||||
},
|
||||
|
||||
// Options, model, and tool management.
|
||||
GetOption: func(name string) string { return kitInstance.GetExtensionOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.SetExtensionOption(name, value) },
|
||||
GetOption: func(name string) string { return kitInstance.Extensions().GetOption(name) },
|
||||
SetOption: func(name, value string) { kitInstance.Extensions().SetOption(name, value) },
|
||||
SetModel: func(modelString string) error {
|
||||
previousModel := kitInstance.GetExtensionContext().Model
|
||||
previousModel := kitInstance.Extensions().GetContext().Model
|
||||
if err := kitInstance.SetModel(context.Background(), modelString); err != nil {
|
||||
return err
|
||||
}
|
||||
kitInstance.UpdateExtensionContextModel(modelString)
|
||||
kitInstance.EmitModelChange(modelString, previousModel, "extension")
|
||||
kitInstance.Extensions().UpdateContextModel(modelString)
|
||||
kitInstance.Extensions().EmitModelChange(modelString, previousModel, "extension")
|
||||
return nil
|
||||
},
|
||||
GetAvailableModels: func() []extensions.ModelInfoEntry { return kitInstance.GetAvailableModels() },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.EmitExtensionCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.GetExtensionToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.SetExtensionActiveTools(names) },
|
||||
EmitCustomEvent: func(name, data string) { kitInstance.Extensions().EmitCustomEvent(name, data) },
|
||||
GetAllTools: func() []extensions.ToolInfo { return kitInstance.Extensions().GetToolInfos() },
|
||||
SetActiveTools: func(names []string) { kitInstance.Extensions().SetActiveTools(names) },
|
||||
|
||||
// LLM completions and subagents.
|
||||
Complete: func(req extensions.CompleteRequest) (extensions.CompleteResponse, error) {
|
||||
@@ -173,7 +173,7 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
}
|
||||
extResult := &extensions.SubagentResult{
|
||||
Response: result.Response,
|
||||
Error: result.Error,
|
||||
Error: err,
|
||||
SessionID: result.SessionID,
|
||||
Elapsed: result.Elapsed,
|
||||
}
|
||||
@@ -188,15 +188,15 @@ func (r *sessionRegistry) create(ctx context.Context, cwd string) (*acpSession,
|
||||
|
||||
// Render — fall back to logging.
|
||||
RenderMessage: func(name, content string) {
|
||||
renderer := kitInstance.GetExtensionMessageRenderer(name)
|
||||
renderer := kitInstance.Extensions().GetMessageRenderer(name)
|
||||
if renderer != nil && renderer.Render != nil {
|
||||
content = renderer.Render(content, 80)
|
||||
}
|
||||
log.Info("extension: message", "renderer", name, "content", content)
|
||||
},
|
||||
ReloadExtensions: func() error { return kitInstance.ReloadExtensions() },
|
||||
ReloadExtensions: func() error { return kitInstance.Extensions().Reload() },
|
||||
})
|
||||
kitInstance.EmitSessionStart()
|
||||
kitInstance.Extensions().EmitSessionStart()
|
||||
}
|
||||
|
||||
sess := &acpSession{
|
||||
|
||||
+457
-145
@@ -25,19 +25,39 @@ type AgentConfig struct {
|
||||
StreamingEnabled bool
|
||||
DebugLogger tools.DebugLogger
|
||||
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
|
||||
// ToolWrapper is an optional function that wraps the combined tool list
|
||||
// before it is passed to the Fantasy agent. Used by the extensions system
|
||||
// before it is passed to the LLM agent. Used by the extensions system
|
||||
// to intercept tool calls/results.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ExtraTools are additional tools to include alongside core and MCP tools.
|
||||
// Used by extensions to register custom tools.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). The callback receives the server
|
||||
// name, tool count, and any error. Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// ToolCallHandler is a function type for handling tool calls as they happen.
|
||||
@@ -63,6 +83,10 @@ type ToolCallContentHandler func(content string)
|
||||
// ReasoningDeltaHandler is a function type for handling streaming reasoning/thinking deltas.
|
||||
type ReasoningDeltaHandler func(delta string)
|
||||
|
||||
// ReasoningCompleteHandler is a function type for handling reasoning/thinking completion.
|
||||
// Called when the last reasoning token has been processed, before text streaming starts.
|
||||
type ReasoningCompleteHandler func()
|
||||
|
||||
// ToolOutputHandler is a function type for handling streaming tool output chunks.
|
||||
// Used by tools like bash to stream output as it arrives rather than waiting
|
||||
// for the command to complete. The isStderr flag indicates if the chunk
|
||||
@@ -70,15 +94,33 @@ type ReasoningDeltaHandler func(delta string)
|
||||
// Note: This is an alias for core.ToolOutputCallback to avoid import cycles.
|
||||
type ToolOutputHandler = core.ToolOutputCallback
|
||||
|
||||
// PasswordPromptHandler is a function type for password prompts.
|
||||
// Used by the bash tool when sudo requires a password. The handler receives
|
||||
// a prompt message and returns the password and whether it was cancelled.
|
||||
// Note: This is an alias for core.PasswordPromptCallback.
|
||||
type PasswordPromptHandler = core.PasswordPromptCallback
|
||||
|
||||
// StepMessagesHandler is a function type for persisting messages after each
|
||||
// complete step in a multi-step agent turn. The handler receives the messages
|
||||
// produced by the step (typically an assistant message with tool calls followed
|
||||
// by a tool-role message with results, or a final assistant message with text).
|
||||
// This enables incremental session persistence so that progress is saved as
|
||||
// it happens rather than only at the end of the turn.
|
||||
type StepMessagesHandler func(stepMessages []fantasy.Message)
|
||||
|
||||
// StepUsageHandler is a function type for handling token usage after each
|
||||
// complete step in a multi-step agent turn. This enables real-time cost
|
||||
// tracking during long-running tool-calling conversations.
|
||||
type StepUsageHandler func(inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64)
|
||||
|
||||
// Agent represents an AI agent with core tool integration using the fantasy library.
|
||||
// Agent represents an AI agent with core tool integration using the LLM library.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are registered as direct
|
||||
// fantasy.AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// AgentTool implementations — no MCP layer, no serialization overhead.
|
||||
// Additional tools from external MCP servers can be loaded alongside core tools.
|
||||
//
|
||||
// When MCP servers are configured, tool loading happens in the background so the
|
||||
// agent (and UI) can start immediately. The first LLM call automatically waits
|
||||
// for MCP tools to finish loading before proceeding.
|
||||
type Agent struct {
|
||||
toolManager *tools.MCPToolManager
|
||||
fantasyAgent fantasy.Agent
|
||||
@@ -92,6 +134,24 @@ type Agent struct {
|
||||
coreTools []fantasy.AgentTool
|
||||
extraTools []fantasy.AgentTool
|
||||
toolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool // stored for SetModel rebuild
|
||||
|
||||
// providerOptions and modelConfig are stored for rebuilding the fantasy
|
||||
// agent when MCP tools arrive asynchronously or on SetModel.
|
||||
providerOptions fantasy.ProviderOptions
|
||||
skipMaxOutputTokens bool
|
||||
modelConfig *models.ProviderConfig
|
||||
|
||||
// authHandler and tokenStoreFactory are stored from AgentConfig so that
|
||||
// AddMCPServer() can propagate them when creating a new MCPToolManager
|
||||
// at runtime (i.e. when no MCP servers were configured at init time).
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
// mcpErr holds any error from background MCP loading.
|
||||
mcpErr error
|
||||
}
|
||||
|
||||
// GenerateWithLoopResult contains the result and conversation history from an agent interaction.
|
||||
@@ -100,54 +160,50 @@ type GenerateWithLoopResult struct {
|
||||
FinalResponse *fantasy.Response
|
||||
// ConversationMessages contains all messages in the conversation including tool calls and results
|
||||
ConversationMessages []fantasy.Message
|
||||
// Messages contains the conversation as custom content blocks (crush-style)
|
||||
// Messages contains the conversation as custom content blocks
|
||||
Messages []message.Message
|
||||
// TotalUsage contains aggregate token usage across all steps
|
||||
TotalUsage fantasy.Usage
|
||||
// StopReason is the LLM provider's finish reason for the final response.
|
||||
StopReason string
|
||||
// PersistedMessageCount is the number of new messages (beyond the original
|
||||
// input) that were already persisted incrementally via OnStepMessages during
|
||||
// generation. The caller should skip these when doing post-generation
|
||||
// persistence to avoid duplicates.
|
||||
PersistedMessageCount int
|
||||
}
|
||||
|
||||
// NewAgent creates a new Agent with core tools and optional MCP tool integration.
|
||||
// Core tools (bash, read, write, edit, grep, find, ls) are always registered.
|
||||
// External MCP tools are loaded from the config if any MCP servers are configured.
|
||||
// If MCP servers are configured, their tools are loaded in the background —
|
||||
// the agent returns immediately and is usable with core tools only. The first
|
||||
// LLM call (GenerateWithLoop) automatically waits for MCP tools to finish
|
||||
// loading and rebuilds the agent with the full tool set.
|
||||
func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
// Create the LLM provider via fantasy
|
||||
// Create the LLM provider
|
||||
providerResult, err := models.CreateProvider(ctx, agentConfig.ModelConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Register core tools (direct fantasy implementations, no MCP overhead).
|
||||
// Register core tools (direct AgentTool implementations, no MCP overhead).
|
||||
// Use caller-provided tools if set, otherwise default to all core tools.
|
||||
coreTools := agentConfig.CoreTools
|
||||
if len(coreTools) == 0 {
|
||||
// DisableCoreTools allows explicitly having zero tools (for chat-only mode).
|
||||
var coreTools []fantasy.AgentTool
|
||||
if agentConfig.DisableCoreTools && len(agentConfig.CoreTools) == 0 {
|
||||
// Explicitly zero tools - chat-only mode
|
||||
coreTools = nil
|
||||
} else if len(agentConfig.CoreTools) > 0 {
|
||||
// Custom tools provided - use them
|
||||
coreTools = agentConfig.CoreTools
|
||||
} else {
|
||||
// Default: load all core tools
|
||||
coreTools = core.AllTools()
|
||||
}
|
||||
|
||||
// Build the combined tool list: core tools + any external MCP tools
|
||||
// Build the initial tool list: core tools + extension tools (no MCP yet).
|
||||
allTools := make([]fantasy.AgentTool, len(coreTools))
|
||||
copy(allTools, coreTools)
|
||||
|
||||
// Load external MCP tools if configured
|
||||
var toolManager *tools.MCPToolManager
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager = tools.NewMCPToolManager()
|
||||
toolManager.SetModel(providerResult.Model)
|
||||
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
// MCP tool loading failures are non-fatal; core tools still work
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
} else {
|
||||
mcpTools := toolManager.GetTools()
|
||||
allTools = append(allTools, mcpTools...)
|
||||
}
|
||||
}
|
||||
|
||||
// Append any extra tools provided by extensions.
|
||||
if len(agentConfig.ExtraTools) > 0 {
|
||||
allTools = append(allTools, agentConfig.ExtraTools...)
|
||||
@@ -158,7 +214,149 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
allTools = agentConfig.ToolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Build fantasy agent options
|
||||
// Build agent options
|
||||
agentOpts := buildAgentOptions(agentConfig, providerResult, allTools)
|
||||
|
||||
// Create the agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
}
|
||||
}
|
||||
|
||||
a := &Agent{
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
providerOptions: providerResult.ProviderOptions,
|
||||
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
// The mcpReady channel is closed when loading completes (success or failure).
|
||||
if agentConfig.MCPConfig != nil && len(agentConfig.MCPConfig.MCPServers) > 0 {
|
||||
toolManager := tools.NewMCPToolManager()
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
if agentConfig.TokenStoreFactory != nil {
|
||||
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
|
||||
}
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
// Set per-server loaded callback if provided.
|
||||
if agentConfig.OnMCPServerLoaded != nil {
|
||||
toolManager.SetOnServerLoaded(agentConfig.OnMCPServerLoaded)
|
||||
}
|
||||
a.toolManager = toolManager
|
||||
a.mcpReady = make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(a.mcpReady)
|
||||
if err := toolManager.LoadTools(ctx, agentConfig.MCPConfig); err != nil {
|
||||
a.mcpErr = err
|
||||
fmt.Printf("Warning: Failed to load MCP tools: %v\n", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// WaitForMCPTools blocks until background MCP tool loading completes.
|
||||
// Returns nil if no MCP servers are configured or if loading succeeded.
|
||||
// Returns the loading error if all servers failed. Safe to call multiple times.
|
||||
func (a *Agent) WaitForMCPTools() error {
|
||||
if a.mcpReady == nil {
|
||||
return nil
|
||||
}
|
||||
<-a.mcpReady
|
||||
return a.mcpErr
|
||||
}
|
||||
|
||||
// MCPToolsReady returns true if MCP tool loading has completed (or was never
|
||||
// started). This is a non-blocking check useful for UI status display.
|
||||
func (a *Agent) MCPToolsReady() bool {
|
||||
if a.mcpReady == nil {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-a.mcpReady:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ensureMCPTools waits for MCP tools to load and rebuilds the fantasy agent
|
||||
// with the full tool set. Called lazily before the first LLM call.
|
||||
// This is idempotent — subsequent calls after the first rebuild are no-ops.
|
||||
func (a *Agent) ensureMCPTools() {
|
||||
if a.mcpReady == nil {
|
||||
return
|
||||
}
|
||||
<-a.mcpReady
|
||||
|
||||
// If there are MCP tools, rebuild the fantasy agent to include them.
|
||||
if a.toolManager != nil && len(a.toolManager.GetTools()) > 0 {
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// Nil out the channel so future calls are instant no-ops and we
|
||||
// don't rebuild again.
|
||||
a.mcpReady = nil
|
||||
}
|
||||
|
||||
// rebuildFantasyAgent reconstructs the fantasy agent with the current full
|
||||
// tool set (core + MCP + extension tools). Used after MCP tools arrive
|
||||
// asynchronously and by SetModel.
|
||||
func (a *Agent) rebuildFantasyAgent() {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
providerResult := &models.ProviderResult{
|
||||
Model: a.model,
|
||||
ProviderOptions: a.providerOptions,
|
||||
SkipMaxOutputTokens: a.skipMaxOutputTokens,
|
||||
}
|
||||
agentOpts := buildAgentOptions(&AgentConfig{
|
||||
ModelConfig: a.modelConfig,
|
||||
SystemPrompt: a.systemPrompt,
|
||||
MaxSteps: a.maxSteps,
|
||||
}, providerResult, allTools)
|
||||
|
||||
a.fantasyAgent = fantasy.NewAgent(a.model, agentOpts...)
|
||||
}
|
||||
|
||||
// buildAgentOptions constructs the fantasy.AgentOption slice from config,
|
||||
// provider result, and the combined tool list. Shared by NewAgent,
|
||||
// rebuildFantasyAgent, and SetModel.
|
||||
func buildAgentOptions(agentConfig *AgentConfig, providerResult *models.ProviderResult, allTools []fantasy.AgentTool) []fantasy.AgentOption {
|
||||
var agentOpts []fantasy.AgentOption
|
||||
|
||||
if agentConfig.SystemPrompt != "" {
|
||||
@@ -196,33 +394,15 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.ModelConfig.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*agentConfig.ModelConfig.TopK)))
|
||||
}
|
||||
}
|
||||
|
||||
// Create the fantasy agent
|
||||
fantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Determine provider type from model string
|
||||
providerType := "default"
|
||||
if agentConfig.ModelConfig != nil && agentConfig.ModelConfig.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(agentConfig.ModelConfig.ModelString); err == nil {
|
||||
providerType = p
|
||||
if agentConfig.ModelConfig.FrequencyPenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithFrequencyPenalty(float64(*agentConfig.ModelConfig.FrequencyPenalty)))
|
||||
}
|
||||
if agentConfig.ModelConfig.PresencePenalty != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithPresencePenalty(float64(*agentConfig.ModelConfig.PresencePenalty)))
|
||||
}
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
toolManager: toolManager,
|
||||
fantasyAgent: fantasyAgent,
|
||||
model: providerResult.Model,
|
||||
providerCloser: providerResult.Closer,
|
||||
maxSteps: agentConfig.MaxSteps,
|
||||
systemPrompt: agentConfig.SystemPrompt,
|
||||
loadingMessage: providerResult.Message,
|
||||
providerType: providerType,
|
||||
streamingEnabled: agentConfig.StreamingEnabled,
|
||||
coreTools: coreTools,
|
||||
extraTools: agentConfig.ExtraTools,
|
||||
toolWrapper: agentConfig.ToolWrapper,
|
||||
}, nil
|
||||
return agentOpts
|
||||
}
|
||||
|
||||
// GenerateWithLoop processes messages with a custom loop that displays tool calls in real-time.
|
||||
@@ -231,38 +411,54 @@ func (a *Agent) GenerateWithLoop(ctx context.Context, messages []fantasy.Message
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
return a.GenerateWithLoopAndStreaming(ctx, messages, onToolCall, onToolExecution, onToolResult,
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil)
|
||||
onResponse, onToolCallContent, nil, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// GenerateWithLoopAndStreaming processes messages using the fantasy agent with streaming and callbacks.
|
||||
// Fantasy handles the tool call loop internally. We map fantasy's rich callback system
|
||||
// GenerateWithLoopAndStreaming processes messages using the agent with streaming and callbacks.
|
||||
// The agent handles the tool call loop internally. We map the rich callback system
|
||||
// to kit's existing callback interface for UI integration.
|
||||
func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fantasy.Message,
|
||||
onToolCall ToolCallHandler, onToolExecution ToolExecutionHandler, onToolResult ToolResultHandler,
|
||||
onResponse ResponseHandler, onToolCallContent ToolCallContentHandler,
|
||||
onStreamingResponse StreamingResponseHandler,
|
||||
onReasoningDelta ReasoningDeltaHandler,
|
||||
onReasoningComplete ReasoningCompleteHandler,
|
||||
onToolOutput ToolOutputHandler,
|
||||
onStepMessages StepMessagesHandler,
|
||||
onStepUsage StepUsageHandler,
|
||||
onPasswordPrompt PasswordPromptHandler,
|
||||
) (*GenerateWithLoopResult, error) {
|
||||
|
||||
// Wait for background MCP tool loading to complete and rebuild the
|
||||
// fantasy agent with the full tool set. This is a no-op when no MCP
|
||||
// servers are configured or tools have already been integrated.
|
||||
a.ensureMCPTools()
|
||||
|
||||
// Inject tool output handler into context for use by core tools (e.g., bash).
|
||||
if onToolOutput != nil {
|
||||
ctx = core.ContextWithToolOutputCallback(ctx, onToolOutput)
|
||||
}
|
||||
|
||||
// Fantasy requires the current user input as Prompt, with prior messages as history.
|
||||
// Inject password prompt handler into context for use by bash tool.
|
||||
if onPasswordPrompt != nil {
|
||||
ctx = core.ContextWithPasswordPrompt(ctx, onPasswordPrompt)
|
||||
}
|
||||
|
||||
// The agent requires the current user input as Prompt, with prior messages as history.
|
||||
// Extract the last user message text and files as the prompt, and pass everything
|
||||
// before it as Messages. Files (e.g. clipboard images) are passed via the Files
|
||||
// field so Fantasy includes them in the API request.
|
||||
// field so the agent includes them in the API request.
|
||||
prompt, files, history := splitPromptAndHistory(messages)
|
||||
|
||||
// Track current tool call info for callbacks
|
||||
var currentToolName string
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
history = applyCacheControlToMessages(history)
|
||||
|
||||
// Track current tool call args for callbacks
|
||||
var currentToolArgs string
|
||||
|
||||
// Use the streaming path when streaming is enabled OR when any callbacks are
|
||||
// provided. Fantasy only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// provided. The agent only exposes tool/step callbacks on AgentStreamCall, so
|
||||
// Stream is required to observe tool execution in real time. The non-streaming
|
||||
// Generate path is reserved for the simple case with no callbacks at all.
|
||||
hasCallbacks := onToolCall != nil || onToolExecution != nil || onToolResult != nil ||
|
||||
@@ -270,12 +466,16 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
|
||||
if a.streamingEnabled || hasCallbacks {
|
||||
// Track completed step messages so we can return partial results
|
||||
// on cancellation. Fantasy's Stream() discards accumulated steps
|
||||
// on cancellation. The agent's Stream() discards accumulated steps
|
||||
// when it returns an error, but the OnStepFinish callback fires
|
||||
// for every step that completed before the error occurred.
|
||||
var completedStepMessages []fantasy.Message
|
||||
// persistedCount tracks how many new messages (beyond the original
|
||||
// input) were persisted incrementally via onStepMessages, so the
|
||||
// caller can skip them during post-generation persistence.
|
||||
var persistedCount int
|
||||
|
||||
// Use fantasy's streaming agent
|
||||
// Use the streaming agent
|
||||
streamCall := fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
Files: files,
|
||||
@@ -292,6 +492,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil
|
||||
},
|
||||
|
||||
// Reasoning/thinking complete callback
|
||||
OnReasoningEnd: func(id string, _ fantasy.ReasoningContent) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if onReasoningComplete != nil {
|
||||
onReasoningComplete()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
|
||||
// Text streaming callback
|
||||
OnTextDelta: func(id, text string) error {
|
||||
if ctx.Err() != nil {
|
||||
@@ -308,7 +519,6 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
currentToolName = tc.ToolName
|
||||
currentToolArgs = tc.Input
|
||||
|
||||
// Notify about the tool call
|
||||
@@ -349,6 +559,13 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
// persisted even if a later step is cancelled.
|
||||
completedStepMessages = append(completedStepMessages, step.Messages...)
|
||||
|
||||
// Persist step messages incrementally so progress is saved
|
||||
// as it happens rather than only at the end of the turn.
|
||||
if onStepMessages != nil && len(step.Messages) > 0 {
|
||||
onStepMessages(step.Messages)
|
||||
persistedCount += len(step.Messages)
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -379,7 +596,7 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
opts fantasy.PrepareStepFunctionOptions,
|
||||
) (context.Context, fantasy.PrepareStepResult, error) {
|
||||
// Drain all pending steer messages (non-blocking).
|
||||
var steered []string
|
||||
var steered []SteerMessage
|
||||
for {
|
||||
select {
|
||||
case msg := <-steerCh:
|
||||
@@ -396,15 +613,20 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
if len(steered) > 0 {
|
||||
// Inject each steer message as a user message so the
|
||||
// LLM sees the redirection on the next step.
|
||||
for _, text := range steered {
|
||||
for _, sm := range steered {
|
||||
result.Messages = append(result.Messages,
|
||||
fantasy.NewUserMessage(text))
|
||||
fantasy.NewUserMessage(sm.Text, sm.Files...))
|
||||
}
|
||||
// Notify that steer messages were consumed.
|
||||
if onConsumed != nil {
|
||||
onConsumed(len(steered))
|
||||
}
|
||||
}
|
||||
|
||||
// Apply message-level cache control for Anthropic models.
|
||||
// This avoids type conflicts with provider-level options.
|
||||
result.Messages = applyCacheControlToMessages(result.Messages)
|
||||
|
||||
return stepCtx, result, nil
|
||||
}
|
||||
}
|
||||
@@ -422,19 +644,25 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
partialMessages = append(partialMessages, messages...)
|
||||
partialMessages = append(partialMessages, completedStepMessages...)
|
||||
return &GenerateWithLoopResult{
|
||||
ConversationMessages: partialMessages,
|
||||
ConversationMessages: partialMessages,
|
||||
PersistedMessageCount: persistedCount,
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fire the response callback for callers that use it (e.g. non-streaming
|
||||
// callers that still want the final response notification).
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// Fire the response callback so callers (e.g. the TUI) can reset
|
||||
// streaming state. This must fire even when the response text is
|
||||
// empty (e.g. reasoning-only responses) so the UI properly resets
|
||||
// the stream component and avoids duplicate content on the next
|
||||
// flush.
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
r := convertAgentResult(result, messages)
|
||||
r.PersistedMessageCount = persistedCount
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// Non-streaming path with no callbacks — use the simpler Generate call.
|
||||
@@ -447,18 +675,17 @@ func (a *Agent) GenerateWithLoopAndStreaming(ctx context.Context, messages []fan
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For non-streaming, fire the response callback with the final text
|
||||
if onResponse != nil && result.Response.Content.Text() != "" {
|
||||
// For non-streaming, fire the response callback so callers can reset
|
||||
// streaming state (see streaming path comment above).
|
||||
if onResponse != nil {
|
||||
onResponse(result.Response.Content.Text())
|
||||
}
|
||||
|
||||
_ = currentToolName // satisfy compiler for non-streaming path
|
||||
|
||||
return convertAgentResult(result, messages), nil
|
||||
}
|
||||
|
||||
// splitPromptAndHistory extracts the last user message as the prompt string,
|
||||
// and returns everything before it as conversation history. Fantasy's agent
|
||||
// and returns everything before it as conversation history. The agent's
|
||||
// requires the current turn's input as Prompt (string), with prior messages
|
||||
// passed separately as Messages (history).
|
||||
func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePart, []fantasy.Message) {
|
||||
@@ -501,8 +728,8 @@ func splitPromptAndHistory(messages []fantasy.Message) (string, []fantasy.FilePa
|
||||
return "", nil, messages
|
||||
}
|
||||
|
||||
// convertAgentResult converts a fantasy AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the legacy fantasy.Message slice and the new custom content blocks.
|
||||
// convertAgentResult converts an AgentResult to our GenerateWithLoopResult.
|
||||
// It builds both the message slice and the new custom content blocks.
|
||||
func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.Message) *GenerateWithLoopResult {
|
||||
// Collect all conversation messages: original + all step messages
|
||||
var allFantasyMessages []fantasy.Message
|
||||
@@ -515,7 +742,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
// Convert to custom content blocks
|
||||
var allMessages []message.Message
|
||||
for _, fm := range allFantasyMessages {
|
||||
allMessages = append(allMessages, message.FromFantasyMessage(fm))
|
||||
allMessages = append(allMessages, message.FromLLMMessage(fm))
|
||||
}
|
||||
|
||||
return &GenerateWithLoopResult{
|
||||
@@ -527,7 +754,7 @@ func convertAgentResult(result *fantasy.AgentResult, originalMessages []fantasy.
|
||||
}
|
||||
}
|
||||
|
||||
// extractToolResultText extracts the text and error status from a fantasy ToolResultContent.
|
||||
// extractToolResultText extracts the text and error status from a ToolResultContent.
|
||||
// For core tools, the result is already clean text (no MCP JSON wrapping).
|
||||
// For MCP tools, it unwraps the MCP content structure.
|
||||
func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
@@ -540,7 +767,7 @@ func extractToolResultText(tr fantasy.ToolResultContent) (string, bool) {
|
||||
return errResult.Error.Error(), true
|
||||
}
|
||||
|
||||
// Get text directly from the Fantasy result type.
|
||||
// Get text directly from the result type.
|
||||
if textResult, ok := tr.Result.(fantasy.ToolResultOutputContentText); ok {
|
||||
// Try to unwrap MCP JSON structure (for external MCP tools).
|
||||
// Core tools return plain text, so this is a no-op for them.
|
||||
@@ -592,7 +819,7 @@ func (a *Agent) GetTools() []fantasy.AgentTool {
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
allTools = append(allTools, mcpToolsToAgentTools(a.toolManager.GetTools(), a.toolManager)...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
@@ -618,6 +845,72 @@ func (a *Agent) GetExtensionToolCount() int {
|
||||
return len(a.extraTools)
|
||||
}
|
||||
|
||||
// SetExtraTools replaces the agent's extra tools (e.g. extension-registered
|
||||
// tools) and rebuilds the internal agent with the updated tool list. The
|
||||
// model, system prompt, and all other configuration are preserved.
|
||||
func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
|
||||
a.extraTools = extraTools
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent. Returns the number of tools loaded.
|
||||
// If the agent has no tool manager (no MCP servers were configured at init),
|
||||
// one is created automatically.
|
||||
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
if a.authHandler != nil {
|
||||
a.toolManager.SetAuthHandler(a.authHandler)
|
||||
}
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
}
|
||||
|
||||
count, err := a.toolManager.AddServer(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but only if it was wired. Ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
|
||||
func (a *Agent) RemoveMCPServer(name string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
err := a.toolManager.RemoveServer(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPToolManager returns the underlying MCP tool manager.
|
||||
// Returns nil if no MCP servers have been configured.
|
||||
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
|
||||
return a.toolManager
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
@@ -631,78 +924,88 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
|
||||
// system prompt, and configuration are preserved. The old provider is closed
|
||||
// if it has a closer. Returns the previous model string for notification.
|
||||
// GetMCPPrompts returns all prompts discovered from connected MCP servers.
|
||||
// Returns nil if no MCP servers are configured or no prompts were found.
|
||||
func (a *Agent) GetMCPPrompts() []tools.MCPPrompt {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetPrompts()
|
||||
}
|
||||
|
||||
// GetMCPPrompt retrieves and expands a specific prompt from an MCP server.
|
||||
// This is a lazy call — the server is contacted each time.
|
||||
func (a *Agent) GetMCPPrompt(ctx context.Context, serverName, promptName string, args map[string]string) (*tools.MCPPromptResult, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.GetPrompt(ctx, serverName, promptName, args)
|
||||
}
|
||||
|
||||
// GetMCPResources returns all resources discovered from connected MCP servers.
|
||||
func (a *Agent) GetMCPResources() []tools.MCPResource {
|
||||
if a.toolManager == nil {
|
||||
return nil
|
||||
}
|
||||
return a.toolManager.GetResources()
|
||||
}
|
||||
|
||||
// ReadMCPResource reads a specific resource from an MCP server by URI.
|
||||
func (a *Agent) ReadMCPResource(ctx context.Context, serverName, uri string) (*tools.MCPResourceContent, error) {
|
||||
if a.toolManager == nil {
|
||||
return nil, fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.ReadResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SubscribeMCPResource subscribes to change notifications for a resource.
|
||||
func (a *Agent) SubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.SubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// UnsubscribeMCPResource cancels change notifications for a resource.
|
||||
func (a *Agent) UnsubscribeMCPResource(ctx context.Context, serverName, uri string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
return a.toolManager.UnsubscribeResource(ctx, serverName, uri)
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
|
||||
// it has a closer.
|
||||
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
|
||||
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
|
||||
// before the first LLM call).
|
||||
a.ensureMCPTools()
|
||||
|
||||
providerResult, err := models.CreateProvider(ctx, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create model provider: %v", err)
|
||||
}
|
||||
|
||||
// Rebuild tool list (same as NewAgent).
|
||||
allTools := make([]fantasy.AgentTool, len(a.coreTools))
|
||||
copy(allTools, a.coreTools)
|
||||
if a.toolManager != nil {
|
||||
allTools = append(allTools, a.toolManager.GetTools()...)
|
||||
}
|
||||
if len(a.extraTools) > 0 {
|
||||
allTools = append(allTools, a.extraTools...)
|
||||
}
|
||||
if a.toolWrapper != nil {
|
||||
allTools = a.toolWrapper(allTools)
|
||||
}
|
||||
|
||||
// Rebuild fantasy agent options.
|
||||
var agentOpts []fantasy.AgentOption
|
||||
if a.systemPrompt != "" {
|
||||
agentOpts = append(agentOpts, fantasy.WithSystemPrompt(a.systemPrompt))
|
||||
}
|
||||
if len(allTools) > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithTools(allTools...))
|
||||
}
|
||||
if a.maxSteps > 0 {
|
||||
agentOpts = append(agentOpts, fantasy.WithStopConditions(
|
||||
fantasy.StepCountIs(a.maxSteps),
|
||||
))
|
||||
}
|
||||
|
||||
// Pass provider-specific options (e.g. OpenAI Responses API reasoning settings).
|
||||
if providerResult.ProviderOptions != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithProviderOptions(providerResult.ProviderOptions))
|
||||
}
|
||||
|
||||
// Pass generation parameters when available.
|
||||
// Skip max_output_tokens for providers that don't support it (e.g., Codex OAuth)
|
||||
if config.MaxTokens > 0 && !providerResult.SkipMaxOutputTokens {
|
||||
agentOpts = append(agentOpts, fantasy.WithMaxOutputTokens(int64(config.MaxTokens)))
|
||||
}
|
||||
if config.Temperature != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTemperature(float64(*config.Temperature)))
|
||||
}
|
||||
if config.TopP != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopP(float64(*config.TopP)))
|
||||
}
|
||||
if config.TopK != nil {
|
||||
agentOpts = append(agentOpts, fantasy.WithTopK(int64(*config.TopK)))
|
||||
}
|
||||
|
||||
newFantasyAgent := fantasy.NewAgent(providerResult.Model, agentOpts...)
|
||||
|
||||
// Close old provider.
|
||||
if a.providerCloser != nil {
|
||||
_ = a.providerCloser.Close()
|
||||
}
|
||||
|
||||
// Update model info on MCP tool manager.
|
||||
if a.toolManager != nil {
|
||||
a.toolManager.SetModel(providerResult.Model)
|
||||
}
|
||||
|
||||
// Swap fields.
|
||||
a.fantasyAgent = newFantasyAgent
|
||||
a.model = providerResult.Model
|
||||
a.providerCloser = providerResult.Closer
|
||||
a.providerOptions = providerResult.ProviderOptions
|
||||
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
|
||||
a.modelConfig = config
|
||||
|
||||
// Update system prompt when the config carries one (from per-model
|
||||
// settings or the global config). This allows model-specific system
|
||||
// prompts to take effect on model switch.
|
||||
if config.SystemPrompt != "" {
|
||||
a.systemPrompt = config.SystemPrompt
|
||||
}
|
||||
|
||||
// Update provider type.
|
||||
if config.ModelString != "" {
|
||||
@@ -711,16 +1014,25 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild the fantasy agent with the new model and current tool set.
|
||||
a.rebuildFantasyAgent()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModel returns the underlying fantasy LanguageModel.
|
||||
// GetModel returns the underlying LanguageModel.
|
||||
func (a *Agent) GetModel() fantasy.LanguageModel {
|
||||
return a.model
|
||||
}
|
||||
|
||||
// Close closes the agent and cleans up resources.
|
||||
// If MCP tools are still loading in the background, Close waits for them
|
||||
// to finish before closing connections to avoid resource leaks.
|
||||
func (a *Agent) Close() error {
|
||||
// Wait for background MCP loading to finish before closing connections.
|
||||
if a.mcpReady != nil {
|
||||
<-a.mcpReady
|
||||
}
|
||||
var toolErr error
|
||||
if a.toolManager != nil {
|
||||
toolErr = a.toolManager.Close()
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// mockModel is a minimal LanguageModel that satisfies the interface
|
||||
// without making real API calls. Used to test tool management wiring.
|
||||
type mockModel struct{}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) Provider() string { return "mock" }
|
||||
func (m *mockModel) Model() string { return "mock-model" }
|
||||
|
||||
// testdataDir returns the absolute path to the tools testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// mockAuthHandler is a minimal MCPAuthHandler for testing that auth handler
|
||||
// propagation works without requiring a real OAuth server.
|
||||
type mockAuthHandler struct {
|
||||
redirectURI string
|
||||
}
|
||||
|
||||
func (h *mockAuthHandler) RedirectURI() string { return h.redirectURI }
|
||||
func (h *mockAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// newTestAgent creates a minimal Agent with a mock model and no core tools,
|
||||
// suitable for testing MCP server management without an API key.
|
||||
func newTestAgent() *Agent {
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Initially no MCP tools.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Add a server.
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
// Verify tools are in the agent's tool list.
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
allTools := a.GetTools()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range allTools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo' in agent tools")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet' in agent tools")
|
||||
}
|
||||
|
||||
// Verify loaded server names.
|
||||
names := a.GetLoadedServerNames()
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected 'echo' in loaded server names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add then remove.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools removed.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Verify agent's tool list has no MCP tools.
|
||||
for _, tool := range a.GetTools() {
|
||||
if strings.Contains(tool.Info().Name, "echo__") {
|
||||
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
err := a.RemoveMCPServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no tool manager exists")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no MCP servers loaded") {
|
||||
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist.
|
||||
if a.GetMCPToolManager() == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add → Remove → Add cycle.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First add failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-add failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgent_AddMCPServer_InheritsAuthHandler verifies that AddMCPServer()
|
||||
// propagates the agent's authHandler and tokenStoreFactory to a newly created
|
||||
// MCPToolManager (fix for issue #3).
|
||||
func TestAgent_AddMCPServer_InheritsAuthHandler(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := &mockAuthHandler{redirectURI: "http://localhost:9999/oauth/callback"}
|
||||
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
authHandler: handler,
|
||||
tokenStoreFactory: nil, // nil is fine; we just test authHandler propagation
|
||||
}
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist and have the auth handler set.
|
||||
tm := a.GetMCPToolManager()
|
||||
if tm == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
|
||||
// Verify the auth handler was propagated by checking the field directly.
|
||||
if tm.GetAuthHandler() == nil {
|
||||
t.Fatal("Expected auth handler to be propagated to tool manager")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/anthropic"
|
||||
)
|
||||
|
||||
// cacheControlOptions returns provider options for Anthropic cache control.
|
||||
// This is used at the message level to avoid type conflicts with provider-level options.
|
||||
func cacheControlOptions() fantasy.ProviderOptions {
|
||||
return anthropic.NewProviderCacheControlOptions(&anthropic.ProviderCacheControlOptions{
|
||||
CacheControl: anthropic.CacheControl{
|
||||
Type: "ephemeral",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// applyCacheControlToMessages adds cache control to specific messages.
|
||||
// Anthropic allows max 4 cache blocks per request.
|
||||
// Counts existing cache blocks and only adds new ones up to the limit.
|
||||
func applyCacheControlToMessages(messages []fantasy.Message) []fantasy.Message {
|
||||
if len(messages) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// Make a copy to avoid modifying the original slice
|
||||
result := make([]fantasy.Message, len(messages))
|
||||
copy(result, messages)
|
||||
|
||||
cacheOpts := cacheControlOptions()
|
||||
maxCacheBlocks := 4
|
||||
|
||||
// Helper to check if message already has cache control
|
||||
hasCache := func(msg fantasy.Message) bool {
|
||||
if msg.ProviderOptions == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := msg.ProviderOptions["anthropic"]; ok {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Count existing cache blocks
|
||||
existingCacheCount := 0
|
||||
for _, msg := range result {
|
||||
if hasCache(msg) {
|
||||
existingCacheCount++
|
||||
}
|
||||
}
|
||||
|
||||
// If we're already at or over the limit, don't add more
|
||||
if existingCacheCount >= maxCacheBlocks {
|
||||
return result
|
||||
}
|
||||
|
||||
// How many new cache blocks can we add?
|
||||
remaining := maxCacheBlocks - existingCacheCount
|
||||
|
||||
// First: find and cache the last system message (most important)
|
||||
lastSystemIdx := -1
|
||||
for i, msg := range result {
|
||||
if msg.Role == fantasy.MessageRoleSystem {
|
||||
lastSystemIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if lastSystemIdx >= 0 && remaining > 0 && !hasCache(result[lastSystemIdx]) {
|
||||
result[lastSystemIdx].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
// Second: cache the most recent messages (up to remaining limit)
|
||||
// Work backwards from the end to prioritize recent context
|
||||
for i := len(result) - 1; i >= 0 && remaining > 0; i-- {
|
||||
if hasCache(result[i]) {
|
||||
continue
|
||||
}
|
||||
result[i].ProviderOptions = cacheOpts
|
||||
remaining--
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
+27
-10
@@ -36,13 +36,26 @@ type AgentCreationOptions struct {
|
||||
SpinnerFunc SpinnerFunc // Function to show spinner (provided by caller)
|
||||
// DebugLogger is an optional logger for debugging MCP communications
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
// ToolWrapper wraps the combined tool list before Fantasy agent creation.
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ToolWrapper wraps the combined tool list before agent creation.
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
// ExtraTools are additional tools to include (e.g. from extensions).
|
||||
ExtraTools []fantasy.AgentTool
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// CreateAgent creates an agent with optional spinner for Ollama models.
|
||||
@@ -50,15 +63,19 @@ type AgentCreationOptions struct {
|
||||
// Returns the created agent or an error if creation fails.
|
||||
func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error) {
|
||||
agentConfig := &AgentConfig{
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
ModelConfig: opts.ModelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: opts.SystemPrompt,
|
||||
MaxSteps: opts.MaxSteps,
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
ExtraTools: opts.ExtraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
}
|
||||
|
||||
var agent *Agent
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/tools"
|
||||
)
|
||||
|
||||
// mcpAgentTool adapts an tools.MCPTool to the fantasy.AgentTool interface.
|
||||
// This keeps the fantasy dependency confined to the agent layer — the tools
|
||||
// package is a pure MCP client library with no LLM framework dependency.
|
||||
type mcpAgentTool struct {
|
||||
tool tools.MCPTool
|
||||
manager *tools.MCPToolManager
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpAgentTool) Info() fantasy.ToolInfo {
|
||||
return fantasy.ToolInfo{
|
||||
Name: t.tool.Name,
|
||||
Description: t.tool.Description,
|
||||
Parameters: t.tool.Parameters,
|
||||
Required: t.tool.Required,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by delegating to the MCPToolManager.
|
||||
func (t *mcpAgentTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
result, err := t.manager.ExecuteTool(ctx, t.tool.Name, call.Input)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("mcp tool execution failed: %w", err)
|
||||
}
|
||||
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(result.Content), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(result.Content), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpAgentTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
|
||||
// mcpToolsToAgentTools converts a slice of MCPTool to fantasy.AgentTool
|
||||
// implementations that route execution through the MCPToolManager.
|
||||
func mcpToolsToAgentTools(mcpTools []tools.MCPTool, manager *tools.MCPToolManager) []fantasy.AgentTool {
|
||||
agentTools := make([]fantasy.AgentTool, len(mcpTools))
|
||||
for i, t := range mcpTools {
|
||||
agentTools[i] = &mcpAgentTool{
|
||||
tool: t,
|
||||
manager: manager,
|
||||
}
|
||||
}
|
||||
return agentTools
|
||||
}
|
||||
+15
-4
@@ -1,6 +1,17 @@
|
||||
package agent
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// SteerMessage carries a steering prompt and optional file attachments
|
||||
// (e.g. clipboard images) through the steer channel.
|
||||
type SteerMessage struct {
|
||||
Text string
|
||||
Files []fantasy.FilePart
|
||||
}
|
||||
|
||||
// steerChKey is the context key for the steer channel.
|
||||
type steerChKey struct{}
|
||||
@@ -11,7 +22,7 @@ type steerConsumedKey struct{}
|
||||
// ContextWithSteerCh returns a new context with the steer channel attached.
|
||||
// The agent's PrepareStep function checks this channel between steps and
|
||||
// injects any pending steer messages as user messages before the next LLM call.
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan string) context.Context {
|
||||
func ContextWithSteerCh(ctx context.Context, ch <-chan SteerMessage) context.Context {
|
||||
return context.WithValue(ctx, steerChKey{}, ch)
|
||||
}
|
||||
|
||||
@@ -23,8 +34,8 @@ func ContextWithSteerConsumed(ctx context.Context, fn func(count int)) context.C
|
||||
}
|
||||
|
||||
// steerChFromContext extracts the steer channel from the context, or nil.
|
||||
func steerChFromContext(ctx context.Context) <-chan string {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan string)
|
||||
func steerChFromContext(ctx context.Context) <-chan SteerMessage {
|
||||
ch, _ := ctx.Value(steerChKey{}).(<-chan SteerMessage)
|
||||
return ch
|
||||
}
|
||||
|
||||
|
||||
+349
-53
@@ -3,9 +3,11 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/fantasy"
|
||||
@@ -18,7 +20,7 @@ import (
|
||||
// queueItem holds a prompt and optional image attachments for the execution queue.
|
||||
type queueItem struct {
|
||||
Prompt string
|
||||
Files []fantasy.FilePart
|
||||
Files []kit.LLMFilePart
|
||||
}
|
||||
|
||||
// App is the application-layer orchestrator. It owns the agentic loop,
|
||||
@@ -67,11 +69,20 @@ type App struct {
|
||||
// rootCtx/rootCancel are used to signal shutdown to all goroutines.
|
||||
rootCtx context.Context
|
||||
rootCancel context.CancelFunc
|
||||
|
||||
// widgetUpdatePending is set to true when a WidgetUpdateEvent has been
|
||||
// sent to the TUI but not yet consumed by its event loop. While the flag
|
||||
// is set, subsequent NotifyWidgetUpdate calls are coalesced (dropped) to
|
||||
// prevent fast extension tickers from flooding the BubbleTea mailbox with
|
||||
// redundant re-render triggers. The flag is cleared after a short debounce
|
||||
// (~1 frame) so new updates are always let through once the TUI has had a
|
||||
// chance to process the pending event.
|
||||
widgetUpdatePending atomic.Bool
|
||||
}
|
||||
|
||||
// New creates a new App with the provided options and pre-loaded messages.
|
||||
// initialMessages may be nil or empty for a fresh session.
|
||||
func New(opts Options, initialMessages []fantasy.Message) *App {
|
||||
func New(opts Options, initialMessages []kit.LLMMessage) *App {
|
||||
rootCtx, rootCancel := context.WithCancel(context.Background())
|
||||
return &App{
|
||||
opts: opts,
|
||||
@@ -115,9 +126,8 @@ func (a *App) Run(prompt string) int {
|
||||
// If the app is idle the prompt executes immediately; otherwise it is queued.
|
||||
// Returns the current queue depth (0 = started immediately, >0 = queued).
|
||||
//
|
||||
// Satisfies ui.AppController (via RunWithImages which converts ImageAttachment
|
||||
// to fantasy.FilePart).
|
||||
func (a *App) RunWithFiles(prompt string, files []fantasy.FilePart) int {
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) RunWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -152,6 +162,24 @@ func (a *App) CancelCurrentStep() {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// IsBusy returns true when the agent is currently processing a turn.
|
||||
func (a *App) IsBusy() bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.busy
|
||||
}
|
||||
|
||||
// Abort cancels the current agent step (if running) and clears the queue.
|
||||
// Unlike InterruptAndSend, no new message is injected — the agent simply
|
||||
// stops. Safe to call when idle (no-op).
|
||||
func (a *App) Abort() {
|
||||
a.mu.Lock()
|
||||
a.queue = a.queue[:0]
|
||||
cancel := a.cancelStep
|
||||
a.mu.Unlock()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// QueueLength returns the number of prompts currently waiting in the queue.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
@@ -177,6 +205,15 @@ func (a *App) QueueLength() int {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) Steer(prompt string) int {
|
||||
return a.SteerWithFiles(prompt, nil)
|
||||
}
|
||||
|
||||
// SteerWithFiles injects a steering message with optional file attachments
|
||||
// (e.g. pasted images) into the currently running agent turn. Behaves like
|
||||
// Steer but includes file parts alongside the text.
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) SteerWithFiles(prompt string, files []kit.LLMFilePart) int {
|
||||
a.mu.Lock()
|
||||
|
||||
if a.closed {
|
||||
@@ -185,8 +222,8 @@ func (a *App) Steer(prompt string) int {
|
||||
}
|
||||
|
||||
if !a.busy {
|
||||
// Not busy — start immediately, same as Run().
|
||||
item := queueItem{Prompt: prompt}
|
||||
// Not busy — start immediately, same as RunWithFiles().
|
||||
item := queueItem{Prompt: prompt, Files: files}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
@@ -201,7 +238,7 @@ func (a *App) Steer(prompt string) int {
|
||||
// execution, before next LLM call). If PrepareStep doesn't fire
|
||||
// (text-only response), drainQueue will pick it up after the turn.
|
||||
if a.opts.Kit != nil {
|
||||
a.opts.Kit.InjectSteer(prompt)
|
||||
a.opts.Kit.InjectSteerWithFiles(prompt, files)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
@@ -260,6 +297,17 @@ func (a *App) ClearMessages() {
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadMessagesFromTree clears the in-memory message store and reloads it
|
||||
// from the tree session's current branch. Unlike ClearMessages, this does NOT
|
||||
// reset the tree session's leaf pointer. Used after Branch() to sync the
|
||||
// store with the new branch position.
|
||||
func (a *App) ReloadMessagesFromTree() {
|
||||
a.store.Clear()
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
// GetTreeSession returns the tree session manager, or nil if not configured.
|
||||
func (a *App) GetTreeSession() *session.TreeManager {
|
||||
return a.opts.TreeSession
|
||||
@@ -281,7 +329,7 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
// Reload messages from new session.
|
||||
a.store.Clear()
|
||||
if ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -292,12 +340,12 @@ func (a *App) SwitchTreeSession(ts *session.TreeManager) {
|
||||
//
|
||||
// Satisfies ui.AppController.
|
||||
func (a *App) AddContextMessage(text string) {
|
||||
msg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(msg)
|
||||
kitMsg := fantasy.NewUserMessage(text)
|
||||
a.store.Add(kitMsg)
|
||||
|
||||
// Persist to tree session if active.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
_, _ = ts.AppendFantasyMessage(msg)
|
||||
_, _ = ts.AppendLLMMessage(fantasy.NewUserMessage(text))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -335,6 +383,15 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
@@ -347,7 +404,7 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetFantasyMessages())
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
@@ -360,6 +417,78 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompactAsync is like CompactConversation but calls onComplete/onError
|
||||
// callbacks instead of sending TUI events. Used by the extension API's
|
||||
// ctx.Compact() which needs callback-based notification.
|
||||
func (a *App) CompactAsync(customInstructions string, onComplete func(), onError func(string)) error {
|
||||
a.mu.Lock()
|
||||
if a.closed {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("app is closed")
|
||||
}
|
||||
if a.busy {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("cannot compact while the agent is working")
|
||||
}
|
||||
if a.opts.Kit == nil {
|
||||
a.mu.Unlock()
|
||||
return fmt.Errorf("SDK instance not available")
|
||||
}
|
||||
a.busy = true
|
||||
a.wg.Add(1)
|
||||
a.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
defer a.wg.Done()
|
||||
defer func() {
|
||||
a.mu.Lock()
|
||||
a.busy = false
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
|
||||
// Subscribe to SDK events for streaming compaction summary to the TUI.
|
||||
sendFn := func(msg tea.Msg) {
|
||||
if a.program != nil {
|
||||
a.program.Send(msg)
|
||||
}
|
||||
}
|
||||
unsub := a.subscribeSDKEvents(sendFn, nil)
|
||||
defer unsub()
|
||||
|
||||
result, err := a.opts.Kit.Compact(a.rootCtx, nil, customInstructions)
|
||||
if err != nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: err})
|
||||
if onError != nil {
|
||||
onError(err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
if result == nil {
|
||||
a.sendEvent(CompactErrorEvent{Err: fmt.Errorf("nothing to compact")})
|
||||
if onError != nil {
|
||||
onError("nothing to compact")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Sync in-memory store with the compacted session.
|
||||
if a.opts.TreeSession != nil {
|
||||
a.store.Replace(a.opts.TreeSession.GetLLMMessages())
|
||||
}
|
||||
|
||||
a.sendEvent(CompactCompleteEvent{
|
||||
Summary: result.Summary,
|
||||
OriginalTokens: result.OriginalTokens,
|
||||
CompactedTokens: result.CompactedTokens,
|
||||
MessagesRemoved: result.MessagesRemoved,
|
||||
})
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Non-interactive execution
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -368,6 +497,12 @@ func (a *App) CompactConversation(customInstructions string) error {
|
||||
// response text to stdout. No intermediate events are emitted. Blocks until
|
||||
// the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
return a.RunOnceWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithFiles executes a single agent step synchronously with optional
|
||||
// multimodal file attachments. Prints the response to stdout and returns.
|
||||
func (a *App) RunOnceWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -375,7 +510,7 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, nil, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -390,6 +525,12 @@ func (a *App) RunOnce(ctx context.Context, prompt string) error {
|
||||
// full TurnResult without printing anything. This is used by --json mode to
|
||||
// capture structured output for serialization.
|
||||
func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult, error) {
|
||||
return a.RunOnceResultWithFiles(ctx, prompt, nil)
|
||||
}
|
||||
|
||||
// RunOnceResultWithFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments and returns the full TurnResult.
|
||||
func (a *App) RunOnceResultWithFiles(ctx context.Context, prompt string, files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -397,7 +538,7 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
return a.executeStep(stepCtx, prompt, nil, nil)
|
||||
return a.executeStep(stepCtx, prompt, nil, files)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplay executes a single agent step synchronously, sending
|
||||
@@ -411,6 +552,12 @@ func (a *App) RunOnceResult(ctx context.Context, prompt string) (*kit.TurnResult
|
||||
//
|
||||
// Blocks until the step completes or ctx is cancelled.
|
||||
func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn func(tea.Msg)) error {
|
||||
return a.RunOnceWithDisplayAndFiles(ctx, prompt, eventFn, nil)
|
||||
}
|
||||
|
||||
// RunOnceWithDisplayAndFiles executes a single agent step synchronously with
|
||||
// optional multimodal file attachments, sending intermediate display events.
|
||||
func (a *App) RunOnceWithDisplayAndFiles(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) error {
|
||||
stepCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
@@ -418,7 +565,7 @@ func (a *App) RunOnceWithDisplay(ctx context.Context, prompt string, eventFn fun
|
||||
a.cancelStep = cancel
|
||||
a.mu.Unlock()
|
||||
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, nil)
|
||||
result, err := a.executeStep(stepCtx, prompt, eventFn, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -484,11 +631,10 @@ func (a *App) drainQueue(first queueItem) {
|
||||
a.mu.Lock()
|
||||
items = append(items, a.queue...)
|
||||
a.queue = a.queue[:0] // Clear the queue
|
||||
queueLen := len(a.queue)
|
||||
a.mu.Unlock()
|
||||
|
||||
// Send queue updated event (queue is now empty)
|
||||
a.sendEvent(QueueUpdatedEvent{Length: queueLen})
|
||||
// Notify UI: all queued messages have been consumed into this batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
|
||||
// Process all collected items as a single batch
|
||||
a.runQueueBatch(items)
|
||||
@@ -501,8 +647,8 @@ func (a *App) drainQueue(first queueItem) {
|
||||
if leftover := a.opts.Kit.DrainSteer(); len(leftover) > 0 {
|
||||
a.mu.Lock()
|
||||
steerItems := make([]queueItem, len(leftover))
|
||||
for i, text := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: text}
|
||||
for i, sm := range leftover {
|
||||
steerItems[i] = queueItem{Prompt: sm.Text, Files: sm.Files}
|
||||
}
|
||||
a.queue = append(steerItems, a.queue...)
|
||||
a.mu.Unlock()
|
||||
@@ -521,6 +667,11 @@ func (a *App) drainQueue(first queueItem) {
|
||||
}
|
||||
a.mu.Unlock()
|
||||
|
||||
if hasMore {
|
||||
// Notify UI: these newly queued messages have been consumed into the next batch.
|
||||
a.sendEvent(QueueUpdatedEvent{Length: 0})
|
||||
}
|
||||
|
||||
if !hasMore {
|
||||
// No more items, we're done
|
||||
break
|
||||
@@ -568,7 +719,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// call/result pairs; only the in-progress message or tool
|
||||
// call is discarded. Sync the in-memory store to match.
|
||||
if ts := a.opts.TreeSession; ts != nil {
|
||||
a.store.Replace(ts.GetFantasyMessages())
|
||||
a.store.Replace(ts.GetLLMMessages())
|
||||
}
|
||||
a.sendEvent(StepCancelledEvent{})
|
||||
return
|
||||
@@ -587,7 +738,7 @@ func (a *App) runQueueBatch(items []queueItem) {
|
||||
// executeStep runs a single agentic step by delegating to the SDK's
|
||||
// PromptResult() (or PromptResultWithFiles for multimodal), which handles
|
||||
// session persistence, hooks, extension events, and the generation loop.
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []fantasy.FilePart) (*kit.TurnResult, error) {
|
||||
func (a *App) executeStep(ctx context.Context, prompt string, eventFn func(tea.Msg), files []kit.LLMFilePart) (*kit.TurnResult, error) {
|
||||
// Test hook: bypass SDK entirely.
|
||||
if a.opts.PromptFunc != nil {
|
||||
return a.opts.PromptFunc(ctx, prompt)
|
||||
@@ -684,8 +835,8 @@ func (a *App) executeBatch(ctx context.Context, items []queueItem, eventFn func(
|
||||
messages = append(messages, item.Prompt)
|
||||
}
|
||||
|
||||
// TODO: Handle file attachments in batch mode
|
||||
// For now, files are ignored in batch mode (rare edge case)
|
||||
// File attachments are not supported in batch mode; fall back to
|
||||
// processing only the first item that carries files.
|
||||
if hasFiles {
|
||||
// If files exist, fall back to processing just the first item with files
|
||||
for _, item := range items {
|
||||
@@ -754,6 +905,8 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
sendFn(StreamChunkEvent{Content: ev.Chunk})
|
||||
case kit.ReasoningDeltaEvent:
|
||||
sendFn(ReasoningChunkEvent{Delta: ev.Delta})
|
||||
case kit.ReasoningCompleteEvent:
|
||||
sendFn(ReasoningCompleteEvent{})
|
||||
case kit.ToolOutputEvent:
|
||||
sendFn(ToolOutputEvent{
|
||||
ToolCallID: ev.ToolCallID,
|
||||
@@ -765,6 +918,20 @@ func (a *App) subscribeSDKEvents(sendFn func(tea.Msg), stepUsageSeen *atomic.Boo
|
||||
sendFn(SteerConsumedEvent{})
|
||||
case kit.StepUsageEvent:
|
||||
a.recordStepUsage(ev, stepUsageSeen)
|
||||
case kit.PasswordPromptEvent:
|
||||
// Convert SDK PasswordPromptEvent to app PasswordPromptEvent
|
||||
// The TUI will handle this and send the response back
|
||||
responseCh := make(chan PasswordPromptResponse, 1)
|
||||
sendFn(PasswordPromptEvent{
|
||||
Prompt: ev.Prompt,
|
||||
ResponseCh: responseCh,
|
||||
})
|
||||
// Wait for TUI response and forward to SDK
|
||||
resp := <-responseCh
|
||||
ev.ResponseCh <- kit.PasswordPromptResponse{
|
||||
Password: resp.Password,
|
||||
Cancelled: resp.Cancelled,
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
@@ -795,7 +962,8 @@ func (a *App) QuitFromExtension() {
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
// appropriate renderer. In non-interactive mode it falls back to stderr with
|
||||
// a level prefix so errors are distinguishable from plain output.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
@@ -804,8 +972,16 @@ func (a *App) PrintFromExtension(level, text string) {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
// Non-interactive fallback: write to stderr with a level prefix so that
|
||||
// errors and info messages are distinguishable from plain output.
|
||||
switch level {
|
||||
case "error":
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
|
||||
default:
|
||||
fmt.Println(text)
|
||||
}
|
||||
}
|
||||
|
||||
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
|
||||
@@ -833,12 +1009,73 @@ func (a *App) NotifyModelChanged(provider, model string) {
|
||||
// NotifyWidgetUpdate sends a WidgetUpdateEvent to the TUI so it re-renders
|
||||
// extension widgets. Called from the extension context's SetWidget/RemoveWidget
|
||||
// closures. In non-interactive mode this is a no-op (widgets are TUI-only).
|
||||
//
|
||||
// Coalescing: if a WidgetUpdateEvent is already queued and not yet consumed
|
||||
// by the TUI event loop, additional calls within the same ~16 ms window are
|
||||
// dropped. This prevents fast extension tickers from flooding BubbleTea's
|
||||
// mailbox with redundant re-render triggers.
|
||||
func (a *App) NotifyWidgetUpdate() {
|
||||
// Coalesce: only one pending update at a time.
|
||||
if !a.widgetUpdatePending.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(WidgetUpdateEvent{})
|
||||
// Reset the pending flag after a short debounce so subsequent calls
|
||||
// within the same render cycle are also coalesced, but new updates
|
||||
// after the cycle are allowed through.
|
||||
go func() {
|
||||
time.Sleep(16 * time.Millisecond) // ~1 frame at 60 fps
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}()
|
||||
} else {
|
||||
// No program registered (non-interactive mode); clear the flag so
|
||||
// future calls are never permanently blocked.
|
||||
a.widgetUpdatePending.Store(false)
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyContentReload sends a ContentReloadEvent to the TUI so it refreshes
|
||||
// prompt templates and skills from their provider callbacks. Called by file
|
||||
// watchers when .md/.txt files change in prompt or skill directories.
|
||||
// In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyContentReload() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(ContentReloadEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPToolsReady sends an MCPToolsReadyEvent to the TUI so it refreshes
|
||||
// tool names and MCP tool count from provider callbacks. Called when background
|
||||
// MCP tool loading completes. In non-interactive mode this is a no-op.
|
||||
func (a *App) NotifyMCPToolsReady() {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPToolsReadyEvent{})
|
||||
}
|
||||
}
|
||||
|
||||
// NotifyMCPServerLoaded sends an MCPServerLoadedEvent to the TUI so it can
|
||||
// display a system message when a single MCP server finishes loading. Called
|
||||
// per server as background MCP tool loading progresses.
|
||||
func (a *App) NotifyMCPServerLoaded(serverName string, toolCount int, err error) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
a.mu.Unlock()
|
||||
if prog != nil {
|
||||
prog.Send(MCPServerLoadedEvent{
|
||||
ServerName: serverName,
|
||||
ToolCount: toolCount,
|
||||
Error: err,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -926,11 +1163,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
// Non-interactive fallback: render a simple framed block to stderr so
|
||||
// it is visually distinct from plain stdout output.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -939,6 +1177,10 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
// the usage widget accurate on all stop paths.
|
||||
func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool) {
|
||||
hasUsage := ev.InputTokens > 0 || ev.OutputTokens > 0 || ev.CacheReadTokens > 0 || ev.CacheWriteTokens > 0
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] recordStepUsage: hasUsage=%v input=%d output=%d cacheRead=%d cacheWrite=%d",
|
||||
hasUsage, ev.InputTokens, ev.OutputTokens, ev.CacheReadTokens, ev.CacheWriteTokens)
|
||||
}
|
||||
if !hasUsage {
|
||||
return
|
||||
}
|
||||
@@ -954,8 +1196,11 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheReadTokens),
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// Keep context fill reasonably fresh during long/partial turns.
|
||||
a.opts.UsageTracker.SetContextTokens(int(ev.InputTokens + ev.OutputTokens))
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
@@ -963,35 +1208,86 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
//
|
||||
// When sawStepUsage is true, totals were already accumulated incrementally via
|
||||
// StepUsageEvent callbacks; in that case this method only updates context fill.
|
||||
// Otherwise it falls back to TotalUsage (or estimation) to keep costs/tokens
|
||||
// visible for providers/modes that don't emit per-step usage.
|
||||
// Otherwise it falls back to TotalUsage from the API response.
|
||||
//
|
||||
// NOTE: We only use ACTUAL token counts from API responses for cost tracking.
|
||||
// Estimation is never used for costs - only API-reported tokens are accurate.
|
||||
func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt string, sawStepUsage bool) {
|
||||
if a.opts.UsageTracker == nil || result == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
if !sawStepUsage {
|
||||
if result.TotalUsage != nil && result.TotalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
// Debug logging for token tracking
|
||||
if a.opts.Debug {
|
||||
if result.TotalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult TotalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
// Provider didn't report token counts — fall back to character-based
|
||||
// estimates so the footer shows something rather than nothing.
|
||||
a.opts.UsageTracker.EstimateAndUpdateUsage(userPrompt, result.Response)
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: TotalUsage=nil")
|
||||
}
|
||||
if result.FinalUsage != nil {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult FinalUsage: input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.FinalUsage.InputTokens, result.FinalUsage.OutputTokens,
|
||||
result.FinalUsage.CacheReadTokens, result.FinalUsage.CacheCreationTokens)
|
||||
} else {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: FinalUsage=nil")
|
||||
}
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: sawStepUsage=%v", sawStepUsage)
|
||||
}
|
||||
|
||||
// --- Accumulate cost/token totals for the session ---
|
||||
// Only use actual API-reported tokens for cost tracking.
|
||||
// If sawStepUsage is true, totals were already updated via StepUsageEvent.
|
||||
// Check any token field > 0 (not just InputTokens) because cached prompts
|
||||
// can result in InputTokens=0 while OutputTokens>0 (OpenAI-compatible behavior).
|
||||
hasTotalUsage := result.TotalUsage != nil &&
|
||||
(result.TotalUsage.InputTokens > 0 ||
|
||||
result.TotalUsage.OutputTokens > 0 ||
|
||||
result.TotalUsage.CacheReadTokens > 0 ||
|
||||
result.TotalUsage.CacheCreationTokens > 0)
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: hasTotalUsage=%v", hasTotalUsage)
|
||||
}
|
||||
if !sawStepUsage && hasTotalUsage {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling UpdateUsage input=%d output=%d cacheRead=%d cacheCreate=%d",
|
||||
result.TotalUsage.InputTokens, result.TotalUsage.OutputTokens,
|
||||
result.TotalUsage.CacheReadTokens, result.TotalUsage.CacheCreationTokens)
|
||||
}
|
||||
a.opts.UsageTracker.UpdateUsage(
|
||||
int(result.TotalUsage.InputTokens),
|
||||
int(result.TotalUsage.OutputTokens),
|
||||
int(result.TotalUsage.CacheReadTokens),
|
||||
int(result.TotalUsage.CacheCreationTokens),
|
||||
)
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens: the input token count of the last API call
|
||||
// equals the number of tokens currently occupying the context window.
|
||||
// Adding OutputTokens would overstate fill since the response is not part
|
||||
// of the context that was *sent* to the model.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
// Calculate context fill from the LAST API call's usage. The context
|
||||
// window is filled by everything sent to and received from the model:
|
||||
//
|
||||
// InputTokens — non-cached input (may be small with prompt caching)
|
||||
// CacheReadTokens — input tokens served from cache
|
||||
// CacheCreationTokens — input tokens written to cache this call
|
||||
// OutputTokens — assistant output (becomes input next turn)
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can drop to near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. We must sum all four to
|
||||
// get the true context window utilization.
|
||||
//
|
||||
// We use FinalUsage (last step only), NOT TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps — and each step re-sends the full
|
||||
// conversation, so TotalUsage massively overstates the actual window fill.
|
||||
if result.FinalUsage != nil {
|
||||
u := result.FinalUsage
|
||||
contextFill := int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheCreate=%d + Output=%d)",
|
||||
contextFill, u.InputTokens, u.CacheReadTokens, u.CacheCreationTokens, u.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+81
-11
@@ -7,8 +7,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
@@ -534,7 +532,9 @@ func TestQueueLength_reflects(t *testing.T) {
|
||||
}
|
||||
|
||||
// TestRecordStepUsage_updatesTracker verifies that per-step usage updates are
|
||||
// recorded immediately (including context tokens) for stop-path correctness.
|
||||
// recorded immediately for cost tracking. Context tokens are NOT updated here
|
||||
// (only via updateUsageFromTurnResult) to avoid display jumps during multi-step
|
||||
// tool calls.
|
||||
func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
@@ -557,11 +557,9 @@ func TestRecordStepUsage_updatesTracker(t *testing.T) {
|
||||
t.Fatalf("unexpected usage update payload: in=%d out=%d cache_read=%d cache_write=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput, usage.lastUpdateCacheRead, usage.lastUpdateCacheWrite)
|
||||
}
|
||||
if usage.contextCalls != 1 {
|
||||
t.Fatalf("expected 1 context token update, got %d", usage.contextCalls)
|
||||
}
|
||||
if usage.lastContextTokens != 165 {
|
||||
t.Fatalf("expected context tokens 165, got %d", usage.lastContextTokens)
|
||||
// Context tokens should NOT be updated by recordStepUsage (only by updateUsageFromTurnResult)
|
||||
if usage.contextCalls != 0 {
|
||||
t.Fatalf("expected 0 context token updates from recordStepUsage, got %d", usage.contextCalls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,13 +572,13 @@ func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &fantasy.Usage{
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 999,
|
||||
OutputTokens: 111,
|
||||
CacheReadTokens: 7,
|
||||
CacheCreationTokens: 3,
|
||||
},
|
||||
FinalUsage: &fantasy.Usage{InputTokens: 456},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 456},
|
||||
}, "prompt", true)
|
||||
|
||||
usage.mu.Lock()
|
||||
@@ -592,7 +590,79 @@ func TestUpdateUsageFromTurnResult_skipsTotalsWhenStepUsageSeen(t *testing.T) {
|
||||
if usage.estimateCalls != 0 {
|
||||
t.Fatalf("expected no estimate update when sawStepUsage=true, got %d", usage.estimateCalls)
|
||||
}
|
||||
// Context tokens should be InputTokens only (456)
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 456 {
|
||||
t.Fatalf("expected final context tokens=456, got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
t.Fatalf("expected final context tokens=456 (InputTokens only), got calls=%d tokens=%d", usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero verifies that usage
|
||||
// is recorded when InputTokens=0 but OutputTokens>0 (OpenAI-compatible cache behavior).
|
||||
func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
// Simulate OpenAI-compatible behavior: all prompt tokens cached, InputTokens=0
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 0, // All cached - subtracted from prompt
|
||||
OutputTokens: 150, // Actual generated tokens
|
||||
CacheReadTokens: 500, // Cache hit
|
||||
CacheCreationTokens: 0,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{InputTokens: 0, OutputTokens: 150},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
if usage.updateCalls != 1 {
|
||||
t.Fatalf("expected 1 update call when InputTokens=0 but OutputTokens>0, got %d", usage.updateCalls)
|
||||
}
|
||||
if usage.lastUpdateInput != 0 || usage.lastUpdateOutput != 150 {
|
||||
t.Fatalf("expected input=0 output=150, got input=%d output=%d",
|
||||
usage.lastUpdateInput, usage.lastUpdateOutput)
|
||||
}
|
||||
if usage.lastUpdateCacheRead != 500 {
|
||||
t.Fatalf("expected cache_read=500, got %d", usage.lastUpdateCacheRead)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories verifies that
|
||||
// context window fill uses all token categories from the final API call:
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3,
|
||||
OutputTokens: 5,
|
||||
CacheReadTokens: 0,
|
||||
CacheCreationTokens: 4317,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 3, // Non-cached input (small with caching)
|
||||
OutputTokens: 5, // Assistant output
|
||||
CacheReadTokens: 0, // No cache reads on first call
|
||||
CacheCreationTokens: 4317, // System prompt + tools written to cache
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be Input + CacheRead + CacheCreate + Output = 4325
|
||||
expected := 3 + 0 + 4317 + 5
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != expected {
|
||||
t.Fatalf("expected context tokens=%d (all categories), got calls=%d tokens=%d",
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
+45
-3
@@ -1,6 +1,6 @@
|
||||
package app
|
||||
|
||||
import "charm.land/fantasy"
|
||||
import kit "github.com/mark3labs/kit/pkg/kit"
|
||||
|
||||
// StreamChunkEvent is sent by the app layer when a streaming text delta arrives
|
||||
// from the LLM. Each chunk contains an incremental portion of the response.
|
||||
@@ -16,6 +16,11 @@ type ReasoningChunkEvent struct {
|
||||
Delta string
|
||||
}
|
||||
|
||||
// ReasoningCompleteEvent is sent when reasoning/thinking is finished, after
|
||||
// the last reasoning token has been processed. The TUI uses this to freeze
|
||||
// the reasoning duration counter.
|
||||
type ReasoningCompleteEvent struct{}
|
||||
|
||||
// ToolCallStartedEvent is sent when a tool call has been parsed and is about to execute.
|
||||
// It carries the tool name and its arguments for display purposes.
|
||||
type ToolCallStartedEvent struct {
|
||||
@@ -74,6 +79,24 @@ type ToolCallContentEvent struct {
|
||||
Content string
|
||||
}
|
||||
|
||||
// PasswordPromptEvent is sent when a sudo command needs a password.
|
||||
// The TUI should display a password prompt overlay and send the result back.
|
||||
type PasswordPromptEvent struct {
|
||||
// Prompt is the message to display to the user.
|
||||
Prompt string
|
||||
// ResponseCh receives the password from the TUI.
|
||||
// The TUI must send exactly one value.
|
||||
ResponseCh chan<- PasswordPromptResponse
|
||||
}
|
||||
|
||||
// PasswordPromptResponse carries the user's password input.
|
||||
type PasswordPromptResponse struct {
|
||||
// Password is the entered password.
|
||||
Password string
|
||||
// Cancelled is true if the user cancelled the prompt.
|
||||
Cancelled bool
|
||||
}
|
||||
|
||||
// ResponseCompleteEvent is sent when the LLM produces a final (non-streaming) response.
|
||||
// In streaming mode, this may be empty if all content was delivered via StreamChunkEvents.
|
||||
type ResponseCompleteEvent struct {
|
||||
@@ -118,8 +141,8 @@ type SpinnerEvent struct {
|
||||
// MessageCreatedEvent is sent when a new message is added to the message store.
|
||||
// This allows the TUI to stay in sync with the conversation history.
|
||||
type MessageCreatedEvent struct {
|
||||
// Message is the fantasy message that was added to the store.
|
||||
Message fantasy.Message
|
||||
// Message is the message that was added to the store.
|
||||
Message kit.LLMMessage
|
||||
}
|
||||
|
||||
// CompactCompleteEvent is sent when a /compact operation finishes successfully.
|
||||
@@ -162,6 +185,25 @@ type ModelChangedEvent struct {
|
||||
// from its WidgetProvider on the next render cycle.
|
||||
type WidgetUpdateEvent struct{}
|
||||
|
||||
// ContentReloadEvent is sent when prompt templates or skills are reloaded
|
||||
// from disk (e.g. by a file watcher detecting changes). The TUI refreshes
|
||||
// its autocomplete entries and internal state from the provider callbacks.
|
||||
type ContentReloadEvent struct{}
|
||||
|
||||
// MCPToolsReadyEvent is sent when background MCP tool loading completes.
|
||||
// The TUI refreshes its tool names and MCP tool count from provider callbacks
|
||||
// so that /tools and the startup info bar reflect the loaded MCP tools.
|
||||
type MCPToolsReadyEvent struct{}
|
||||
|
||||
// MCPServerLoadedEvent is sent when a single MCP server finishes loading
|
||||
// (successfully or with error). The TUI displays a system message so users
|
||||
// see real-time progress as each server initializes.
|
||||
type MCPServerLoadedEvent struct {
|
||||
ServerName string
|
||||
ToolCount int
|
||||
Error error // nil on success
|
||||
}
|
||||
|
||||
// EditorTextSetEvent is sent when an extension calls ctx.SetEditorText to
|
||||
// pre-fill the input editor with text. The TUI handles this by setting the
|
||||
// textarea content and moving the cursor to the end.
|
||||
|
||||
@@ -3,14 +3,14 @@ package app
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// MessageStore is a thread-safe in-memory store for the conversation history.
|
||||
// On-disk persistence is handled by the TreeManager at the app/SDK layer.
|
||||
type MessageStore struct {
|
||||
mu sync.RWMutex
|
||||
messages []fantasy.Message
|
||||
messages []kit.LLMMessage
|
||||
}
|
||||
|
||||
// NewMessageStore creates an empty MessageStore.
|
||||
@@ -20,14 +20,14 @@ func NewMessageStore() *MessageStore {
|
||||
|
||||
// NewMessageStoreWithMessages creates a MessageStore pre-populated with the
|
||||
// given messages. This is used when loading an existing session at startup.
|
||||
func NewMessageStoreWithMessages(msgs []fantasy.Message) *MessageStore {
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
func NewMessageStoreWithMessages(msgs []kit.LLMMessage) *MessageStore {
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
return &MessageStore{messages: cp}
|
||||
}
|
||||
|
||||
// Add appends a single message to the store.
|
||||
func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
func (s *MessageStore) Add(msg kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.messages = append(s.messages, msg)
|
||||
@@ -36,22 +36,22 @@ func (s *MessageStore) Add(msg fantasy.Message) {
|
||||
// Replace replaces the entire message history with the given slice. This is
|
||||
// used after an agent step returns the full updated conversation (including
|
||||
// tool calls and results).
|
||||
func (s *MessageStore) Replace(msgs []fantasy.Message) {
|
||||
func (s *MessageStore) Replace(msgs []kit.LLMMessage) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(msgs))
|
||||
cp := make([]kit.LLMMessage, len(msgs))
|
||||
copy(cp, msgs)
|
||||
s.messages = cp
|
||||
}
|
||||
|
||||
// GetAll returns a snapshot copy of the current message slice.
|
||||
// The returned slice is safe to modify without affecting the store.
|
||||
func (s *MessageStore) GetAll() []fantasy.Message {
|
||||
func (s *MessageStore) GetAll() []kit.LLMMessage {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
cp := make([]fantasy.Message, len(s.messages))
|
||||
cp := make([]kit.LLMMessage, len(s.messages))
|
||||
copy(cp, s.messages)
|
||||
return cp
|
||||
}
|
||||
|
||||
@@ -3,17 +3,27 @@ package app
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// makeTextMsg builds a minimal fantasy.Message with a single TextPart.
|
||||
func makeTextMsg(role, text string) fantasy.Message {
|
||||
return fantasy.Message{
|
||||
Role: fantasy.MessageRole(role),
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: text}},
|
||||
// makeTextMsg builds a minimal kit.LLMMessage with the given role and text.
|
||||
func makeTextMsg(role, text string) kit.LLMMessage {
|
||||
return kit.LLMMessage{
|
||||
Role: kit.LLMMessageRole(role),
|
||||
Content: []kit.LLMMessagePart{kit.LLMTextPart{Text: text}},
|
||||
}
|
||||
}
|
||||
|
||||
// textOf extracts the plain text from an LLMMessage for assertions.
|
||||
func textOf(msg kit.LLMMessage) string {
|
||||
for _, part := range msg.Content {
|
||||
if tp, ok := part.(kit.LLMTextPart); ok {
|
||||
return tp.Text
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// NewMessageStore / NewMessageStoreWithMessages
|
||||
// --------------------------------------------------------------------------
|
||||
@@ -29,7 +39,7 @@ func TestNewMessageStore_empty(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
msgs := []fantasy.Message{
|
||||
msgs := []kit.LLMMessage{
|
||||
makeTextMsg("user", "hello"),
|
||||
makeTextMsg("assistant", "hi"),
|
||||
}
|
||||
@@ -42,7 +52,7 @@ func TestNewMessageStoreWithMessages_preloaded(t *testing.T) {
|
||||
// NewMessageStoreWithMessages must deep-copy the slice so that external
|
||||
// modifications don't affect the store.
|
||||
func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
msgs := []fantasy.Message{makeTextMsg("user", "hello")}
|
||||
msgs := []kit.LLMMessage{makeTextMsg("user", "hello")}
|
||||
s := NewMessageStoreWithMessages(msgs)
|
||||
|
||||
// Mutate the source slice.
|
||||
@@ -52,9 +62,8 @@ func TestNewMessageStoreWithMessages_isolatesInput(t *testing.T) {
|
||||
if len(got) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(got))
|
||||
}
|
||||
tp, ok := got[0].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", tp.Text)
|
||||
if textOf(got[0]) != "hello" {
|
||||
t.Fatalf("store was mutated by external slice change; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,9 +89,8 @@ func TestAdd_preservesOrder(t *testing.T) {
|
||||
}
|
||||
got := s.GetAll()
|
||||
for i, expected := range texts {
|
||||
tp, ok := got[i].Content[0].(fantasy.TextPart)
|
||||
if !ok || tp.Text != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, tp.Text)
|
||||
if textOf(got[i]) != expected {
|
||||
t.Fatalf("message[%d]: expected %q, got %q", i, expected, textOf(got[i]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -95,7 +103,7 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
s.Add(makeTextMsg("user", "old"))
|
||||
|
||||
replacement := []fantasy.Message{
|
||||
replacement := []kit.LLMMessage{
|
||||
makeTextMsg("user", "new1"),
|
||||
makeTextMsg("assistant", "new2"),
|
||||
}
|
||||
@@ -105,25 +113,22 @@ func TestReplace_swapsHistory(t *testing.T) {
|
||||
t.Fatalf("expected 2 messages after replace, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp0, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
tp1, _ := got[1].Content[0].(fantasy.TextPart)
|
||||
if tp0.Text != "new1" || tp1.Text != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", tp0.Text, tp1.Text)
|
||||
if textOf(got[0]) != "new1" || textOf(got[1]) != "new2" {
|
||||
t.Fatalf("unexpected messages after replace: %q %q", textOf(got[0]), textOf(got[1]))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace must deep-copy the incoming slice.
|
||||
func TestReplace_isolatesInput(t *testing.T) {
|
||||
s := NewMessageStore()
|
||||
replacement := []fantasy.Message{makeTextMsg("user", "original")}
|
||||
replacement := []kit.LLMMessage{makeTextMsg("user", "original")}
|
||||
s.Replace(replacement)
|
||||
|
||||
replacement[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", tp.Text)
|
||||
if textOf(got[0]) != "original" {
|
||||
t.Fatalf("store was mutated by external slice change after Replace; got %q", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,9 +145,8 @@ func TestGetAll_returnsCopy(t *testing.T) {
|
||||
got[0] = makeTextMsg("user", "mutated")
|
||||
|
||||
internal := s.GetAll()
|
||||
tp, _ := internal[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", tp.Text)
|
||||
if textOf(internal[0]) != "hello" {
|
||||
t.Fatalf("GetAll returned non-copy; store was mutated to %q", textOf(internal[0]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +183,8 @@ func TestClear_allowsSubsequentAdds(t *testing.T) {
|
||||
t.Fatalf("expected 1 message after Clear+Add, got %d", s.Len())
|
||||
}
|
||||
got := s.GetAll()
|
||||
tp, _ := got[0].Content[0].(fantasy.TextPart)
|
||||
if tp.Text != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", tp.Text)
|
||||
if textOf(got[0]) != "after" {
|
||||
t.Fatalf("expected %q, got %q", "after", textOf(got[0]))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,10 @@ type UsageUpdater interface {
|
||||
// the provider does not return exact counts.
|
||||
EstimateAndUpdateUsage(inputText, outputText string)
|
||||
// SetContextTokens records the approximate current context window fill
|
||||
// level. This should be the final API call's input+output tokens (from
|
||||
// FinalResponse.Usage), NOT the aggregate TotalUsage.
|
||||
// level. This should be the sum of ALL token categories from the last
|
||||
// API call: InputTokens + CacheReadTokens + CacheCreationTokens +
|
||||
// OutputTokens. With Anthropic prompt caching, InputTokens can be
|
||||
// near-zero while CacheReadTokens holds the bulk of the context.
|
||||
SetContextTokens(tokens int)
|
||||
}
|
||||
|
||||
@@ -67,10 +69,6 @@ type Options struct {
|
||||
// Debug enables verbose debug logging.
|
||||
Debug bool
|
||||
|
||||
// CompactMode selects the compact renderer instead of the block renderer for
|
||||
// message formatting.
|
||||
CompactMode bool
|
||||
|
||||
// UsageTracker is an optional callback for recording token usage after each
|
||||
// agent step. When non-nil, the app layer calls UpdateUsage (or
|
||||
// EstimateAndUpdateUsage as a fallback) using the usage data returned by the
|
||||
|
||||
@@ -43,13 +43,30 @@ type OpenAICredentials struct {
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
// oauthTokenExpired reports whether an OAuth token with the given type and
|
||||
// expiry unix timestamp is past its expiry. Returns false for API key
|
||||
// credentials or when no expiry is set.
|
||||
func oauthTokenExpired(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= expiresAt
|
||||
}
|
||||
|
||||
// oauthTokenNeedsRefresh reports whether an OAuth token will expire within the
|
||||
// next 5 minutes, allowing proactive refresh before it becomes invalid.
|
||||
// Returns false for API key credentials or when no expiry is set.
|
||||
func oauthTokenNeedsRefresh(credType string, expiresAt int64) bool {
|
||||
if credType != "oauth" || expiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (expiresAt - 300) // 5 minutes buffer
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -57,19 +74,13 @@ func (c *AnthropicCredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *AnthropicCredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsExpired checks if the OAuth token is expired based on the ExpiresAt timestamp.
|
||||
// Returns false for API key authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) IsExpired() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= c.ExpiresAt
|
||||
return oauthTokenExpired(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the OAuth token needs refresh, returning true if the token
|
||||
@@ -77,10 +88,7 @@ func (c *OpenAICredentials) IsExpired() bool {
|
||||
// to avoid authentication failures during operations. Returns false for API key
|
||||
// authentication or if no expiration is set.
|
||||
func (c *OpenAICredentials) NeedsRefresh() bool {
|
||||
if c.Type != "oauth" || c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().Unix() >= (c.ExpiresAt - 300) // 5 minutes buffer
|
||||
return oauthTokenNeedsRefresh(c.Type, c.ExpiresAt)
|
||||
}
|
||||
|
||||
// CredentialManager handles secure storage and retrieval of authentication credentials.
|
||||
|
||||
@@ -428,6 +428,10 @@ type PreviousCompaction struct {
|
||||
ModifiedFiles []string
|
||||
}
|
||||
|
||||
// StreamCallback is called for each chunk of text during streaming compaction.
|
||||
// Return a non-nil error to cancel the stream.
|
||||
type StreamCallback func(delta string) error
|
||||
|
||||
// Compact summarises older messages using the LLM, returning the compaction
|
||||
// result and a new message slice (summary message + preserved recent
|
||||
// messages).
|
||||
@@ -442,6 +446,8 @@ type PreviousCompaction struct {
|
||||
//
|
||||
// prev carries file tracking from a previous compaction for cumulative
|
||||
// tracking. Pass nil if there is no prior compaction.
|
||||
// onChunk is an optional callback for streaming summary text. Pass nil for
|
||||
// non-streaming compaction.
|
||||
func Compact(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -449,6 +455,7 @@ func Compact(
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
prev *PreviousCompaction,
|
||||
onChunk StreamCallback,
|
||||
) (*CompactionResult, []fantasy.Message, error) {
|
||||
opts.defaults()
|
||||
|
||||
@@ -487,9 +494,9 @@ func Compact(
|
||||
var err error
|
||||
|
||||
if IsSplitTurn(messages, cutPoint) {
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions)
|
||||
summaryText, err = compactSplitTurn(ctx, model, oldMessages, messages, cutPoint, opts, customInstructions, onChunk)
|
||||
} else {
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions)
|
||||
summaryText, err = compactNormal(ctx, model, oldMessages, opts, customInstructions, onChunk)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -527,15 +534,17 @@ func Compact(
|
||||
}
|
||||
|
||||
// compactNormal generates a summary for a clean turn-boundary cut.
|
||||
// If onChunk is provided, text deltas are streamed to it.
|
||||
func compactNormal(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
oldMessages []fantasy.Message,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
conversationText := serializeMessages(oldMessages)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions)
|
||||
return generateSummary(ctx, model, conversationText, opts, customInstructions, onChunk)
|
||||
}
|
||||
|
||||
// compactSplitTurn handles the case where the cut point lands mid-turn.
|
||||
@@ -546,6 +555,7 @@ func compactNormal(
|
||||
//
|
||||
// The merged result preserves context from both the older history and the
|
||||
// beginning of the current long turn.
|
||||
// If onChunk is provided, both summaries and the separator are streamed.
|
||||
func compactSplitTurn(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
@@ -554,6 +564,7 @@ func compactSplitTurn(
|
||||
cutPoint int,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
// Find where the split turn starts.
|
||||
turnStart := findTurnStart(allMessages, cutPoint)
|
||||
@@ -573,12 +584,19 @@ func compactSplitTurn(
|
||||
// Generate history summary if there are complete turns before the split.
|
||||
if len(historyMessages) >= 2 {
|
||||
historySummary, err = generateSummary(ctx, model,
|
||||
serializeMessages(historyMessages), opts, "")
|
||||
serializeMessages(historyMessages), opts, "", onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn history summary failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stream the separator between history and turn prefix summaries.
|
||||
if onChunk != nil && historySummary != "" {
|
||||
if err := onChunk("\n\n---\n\n## Current Turn (in progress)\n\n"); err != nil {
|
||||
return "", fmt.Errorf("streaming separator failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Generate turn prefix summary.
|
||||
turnPrefixText := serializeMessages(turnPrefixMessages)
|
||||
turnPrefixPrompt := "The messages above are the BEGINNING of a long turn that was split. " +
|
||||
@@ -588,16 +606,10 @@ func compactSplitTurn(
|
||||
turnPrefixPrompt += "\n\nAdditional instructions: " + customInstructions
|
||||
}
|
||||
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: turnPrefixText + "\n\n" + turnPrefixPrompt,
|
||||
})
|
||||
turnPrefixSummary, err := generateSummary(ctx, model, turnPrefixText, opts, turnPrefixPrompt, onChunk)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("split turn prefix summary failed: %w", err)
|
||||
}
|
||||
turnPrefixSummary := result.Response.Content.Text()
|
||||
|
||||
// Merge the two summaries.
|
||||
if historySummary != "" && turnPrefixSummary != "" {
|
||||
@@ -610,12 +622,14 @@ func compactSplitTurn(
|
||||
}
|
||||
|
||||
// generateSummary calls the LLM to produce a structured summary.
|
||||
// If onChunk is provided, the summary is streamed using Agent.Stream().
|
||||
func generateSummary(
|
||||
ctx context.Context,
|
||||
model fantasy.LanguageModel,
|
||||
conversationText string,
|
||||
opts CompactionOptions,
|
||||
customInstructions string,
|
||||
onChunk StreamCallback,
|
||||
) (string, error) {
|
||||
userPrompt := opts.SummaryPrompt
|
||||
if userPrompt == "" {
|
||||
@@ -628,8 +642,31 @@ func generateSummary(
|
||||
summaryAgent := fantasy.NewAgent(model,
|
||||
fantasy.WithSystemPrompt(defaultSystemPrompt),
|
||||
)
|
||||
|
||||
prompt := conversationText + "\n\n" + userPrompt
|
||||
|
||||
// Use streaming if onChunk is provided.
|
||||
if onChunk != nil {
|
||||
var fullText strings.Builder
|
||||
_, err := summaryAgent.Stream(ctx, fantasy.AgentStreamCall{
|
||||
Prompt: prompt,
|
||||
OnTextDelta: func(_, delta string) error {
|
||||
if delta != "" {
|
||||
fullText.WriteString(delta)
|
||||
return onChunk(delta)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation (streaming) failed: %w", err)
|
||||
}
|
||||
return fullText.String(), nil
|
||||
}
|
||||
|
||||
// Non-streaming path.
|
||||
result, err := summaryAgent.Generate(ctx, fantasy.AgentCall{
|
||||
Prompt: conversationText + "\n\n" + userPrompt,
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("compaction summarisation failed: %w", err)
|
||||
|
||||
@@ -243,7 +243,7 @@ func TestCompact_TooFewMessages(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleUser, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -262,7 +262,7 @@ func TestCompact_WithinBudget(t *testing.T) {
|
||||
makeTextMessageN(fantasy.MessageRoleAssistant, 400),
|
||||
}
|
||||
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil)
|
||||
result, newMsgs, err := Compact(context.TODO(), nil, msgs, CompactionOptions{}, "", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
+115
-17
@@ -22,6 +22,20 @@ type MCPServerConfig struct {
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
|
||||
// OAuth configuration for remote servers that don't support dynamic
|
||||
// client registration (e.g. GitHub). When OAuthClientID is set, it is
|
||||
// passed directly to the transport's OAuthConfig instead of relying on
|
||||
// dynamic registration.
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
|
||||
// InProcessServer holds a live *server.MCPServer for in-process transport.
|
||||
// When set (and Type is "inprocess"), the connection pool creates an
|
||||
// in-process client instead of spawning a subprocess or making HTTP calls.
|
||||
// This field is never serialized — it is only used programmatically via the SDK.
|
||||
InProcessServer any `json:"-" yaml:"-"`
|
||||
|
||||
// Legacy fields for backward compatibility
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
@@ -35,13 +49,16 @@ type MCPServerConfig struct {
|
||||
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
// First try to unmarshal as the new format
|
||||
type newFormat struct {
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -66,6 +83,9 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = newConfig.Headers
|
||||
s.AllowedTools = newConfig.AllowedTools
|
||||
s.ExcludedTools = newConfig.ExcludedTools
|
||||
s.OAuthClientID = newConfig.OAuthClientID
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,11 +177,28 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParams defines generation parameter defaults that can be attached
|
||||
// to individual models. These act as model-level defaults — CLI flags and
|
||||
// global config values take precedence when explicitly set.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
@@ -169,6 +206,11 @@ type CustomModelConfig struct {
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
|
||||
// Generation parameter defaults for this model.
|
||||
// These are applied when the user hasn't explicitly set the corresponding
|
||||
// CLI flag or global config value.
|
||||
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
@@ -191,18 +233,19 @@ type Config struct {
|
||||
Model string `json:"model,omitempty" yaml:"model,omitempty"`
|
||||
MaxSteps int `json:"max-steps,omitempty" yaml:"max-steps,omitempty"`
|
||||
Debug bool `json:"debug,omitempty" yaml:"debug,omitempty"`
|
||||
Compact bool `json:"compact,omitempty" yaml:"compact,omitempty"`
|
||||
SystemPrompt string `json:"system-prompt,omitempty" yaml:"system-prompt,omitempty"`
|
||||
ProviderAPIKey string `json:"provider-api-key,omitempty" yaml:"provider-api-key,omitempty"`
|
||||
ProviderURL string `json:"provider-url,omitempty" yaml:"provider-url,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty" yaml:"stream,omitempty"`
|
||||
Theme any `json:"theme" yaml:"theme"`
|
||||
// Model generation parameters
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
MaxTokens int `json:"max-tokens,omitempty" yaml:"max-tokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"top-p,omitempty" yaml:"top-p,omitempty"`
|
||||
TopK *int32 `json:"top-k,omitempty" yaml:"top-k,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequency-penalty,omitempty" yaml:"frequency-penalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presence-penalty,omitempty" yaml:"presence-penalty,omitempty"`
|
||||
StopSequences []string `json:"stop-sequences,omitempty" yaml:"stop-sequences,omitempty"`
|
||||
|
||||
// Thinking / extended reasoning
|
||||
ThinkingLevel string `json:"thinking-level,omitempty" yaml:"thinking-level,omitempty"`
|
||||
@@ -216,6 +259,12 @@ type Config struct {
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
|
||||
// Per-model generation parameter overrides. Keys are "provider/model" strings
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
|
||||
// settings act as model-level defaults — CLI flags and global config values
|
||||
// take precedence when explicitly set.
|
||||
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
@@ -234,11 +283,18 @@ func (s *MCPServerConfig) GetTransportType() string {
|
||||
return "stdio"
|
||||
case "remote":
|
||||
return "streamable"
|
||||
case "inprocess":
|
||||
return "inprocess"
|
||||
default:
|
||||
return s.Type
|
||||
}
|
||||
}
|
||||
|
||||
// Programmatic in-process server detection.
|
||||
if s.InProcessServer != nil {
|
||||
return "inprocess"
|
||||
}
|
||||
|
||||
// Backward compatibility: infer transport type
|
||||
if len(s.Command) > 0 {
|
||||
return "stdio"
|
||||
@@ -269,8 +325,12 @@ func (c *Config) Validate() error {
|
||||
if serverConfig.URL == "" {
|
||||
return fmt.Errorf("server %s: url is required for %s transport", serverName, transport)
|
||||
}
|
||||
case "inprocess":
|
||||
if serverConfig.InProcessServer == nil {
|
||||
return fmt.Errorf("server %s: InProcessServer is required for inprocess transport", serverName)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable", serverName, transport)
|
||||
return fmt.Errorf("server %s: unsupported transport type '%s'. Supported types: stdio, sse, streamable, inprocess", serverName, transport)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -364,16 +424,55 @@ mcpServers:
|
||||
# debug: false # Enable debug logging
|
||||
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
|
||||
|
||||
# Model generation parameters (all optional)
|
||||
# Model generation parameters (all optional, apply globally to all models)
|
||||
# max-tokens: 4096 # Maximum tokens in response
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
# top-k: 40 # Top K sampling
|
||||
# frequency-penalty: 0.0 # Penalize frequent tokens (0.0-2.0)
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# Per-model generation parameter overrides (apply to specific models)
|
||||
# These act as model-level defaults — CLI flags and global settings above take precedence.
|
||||
# Keys are "provider/model" strings matching the model you use.
|
||||
# modelSettings:
|
||||
# anthropic/claude-sonnet-4-5-20250929:
|
||||
# temperature: 0.3
|
||||
# maxTokens: 8192
|
||||
# openai/gpt-4o:
|
||||
# temperature: 0.7
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# frequencyPenalty: 0.1
|
||||
# presencePenalty: 0.1
|
||||
# anthropic/claude-opus-4-6:
|
||||
# thinkingLevel: "high"
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
# Custom model definitions (under custom/ provider)
|
||||
# customModels:
|
||||
# my-local-llama:
|
||||
# name: "Local Llama 3"
|
||||
# baseUrl: "http://localhost:8080/v1"
|
||||
# family: "llama"
|
||||
# temperature: true
|
||||
# cost:
|
||||
# input: 0.0
|
||||
# output: 0.0
|
||||
# limit:
|
||||
# context: 131072
|
||||
# output: 8192
|
||||
# params: # Generation parameter defaults for this model
|
||||
# temperature: 0.8
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# systemPrompt: "You are a helpful local assistant."
|
||||
`
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
@@ -403,10 +502,9 @@ func FilepathOr[T any](key string, value *T) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filepath.Join(home, absPath[2:])
|
||||
absPath = filepath.Join(home, absPath[2:])
|
||||
}
|
||||
if !filepath.IsAbs(absPath) {
|
||||
// base := GetConfigPath()
|
||||
base := configPath
|
||||
if base == "" {
|
||||
fmt.Fprintf(os.Stderr, "unable to build relative path to config.")
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMCPServerConfig_NewFormat(t *testing.T) {
|
||||
@@ -542,3 +544,86 @@ func TestEnsureConfigExistsWhenFileExists(t *testing.T) {
|
||||
t.Error("Existing config file was modified when it shouldn't have been")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_JSON(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://api.githubcopilot.com/mcp/",
|
||||
"oauthClientId": "Ov23liXXXXXXXXXXXXXX",
|
||||
"oauthClientSecret": "secret123",
|
||||
"oauthScopes": ["read:user", "repo"]
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.URL != "https://api.githubcopilot.com/mcp/" {
|
||||
t.Errorf("Expected URL, got %q", cfg.URL)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "secret123" {
|
||||
t.Errorf("Expected OAuthClientSecret 'secret123', got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_YAML(t *testing.T) {
|
||||
yamlData := `
|
||||
type: remote
|
||||
url: https://api.githubcopilot.com/mcp/
|
||||
oauthClientId: "Ov23liXXXXXXXXXXXXXX"
|
||||
oauthScopes:
|
||||
- read:user
|
||||
- repo
|
||||
`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal YAML: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
// Verify that omitting OAuth fields still works (backward compat).
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://example.com/mcp"
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.OAuthClientID != "" {
|
||||
t.Errorf("Expected empty OAuthClientID, got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "" {
|
||||
t.Errorf("Expected empty OAuthClientSecret, got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 0 {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
+181
-24
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,10 +19,18 @@ import (
|
||||
// It receives tool call ID, tool name, output chunk, and whether it's stderr.
|
||||
type ToolOutputCallback func(toolCallID, toolName, chunk string, isStderr bool)
|
||||
|
||||
// PasswordPromptCallback is the signature for password prompts.
|
||||
// It receives a prompt message and returns the password and whether it was cancelled.
|
||||
type PasswordPromptCallback func(prompt string) (password string, cancelled bool)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions.
|
||||
type contextKey string
|
||||
|
||||
const toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
const (
|
||||
toolOutputCallbackKey contextKey = "toolOutputCallback"
|
||||
sudoPasswordKey contextKey = "sudoPassword"
|
||||
passwordPromptKey contextKey = "passwordPrompt"
|
||||
)
|
||||
|
||||
// ContextWithToolOutputCallback returns a new context with the tool output callback set.
|
||||
func ContextWithToolOutputCallback(ctx context.Context, callback ToolOutputCallback) context.Context {
|
||||
@@ -36,23 +45,39 @@ func toolOutputCallbackFromContext(ctx context.Context) ToolOutputCallback {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithPasswordPrompt returns a new context with the password prompt callback set.
|
||||
// This allows the TUI to show a modal password prompt when sudo needs a password.
|
||||
func ContextWithPasswordPrompt(ctx context.Context, callback PasswordPromptCallback) context.Context {
|
||||
return context.WithValue(ctx, passwordPromptKey, callback)
|
||||
}
|
||||
|
||||
// passwordPromptFromContext retrieves the password prompt callback from context.
|
||||
func passwordPromptFromContext(ctx context.Context) PasswordPromptCallback {
|
||||
if cb, ok := ctx.Value(passwordPromptKey).(PasswordPromptCallback); ok {
|
||||
return cb
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContextWithSudoPassword returns a new context with the sudo password set.
|
||||
// When present, the bash tool will use sudo -S to pipe this password to sudo commands.
|
||||
func ContextWithSudoPassword(ctx context.Context, password string) context.Context {
|
||||
return context.WithValue(ctx, sudoPasswordKey, password)
|
||||
}
|
||||
|
||||
// sudoPasswordFromContext retrieves the sudo password from context.
|
||||
func sudoPasswordFromContext(ctx context.Context) string {
|
||||
if pw, ok := ctx.Value(sudoPasswordKey).(string); ok {
|
||||
return pw
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const defaultBashTimeout = 120 * time.Second
|
||||
const maxBashTimeout = 600 * time.Second
|
||||
|
||||
var bannedCommands = []string{
|
||||
"alias ", "bg ", "bind ", "builtin ",
|
||||
"caller ", "command ", "compgen ",
|
||||
"complete ", "compopt ", "coproc ",
|
||||
"dirs ", "disown ", "enable ",
|
||||
"fc ", "fg ", "hash ", "help ",
|
||||
"history ", "jobs ", "kill ",
|
||||
"logout ", "mapfile ", "popd ",
|
||||
"pushd ", "readonly ", "select ",
|
||||
"set ", "shopt ", "source ",
|
||||
"suspend ", "times ", "trap ",
|
||||
"type ", "typeset ", "ulimit ",
|
||||
"umask ", "unalias ", "wait ",
|
||||
}
|
||||
// bannedCmdRe matches bash builtin commands that are not allowed for security reasons.
|
||||
var bannedCmdRe = regexp.MustCompile(`^(alias|bg|bind|builtin|caller|command|compgen|complete|compopt|coproc|dirs|disown|enable|fc|fg|hash|help|history|jobs|kill|logout|mapfile|popd|pushd|readonly|select|set|shopt|source|suspend|times|trap|type|typeset|ulimit|umask|unalias|wait)\s`)
|
||||
|
||||
type bashArgs struct {
|
||||
Command string `json:"command"`
|
||||
@@ -84,6 +109,66 @@ func NewBashTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// sudoCommandRe matches sudo commands that need to be rewritten for -S mode.
|
||||
// It matches "sudo" as a word boundary, optionally preceded by environment variables.
|
||||
var sudoCommandRe = regexp.MustCompile(`(?i)(^|[&|;|]|\|\||&&)\s*(\w+=\S+\s+)?\bsudo\b`)
|
||||
|
||||
// truncateCommand truncates a long command for display.
|
||||
func truncateCommand(cmd string, maxLen int) string {
|
||||
if len(cmd) <= maxLen {
|
||||
return cmd
|
||||
}
|
||||
return cmd[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
// rewriteSudoForStdin rewrites sudo commands to use -S -p ” for stdin password input.
|
||||
// It transforms: sudo cmd → sudo -S -p ” cmd
|
||||
func rewriteSudoForStdin(command string) string {
|
||||
// Find all matches and their positions
|
||||
matches := sudoCommandRe.FindAllStringIndex(command, -1)
|
||||
if matches == nil {
|
||||
return command
|
||||
}
|
||||
|
||||
// Build result from end to start to preserve indices
|
||||
result := command
|
||||
for i := len(matches) - 1; i >= 0; i-- {
|
||||
match := matches[i]
|
||||
start, end := match[0], match[1]
|
||||
matchedText := result[start:end]
|
||||
|
||||
// Extract just the "sudo" part (after any prefix)
|
||||
sudoIdx := strings.Index(strings.ToLower(matchedText), "sudo")
|
||||
if sudoIdx == -1 {
|
||||
continue
|
||||
}
|
||||
prefix := matchedText[:sudoIdx]
|
||||
sudoPart := matchedText[sudoIdx:]
|
||||
|
||||
// Check if the text immediately after "sudo" in the result contains -S
|
||||
afterSudo := result[end:]
|
||||
if strings.HasPrefix(strings.TrimLeft(afterSudo, " \t"), "-S") {
|
||||
// Already has -S flag, skip
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert -S -p '' after "sudo"
|
||||
newSudo := strings.Replace(sudoPart, "sudo", "sudo -S -p ''", 1)
|
||||
result = result[:start] + prefix + newSudo + result[end:]
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// SudoPasswordRequiredResult is a special marker that indicates sudo needs a password.
|
||||
// This is stored in tool response metadata to signal the TUI to prompt for password.
|
||||
const SudoPasswordRequiredMetadata = `{"sudo_password_required":true}`
|
||||
|
||||
// IsSudoPasswordRequiredResult checks if a tool response indicates sudo password is needed.
|
||||
func IsSudoPasswordRequiredResult(resp fantasy.ToolResponse) bool {
|
||||
return resp.Metadata == SudoPasswordRequiredMetadata
|
||||
}
|
||||
|
||||
func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fantasy.ToolResponse, error) {
|
||||
var args bashArgs
|
||||
if err := parseArgs(call.Input, &args); err != nil {
|
||||
@@ -94,10 +179,8 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
// Check for banned commands
|
||||
for _, banned := range bannedCommands {
|
||||
if strings.HasPrefix(args.Command, banned) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
if bannedCmdRe.MatchString(args.Command) {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("command '%s' is not allowed", args.Command)), nil
|
||||
}
|
||||
|
||||
// Determine timeout
|
||||
@@ -110,7 +193,47 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", args.Command)
|
||||
// Check for sudo password in context or environment
|
||||
sudoPassword := sudoPasswordFromContext(ctx)
|
||||
if sudoPassword == "" {
|
||||
sudoPassword = os.Getenv("SUDO_PASSWORD")
|
||||
}
|
||||
command := args.Command
|
||||
|
||||
// If command contains sudo and we don't have a password, check if sudo needs one
|
||||
if sudoPassword == "" && sudoCommandRe.MatchString(command) {
|
||||
// Check if sudo credentials are cached using sudo -n (non-interactive)
|
||||
testCmd := exec.CommandContext(cmdCtx, "sudo", "-n", "true")
|
||||
testCmd.Dir = workDir
|
||||
if err := testCmd.Run(); err != nil {
|
||||
// Sudo needs a password - try to prompt via callback
|
||||
if promptCallback := passwordPromptFromContext(ctx); promptCallback != nil {
|
||||
pw, cancelled := promptCallback("Sudo password required for: " + truncateCommand(args.Command, 60))
|
||||
if cancelled {
|
||||
return fantasy.NewTextErrorResponse("sudo password prompt cancelled"), nil
|
||||
}
|
||||
if pw == "" {
|
||||
return fantasy.NewTextErrorResponse("no sudo password provided"), nil
|
||||
}
|
||||
sudoPassword = pw
|
||||
command = rewriteSudoForStdin(command)
|
||||
} else {
|
||||
// No callback available - return error with helpful message
|
||||
return fantasy.NewTextErrorResponse(
|
||||
"This command requires sudo access. " +
|
||||
"Please run 'sudo -v' in your terminal first to cache credentials, " +
|
||||
"or set the SUDO_PASSWORD environment variable."), nil
|
||||
}
|
||||
}
|
||||
// Credentials are cached or password was provided, proceed
|
||||
}
|
||||
|
||||
// If we have a sudo password, rewrite the command to use sudo -S
|
||||
if sudoPassword != "" && sudoCommandRe.MatchString(command) {
|
||||
command = rewriteSudoForStdin(command)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(cmdCtx, "bash", "-c", command)
|
||||
if workDir != "" {
|
||||
cmd.Dir = workDir
|
||||
}
|
||||
@@ -128,18 +251,18 @@ func executeBash(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
|
||||
if outputCallback != nil {
|
||||
// Streaming mode: use pipes to capture output as it arrives
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback)
|
||||
return executeBashStreaming(cmdCtx, call, cmd, outputCallback, sudoPassword)
|
||||
}
|
||||
|
||||
// Non-streaming mode: collect all output at once (original behavior)
|
||||
return executeBashBuffered(cmdCtx, call, cmd)
|
||||
return executeBashBuffered(cmdCtx, call, cmd, sudoPassword)
|
||||
}
|
||||
|
||||
// executeBashBuffered collects all output before returning (original behavior).
|
||||
// It uses explicit pipes (not cmd.Stdout) so that cmd.WaitDelay can forcibly
|
||||
// close them when grandchild processes hold pipe handles open after the
|
||||
// direct child exits.
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd) (fantasy.ToolResponse, error) {
|
||||
func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -149,10 +272,27 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe and write the password
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Read pipes concurrently
|
||||
var wg sync.WaitGroup
|
||||
var stdout, stderr strings.Builder
|
||||
@@ -194,7 +334,7 @@ func executeBashBuffered(cmdCtx context.Context, call fantasy.ToolCall, cmd *exe
|
||||
}
|
||||
|
||||
// executeBashStreaming streams output as it arrives via the callback.
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback) (fantasy.ToolResponse, error) {
|
||||
func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *exec.Cmd, outputCallback ToolOutputCallback, sudoPassword string) (fantasy.ToolResponse, error) {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdout pipe"), nil
|
||||
@@ -204,11 +344,28 @@ func executeBashStreaming(cmdCtx context.Context, call fantasy.ToolCall, cmd *ex
|
||||
return fantasy.NewTextErrorResponse("failed to create stderr pipe"), nil
|
||||
}
|
||||
|
||||
// If we have a sudo password, create a stdin pipe
|
||||
var stdinPipe io.WriteCloser
|
||||
if sudoPassword != "" {
|
||||
stdinPipe, err = cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse("failed to create stdin pipe"), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Start command execution
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to start command: %v", err)), nil
|
||||
}
|
||||
|
||||
// Write password to stdin if needed, then close stdin
|
||||
if sudoPassword != "" && stdinPipe != nil {
|
||||
go func() {
|
||||
defer func() { _ = stdinPipe.Close() }()
|
||||
_, _ = io.WriteString(stdinPipe, sudoPassword+"\n")
|
||||
}()
|
||||
}
|
||||
|
||||
// Stream stdout and stderr concurrently
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex
|
||||
|
||||
@@ -127,3 +127,72 @@ func TestBash_EmptyCommand(t *testing.T) {
|
||||
t.Fatal("expected error for empty command")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteSudoForStdin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple sudo",
|
||||
input: "sudo apt update",
|
||||
expected: "sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo with env var",
|
||||
input: "DEBIAN_FRONTEND=noninteractive sudo apt update",
|
||||
expected: "DEBIAN_FRONTEND=noninteractive sudo -S -p '' apt update",
|
||||
},
|
||||
{
|
||||
name: "sudo in pipeline",
|
||||
input: "echo test | sudo tee /etc/test.conf",
|
||||
expected: "echo test | sudo -S -p '' tee /etc/test.conf",
|
||||
},
|
||||
{
|
||||
name: "sudo after &&",
|
||||
input: "apt update && sudo apt upgrade",
|
||||
expected: "apt update && sudo -S -p '' apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "already has -S flag",
|
||||
input: "sudo -S apt update",
|
||||
expected: "sudo -S apt update",
|
||||
},
|
||||
{
|
||||
name: "no sudo",
|
||||
input: "apt update && apt upgrade",
|
||||
expected: "apt update && apt upgrade",
|
||||
},
|
||||
{
|
||||
name: "sudo in string (should not match)",
|
||||
input: "echo 'use sudo carefully'",
|
||||
expected: "echo 'use sudo carefully'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := rewriteSudoForStdin(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("rewriteSudoForStdin(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSudoPasswordFromContext(t *testing.T) {
|
||||
// Test with password in context
|
||||
ctx := ContextWithSudoPassword(context.Background(), "secret123")
|
||||
pw := sudoPasswordFromContext(ctx)
|
||||
if pw != "secret123" {
|
||||
t.Errorf("expected password 'secret123', got %q", pw)
|
||||
}
|
||||
|
||||
// Test without password
|
||||
ctx = context.Background()
|
||||
pw = sudoPasswordFromContext(ctx)
|
||||
if pw != "" {
|
||||
t.Errorf("expected empty password, got %q", pw)
|
||||
}
|
||||
}
|
||||
|
||||
+1
-20
@@ -67,7 +67,7 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
}
|
||||
|
||||
if info.IsDir() {
|
||||
return readDirectory(absPath)
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("'%s' is a directory, not a file. Use the ls tool to list directory contents.", args.Path)), nil
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(absPath)
|
||||
@@ -116,25 +116,6 @@ func executeRead(ctx context.Context, call fantasy.ToolCall, workDir string) (fa
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
func readDirectory(absPath string) (fantasy.ToolResponse, error) {
|
||||
entries, err := os.ReadDir(absPath)
|
||||
if err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("failed to read directory: %v", err)), nil
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
if entry.IsDir() {
|
||||
name += "/"
|
||||
}
|
||||
result.WriteString(name + "\n")
|
||||
}
|
||||
|
||||
tr := truncateHead(result.String(), 500, defaultMaxBytes)
|
||||
return fantasy.NewTextResponse(tr.Content), nil
|
||||
}
|
||||
|
||||
// resolvePathWithWorkDir resolves a path to an absolute path relative to the
|
||||
// given workDir. If workDir is empty, os.Getwd() is used.
|
||||
func resolvePathWithWorkDir(path, workDir string) (string, error) {
|
||||
|
||||
+31
-38
@@ -28,14 +28,14 @@ type SubagentSpawnResult struct {
|
||||
// SubagentSpawnFunc is a callback that spawns an in-process subagent. The
|
||||
// parent Kit instance injects this into the context so the core tool can
|
||||
// call back without importing pkg/kit (which would create a cycle).
|
||||
// The toolCallID parameter is the LLM-assigned ID of the spawn_subagent
|
||||
// The toolCallID parameter is the LLM-assigned ID of the subagent
|
||||
// tool call, enabling the parent to correlate subagent events.
|
||||
type SubagentSpawnFunc func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error)
|
||||
|
||||
type subagentCtxKey struct{}
|
||||
|
||||
// WithSubagentSpawner stores a spawn function in the context so that the
|
||||
// spawn_subagent core tool can create in-process subagents.
|
||||
// subagent core tool can create in-process subagents.
|
||||
func WithSubagentSpawner(ctx context.Context, fn SubagentSpawnFunc) context.Context {
|
||||
return context.WithValue(ctx, subagentCtxKey{}, fn)
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func getSubagentSpawner(ctx context.Context) SubagentSpawnFunc {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// spawn_subagent tool
|
||||
// subagent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type subagentArgs struct {
|
||||
@@ -59,11 +59,11 @@ type subagentArgs struct {
|
||||
TimeoutSeconds int `json:"timeout_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// NewSubagentTool creates the spawn_subagent core tool.
|
||||
// NewSubagentTool creates the subagent core tool.
|
||||
func NewSubagentTool(opts ...ToolOption) fantasy.AgentTool {
|
||||
return &coreTool{
|
||||
info: fantasy.ToolInfo{
|
||||
Name: "spawn_subagent",
|
||||
Name: "subagent",
|
||||
Description: `Spawn a subagent to perform a task autonomously.
|
||||
|
||||
The subagent runs as a separate in-process Kit instance with full tool access
|
||||
@@ -130,13 +130,22 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
), fmt.Errorf("no subagent spawner in context")
|
||||
}
|
||||
|
||||
// Detach from the parent's deadline so the subagent gets its own
|
||||
// independent timeout (applied downstream in Kit.Subagent). The parent
|
||||
// context may carry a tight deadline from the LLM generation loop or
|
||||
// other tool timeouts that would prematurely kill the subagent.
|
||||
// We preserve context values (spawner, etc.) and propagate parent
|
||||
// cancellation (e.g. user hits Ctrl-C) without inheriting the deadline.
|
||||
spawnCtx := detachedWithCancel(ctx)
|
||||
// Build a clean context for the subagent that inherits values (e.g. the
|
||||
// spawner callback) but is completely detached from the parent's
|
||||
// deadline AND cancellation. The subagent gets its own independent
|
||||
// timeout (applied downstream in Kit.Subagent).
|
||||
//
|
||||
// Why full detachment instead of propagating parent cancellation?
|
||||
// The parent context may already be done (deadline exceeded or
|
||||
// cancelled) by the time this tool handler executes — for example when
|
||||
// the generation loop context carries a deadline, when the user
|
||||
// double-ESC cancels mid-turn, or when parallel tool execution
|
||||
// encounters a race between stream completion and tool dispatch. Using
|
||||
// context.WithoutCancel (Go 1.21+) ensures the subagent always starts
|
||||
// cleanly with a fresh timeout, following the pattern used by crush for
|
||||
// shutdown-resilient child work. The subagent's own timeout
|
||||
// (defaultSubagentTimeout / user-specified) provides the safety net.
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: ctx})
|
||||
|
||||
// Spawn in-process subagent.
|
||||
result, err := spawner(spawnCtx, call.ID, args.Task, args.Model, args.SystemPrompt, timeout)
|
||||
@@ -173,37 +182,21 @@ func executeSubagent(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolRe
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context detachment
|
||||
// Context helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// detachedContext wraps a parent context, preserving its values but removing
|
||||
// its deadline and cancellation. This allows the subagent to have its own
|
||||
// independent timeout while still accessing context-stored values (e.g. the
|
||||
// subagent spawner function).
|
||||
type detachedContext struct {
|
||||
// valuesContext preserves a parent context's values (e.g. the subagent
|
||||
// spawner callback) while stripping its deadline and cancellation. Combined
|
||||
// with context.WithoutCancel() this gives the subagent a completely clean
|
||||
// context that only inherits value-based dependencies.
|
||||
type valuesContext struct {
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (d detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (d detachedContext) Done() <-chan struct{} { return nil }
|
||||
func (d detachedContext) Err() error { return nil }
|
||||
func (d detachedContext) Value(key any) any { return d.parent.Value(key) }
|
||||
|
||||
// detachedWithCancel creates a new context that inherits values from the
|
||||
// parent but has no deadline. Cancellation of the parent is propagated: when
|
||||
// the parent is cancelled the returned context is also cancelled, but the
|
||||
// parent's deadline does not apply to the child.
|
||||
func detachedWithCancel(parent context.Context) context.Context {
|
||||
child, cancel := context.WithCancel(detachedContext{parent: parent})
|
||||
go func() {
|
||||
select {
|
||||
case <-parent.Done():
|
||||
cancel()
|
||||
case <-child.Done():
|
||||
}
|
||||
}()
|
||||
return child
|
||||
}
|
||||
func (v valuesContext) Deadline() (time.Time, bool) { return time.Time{}, false }
|
||||
func (v valuesContext) Done() <-chan struct{} { return nil }
|
||||
func (v valuesContext) Err() error { return nil }
|
||||
func (v valuesContext) Value(key any) any { return v.parent.Value(key) }
|
||||
|
||||
// truncateResponse limits the response length to avoid overwhelming context windows.
|
||||
func truncateResponse(s string, maxLen int) string {
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestValuesContext_StripsDeadlineAndCancellation(t *testing.T) {
|
||||
// Parent with a tight deadline.
|
||||
parent, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(5 * time.Millisecond) // Let deadline expire.
|
||||
|
||||
if parent.Err() == nil {
|
||||
t.Fatal("expected parent to be expired")
|
||||
}
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
if _, ok := vc.Deadline(); ok {
|
||||
t.Error("valuesContext should report no deadline")
|
||||
}
|
||||
if vc.Done() != nil {
|
||||
t.Error("valuesContext.Done() should return nil")
|
||||
}
|
||||
if vc.Err() != nil {
|
||||
t.Errorf("valuesContext.Err() should be nil, got %v", vc.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuesContext_PreservesValues(t *testing.T) {
|
||||
type testKey struct{}
|
||||
parent := context.WithValue(context.Background(), testKey{}, "hello")
|
||||
|
||||
vc := valuesContext{parent: parent}
|
||||
|
||||
got, ok := vc.Value(testKey{}).(string)
|
||||
if !ok || got != "hello" {
|
||||
t.Errorf("expected value 'hello', got %q (ok=%v)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesCancelledParent(t *testing.T) {
|
||||
// Simulate the exact scenario from the bug: the parent generation
|
||||
// context is already cancelled when the subagent tool handler runs.
|
||||
parent, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancelled before detach.
|
||||
|
||||
// This is what executeSubagent now does:
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
// The spawn context must be alive.
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive, got err: %v", spawnCtx.Err())
|
||||
}
|
||||
|
||||
// Adding a timeout should produce a working context.
|
||||
tCtx, tCancel := context.WithTimeout(spawnCtx, 5*time.Second)
|
||||
defer tCancel()
|
||||
|
||||
if tCtx.Err() != nil {
|
||||
t.Fatalf("timeout context should be alive, got err: %v", tCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_SurvivesDeadlineExceededParent(t *testing.T) {
|
||||
// Simulate: parent had a deadline that already expired.
|
||||
parent, pCancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer pCancel()
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if parent.Err() != context.DeadlineExceeded {
|
||||
t.Fatalf("expected parent deadline exceeded, got: %v", parent.Err())
|
||||
}
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parent})
|
||||
|
||||
if spawnCtx.Err() != nil {
|
||||
t.Fatalf("spawnCtx should be alive after deadline-exceeded parent, got: %v", spawnCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpawnContext_PreservesSpawnerValue(t *testing.T) {
|
||||
// Verify the subagent spawner callback survives context detachment.
|
||||
called := false
|
||||
spawner := SubagentSpawnFunc(func(ctx context.Context, toolCallID, prompt, model, systemPrompt string, timeout time.Duration) (*SubagentSpawnResult, error) {
|
||||
called = true
|
||||
return &SubagentSpawnResult{Response: "ok"}, nil
|
||||
})
|
||||
|
||||
parent := WithSubagentSpawner(context.Background(), spawner)
|
||||
// Cancel the parent.
|
||||
parentCtx, cancel := context.WithCancel(parent)
|
||||
cancel()
|
||||
|
||||
spawnCtx := context.WithoutCancel(valuesContext{parent: parentCtx})
|
||||
|
||||
// Should be able to retrieve the spawner from the detached context.
|
||||
recovered := getSubagentSpawner(spawnCtx)
|
||||
if recovered == nil {
|
||||
t.Fatal("spawner should be recoverable from detached context")
|
||||
}
|
||||
|
||||
result, err := recovered(spawnCtx, "tc1", "test task", "", "", time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("spawner call failed: %v", err)
|
||||
}
|
||||
if !called {
|
||||
t.Error("spawner was not called")
|
||||
}
|
||||
if result.Response != "ok" {
|
||||
t.Errorf("expected 'ok', got %q", result.Response)
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func ReadOnlyTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SubagentTools returns all core tools except spawn_subagent. This prevents
|
||||
// SubagentTools returns all core tools except subagent. This prevents
|
||||
// infinite recursion when a subagent is itself a Kit instance.
|
||||
func SubagentTools(opts ...ToolOption) []fantasy.AgentTool {
|
||||
return []fantasy.AgentTool{
|
||||
|
||||
+344
-6
@@ -77,6 +77,64 @@ type Context struct {
|
||||
// ctx.CancelAndSend("Stop what you're doing and focus on the tests")
|
||||
CancelAndSend func(string)
|
||||
|
||||
// Abort cancels the current agent turn (if running) and clears the
|
||||
// message queue. Unlike CancelAndSend, no new message is injected —
|
||||
// the agent simply stops. Safe to call when idle (no-op).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ctx.Abort() // stop whatever the agent is doing
|
||||
Abort func()
|
||||
|
||||
// IsIdle returns true when the agent is not processing a turn.
|
||||
// Extensions can use this to decide whether to dispatch immediately
|
||||
// or queue work for later.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// if ctx.IsIdle() {
|
||||
// ctx.SendMessage("start new task")
|
||||
// }
|
||||
IsIdle func() bool
|
||||
|
||||
// Compact triggers context compaction, summarising older messages to
|
||||
// free context window space. Returns an error if compaction cannot
|
||||
// start (e.g. agent is busy or app is closed). The actual compaction
|
||||
// runs asynchronously; use OnComplete/OnError callbacks in
|
||||
// CompactConfig to observe the result.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := ctx.Compact(ext.CompactConfig{
|
||||
// OnComplete: func() { ctx.PrintInfo("Compaction done") },
|
||||
// OnError: func(errMsg string) { ctx.PrintError("Compact failed: " + errMsg) },
|
||||
// })
|
||||
Compact func(CompactConfig) error
|
||||
|
||||
// SendMultimodalMessage injects a message with file attachments (images,
|
||||
// documents) into the conversation and triggers a new agent turn. Files
|
||||
// are described by FilePart structs containing the raw bytes, filename,
|
||||
// and MIME type. If the agent is busy the message is queued.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// data, _ := os.ReadFile("photo.jpg")
|
||||
// ctx.SendMultimodalMessage("Describe this image", []ext.FilePart{
|
||||
// {Filename: "photo.jpg", Data: data, MediaType: "image/jpeg"},
|
||||
// })
|
||||
SendMultimodalMessage func(text string, files []FilePart)
|
||||
|
||||
// GetSessionUsage returns aggregated token usage and cost statistics
|
||||
// for the current session. This includes total input/output tokens,
|
||||
// cache read/write tokens, total cost, and request count.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// usage := ctx.GetSessionUsage()
|
||||
// fmt.Sprintf("Tokens: ↑%d ↓%d Cost: $%.3f",
|
||||
// usage.TotalInputTokens, usage.TotalOutputTokens, usage.TotalCost)
|
||||
GetSessionUsage func() SessionUsage
|
||||
|
||||
// SetWidget places or updates a persistent widget in the TUI. Widgets
|
||||
// remain visible across agent turns until explicitly removed. The
|
||||
// widget is identified by WidgetConfig.ID; calling SetWidget with the
|
||||
@@ -572,6 +630,102 @@ type Context struct {
|
||||
// })
|
||||
// // handle.Kill() to cancel, handle.Wait() to block
|
||||
SpawnSubagent func(SubagentConfig) (*SubagentHandle, *SubagentResult, error)
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API (Phase 1 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// GetTreeNode returns a node by ID with full metadata and children.
|
||||
// Returns nil if entry not found.
|
||||
GetTreeNode func(entryID string) *TreeNode
|
||||
|
||||
// GetCurrentBranch returns the path from root to current leaf.
|
||||
// Each node contains full metadata (unlike GetMessages which flattens).
|
||||
GetCurrentBranch func() []TreeNode
|
||||
|
||||
// GetChildren returns direct child IDs of an entry.
|
||||
GetChildren func(entryID string) []string
|
||||
|
||||
// NavigateTo branches/forks the session to the specified entry ID.
|
||||
// Equivalent to SDK's Branch() but for extensions.
|
||||
NavigateTo func(entryID string) TreeNavigationResult
|
||||
|
||||
// SummarizeBranch uses LLM to summarize a branch range.
|
||||
// Returns summary text or error string (empty if success).
|
||||
SummarizeBranch func(fromID, toID string) string
|
||||
|
||||
// CollapseBranch replaces a branch range with a summary entry.
|
||||
// This is the "fresh context" primitive for context window management.
|
||||
CollapseBranch func(fromID, toID, summary string) TreeNavigationResult
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API (Phase 2 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// LoadSkill loads a single skill file from path.
|
||||
// Parses YAML frontmatter, returns skill with content ready for injection.
|
||||
LoadSkill func(path string) (*Skill, string)
|
||||
|
||||
// LoadSkillsFromDir discovers and loads all skills from a directory.
|
||||
LoadSkillsFromDir func(dir string) SkillLoadResult
|
||||
|
||||
// DiscoverSkills finds skills in standard locations.
|
||||
// Checks ~/.config/kit/skills/, .kit/skills/, .agents/skills/
|
||||
DiscoverSkills func() SkillLoadResult
|
||||
|
||||
// InjectSkillAsContext sends a skill's content as a system message.
|
||||
// Looks up skill by name from discovered skills.
|
||||
InjectSkillAsContext func(skillName string) string
|
||||
|
||||
// InjectRawSkillAsContext loads and immediately injects a skill file.
|
||||
InjectRawSkillAsContext func(path string) string
|
||||
|
||||
// GetAvailableSkills returns all currently loaded/discovered skills.
|
||||
GetAvailableSkills func() []Skill
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API (Phase 3 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ParseTemplate extracts {{variables}} from template content.
|
||||
ParseTemplate func(name, content string) PromptTemplate
|
||||
|
||||
// RenderTemplate substitutes variables into template content.
|
||||
RenderTemplate func(tpl PromptTemplate, vars map[string]string) string
|
||||
|
||||
// ParseArguments parses command-line style arguments.
|
||||
ParseArguments func(input string, pattern ArgumentPattern) ParseResult
|
||||
|
||||
// SimpleParseArguments parses $1, $2, $@ style arguments.
|
||||
// Returns slice where [0]=full input, [1]=$1, [2]=$2, ... [n]=$@
|
||||
SimpleParseArguments func(input string, count int) []string
|
||||
|
||||
// EvaluateModelConditional checks if condition matches current model.
|
||||
// Condition supports wildcards: * matches any, ? matches single char.
|
||||
EvaluateModelConditional func(condition string) bool
|
||||
|
||||
// RenderWithModelConditionals processes <if-model> blocks in content.
|
||||
RenderWithModelConditionals func(content string) string
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API (Phase 4 Bridge)
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
// ResolveModelChain attempts each model in order until one is available.
|
||||
ResolveModelChain func(preferences []string) ModelResolutionResult
|
||||
|
||||
// GetModelCapabilities returns capabilities for a specific model.
|
||||
// If model is empty, uses current model.
|
||||
GetModelCapabilities func(model string) (ModelCapabilities, string)
|
||||
|
||||
// CheckModelAvailable verifies if a model string is valid.
|
||||
CheckModelAvailable func(model string) bool
|
||||
|
||||
// GetCurrentProvider returns just the provider part of current model.
|
||||
GetCurrentProvider func() string
|
||||
|
||||
// GetCurrentModelID returns just the model ID part of current model.
|
||||
GetCurrentModelID func() string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -598,6 +752,148 @@ type SessionMessage struct {
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tree navigation types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TreeNode represents a node in the session tree for navigation.
|
||||
// Extensions use this to traverse conversation history and implement
|
||||
// features like "fresh context" loops and branch summarization.
|
||||
type TreeNode struct {
|
||||
// ID is the unique entry identifier.
|
||||
ID string
|
||||
// ParentID links this entry to its parent (empty if root).
|
||||
ParentID string
|
||||
// Type is the entry type: "message", "branch_summary", "model_change", "extension_data", "tool_execution".
|
||||
Type string
|
||||
// Role is the message role for message entries: "user", "assistant", "system", "tool".
|
||||
Role string
|
||||
// Content is the text content or summary.
|
||||
Content string
|
||||
// Model is the model that generated this (for assistant messages).
|
||||
Model string
|
||||
// Provider is the provider used.
|
||||
Provider string
|
||||
// Timestamp is the RFC3339-formatted creation time.
|
||||
Timestamp string
|
||||
// Children is the list of child entry IDs for tree traversal.
|
||||
Children []string
|
||||
}
|
||||
|
||||
// TreeNavigationResult reports success or failure of tree operations.
|
||||
type TreeNavigationResult struct {
|
||||
// Success is true if the operation completed.
|
||||
Success bool
|
||||
// Error describes what went wrong (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Skill types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// Skill represents a loaded skill file with parsed YAML frontmatter.
|
||||
type Skill struct {
|
||||
// Name is the human-readable identifier.
|
||||
Name string
|
||||
// Description summarizes what this skill provides.
|
||||
Description string
|
||||
// Content is the markdown body (frontmatter stripped).
|
||||
Content string
|
||||
// Path is the absolute filesystem path.
|
||||
Path string
|
||||
// Tags are optional labels for categorization.
|
||||
Tags []string
|
||||
// When controls automatic inclusion: "always", "on-demand", or file-glob.
|
||||
When string
|
||||
}
|
||||
|
||||
// SkillLoadResult reports skills loaded from a directory.
|
||||
type SkillLoadResult struct {
|
||||
// Skills is the list of loaded skills.
|
||||
Skills []Skill
|
||||
// Error describes loading failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Template parsing types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// PromptTemplate represents a parsed template with variable placeholders.
|
||||
type PromptTemplate struct {
|
||||
// Name is the template identifier.
|
||||
Name string
|
||||
// Content is the original template content.
|
||||
Content string
|
||||
// Variables are the extracted {{variable}} names.
|
||||
Variables []string
|
||||
}
|
||||
|
||||
// ArgumentPattern defines how to parse command arguments.
|
||||
type ArgumentPattern struct {
|
||||
// Positional names for $1, $2, etc.
|
||||
Positional []string
|
||||
// Rest is the variable name for $@ (all remaining).
|
||||
Rest string
|
||||
// Flags maps flag names to variable names (e.g., "--loop" -> "loop").
|
||||
Flags map[string]string
|
||||
}
|
||||
|
||||
// ParseResult reports argument parsing outcome.
|
||||
type ParseResult struct {
|
||||
// Vars maps variable names to values for positional args.
|
||||
Vars map[string]string
|
||||
// Flags maps flag names to values.
|
||||
Flags map[string]string
|
||||
// Rest is remaining unparsed text.
|
||||
Rest string
|
||||
// Error describes parsing failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ModelConditional represents an <if-model> block for evaluation.
|
||||
type ModelConditional struct {
|
||||
// Condition is the model pattern (e.g., "claude-*", "anthropic/*").
|
||||
Condition string
|
||||
// Content is rendered if condition matches.
|
||||
Content string
|
||||
// Else is rendered if condition doesn't match.
|
||||
Else string
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Model resolution types (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// ModelCapabilities describes what a model supports.
|
||||
type ModelCapabilities struct {
|
||||
// Provider is the provider ID (e.g., "anthropic").
|
||||
Provider string
|
||||
// ModelID is the model identifier (e.g., "claude-sonnet-4-20250929").
|
||||
ModelID string
|
||||
// ContextLimit is the maximum context window in tokens.
|
||||
ContextLimit int
|
||||
// OutputLimit is the maximum output tokens.
|
||||
OutputLimit int
|
||||
// Reasoning indicates if the model supports reasoning/thinking.
|
||||
Reasoning bool
|
||||
// Streaming indicates if the model supports streaming.
|
||||
Streaming bool
|
||||
}
|
||||
|
||||
// ModelResolutionResult reports model chain resolution outcome.
|
||||
type ModelResolutionResult struct {
|
||||
// Model is the selected model in "provider/model" format.
|
||||
Model string
|
||||
// Capabilities describes the selected model.
|
||||
Capabilities ModelCapabilities
|
||||
// Attempted lists models tried before success.
|
||||
Attempted []string
|
||||
// Error describes resolution failures (empty if success).
|
||||
Error string
|
||||
}
|
||||
|
||||
// ExtensionEntry represents persisted extension data stored in the session.
|
||||
// Extensions use AppendEntry to save custom state and GetEntries to retrieve
|
||||
// it on session resume.
|
||||
@@ -622,7 +918,7 @@ type ExtensionEntry struct {
|
||||
type ContextMessage struct {
|
||||
// Index is the position of this message in the original context array
|
||||
// (0-based). When returning messages from a ContextPrepareResult,
|
||||
// messages with Index >= 0 reuse the original fantasy.Message at that
|
||||
// messages with Index >= 0 reuse the original LLM message at that
|
||||
// position (preserving tool calls, reasoning, and other complex parts).
|
||||
// Set Index to -1 for newly injected messages (created from Role + Content).
|
||||
Index int
|
||||
@@ -699,6 +995,48 @@ type StatusBarEntry struct {
|
||||
Priority int
|
||||
}
|
||||
|
||||
// CompactConfig configures a programmatic context compaction request.
|
||||
type CompactConfig struct {
|
||||
// CustomInstructions is optional text appended to the summary prompt
|
||||
// (e.g. "Focus on the API design decisions"). Empty uses the default.
|
||||
CustomInstructions string
|
||||
// OnComplete is called when compaction finishes successfully.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnComplete func()
|
||||
// OnError is called when compaction fails. The argument is the error message.
|
||||
// May be nil if the caller doesn't need notification.
|
||||
OnError func(errMsg string)
|
||||
}
|
||||
|
||||
// FilePart describes a file attachment for multimodal messages. Extensions
|
||||
// use this with SendMultimodalMessage to attach images or documents.
|
||||
type FilePart struct {
|
||||
// Filename is the name of the file (e.g. "photo.jpg").
|
||||
Filename string
|
||||
// Data is the raw file content.
|
||||
Data []byte
|
||||
// MediaType is the MIME type (e.g. "image/jpeg", "application/pdf").
|
||||
MediaType string
|
||||
}
|
||||
|
||||
// SessionUsage contains aggregated token usage and cost statistics for
|
||||
// the current session. Extensions use this with GetSessionUsage() to
|
||||
// report usage information.
|
||||
type SessionUsage struct {
|
||||
// TotalInputTokens is the sum of input tokens across all requests.
|
||||
TotalInputTokens int
|
||||
// TotalOutputTokens is the sum of output tokens across all requests.
|
||||
TotalOutputTokens int
|
||||
// TotalCacheReadTokens is the sum of cache read tokens.
|
||||
TotalCacheReadTokens int
|
||||
// TotalCacheWriteTokens is the sum of cache write tokens.
|
||||
TotalCacheWriteTokens int
|
||||
// TotalCost is the total cost in USD across all requests.
|
||||
TotalCost float64
|
||||
// RequestCount is the number of LLM requests made in this session.
|
||||
RequestCount int
|
||||
}
|
||||
|
||||
// PrintBlockOpts configures a custom styled block for PrintBlock.
|
||||
type PrintBlockOpts struct {
|
||||
// Text is the main content to display.
|
||||
@@ -784,7 +1122,7 @@ func (a *API) OnToolResult(handler func(ToolResultEvent, Context) *ToolResultRes
|
||||
a.onToolResult(handler)
|
||||
}
|
||||
|
||||
// OnSubagentStart registers a handler that fires when a spawn_subagent tool
|
||||
// OnSubagentStart registers a handler that fires when a subagent tool
|
||||
// call begins executing. Use the ToolCallID to correlate with subsequent
|
||||
// OnSubagentChunk and OnSubagentEnd events for the same subagent.
|
||||
func (a *API) OnSubagentStart(handler func(SubagentStartEvent, Context)) {
|
||||
@@ -799,7 +1137,7 @@ func (a *API) OnSubagentChunk(handler func(SubagentChunkEvent, Context)) {
|
||||
a.onSubagentChunk(handler)
|
||||
}
|
||||
|
||||
// OnSubagentEnd registers a handler that fires when a spawn_subagent call
|
||||
// OnSubagentEnd registers a handler that fires when a subagent call
|
||||
// completes. ErrorMsg is non-empty when the subagent failed.
|
||||
func (a *API) OnSubagentEnd(handler func(SubagentEndEvent, Context)) {
|
||||
a.onSubagentEnd(handler)
|
||||
@@ -1808,9 +2146,9 @@ func (BeforeCompactResult) isResult() {}
|
||||
// Subagent lifecycle events (exposed to Yaegi — concrete structs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// SubagentStartEvent fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStartEvent fires when a subagent tool call begins executing.
|
||||
type SubagentStartEvent struct {
|
||||
// ToolCallID is the LLM-assigned ID of the spawn_subagent tool call.
|
||||
// ToolCallID is the LLM-assigned ID of the subagent tool call.
|
||||
// Use this to correlate SubagentChunkEvent and SubagentEndEvent.
|
||||
ToolCallID string
|
||||
// Task is the task description passed to the subagent.
|
||||
@@ -1850,7 +2188,7 @@ type SubagentChunkEvent struct {
|
||||
|
||||
func (e SubagentChunkEvent) Type() EventType { return SubagentChunk }
|
||||
|
||||
// SubagentEndEvent fires when a spawn_subagent tool call completes.
|
||||
// SubagentEndEvent fires when a subagent tool call completes.
|
||||
type SubagentEndEvent struct {
|
||||
// ToolCallID matches the SubagentStartEvent.ToolCallID for this subagent.
|
||||
ToolCallID string
|
||||
|
||||
@@ -72,7 +72,7 @@ const (
|
||||
// cancel compaction by returning Cancel=true.
|
||||
BeforeCompact EventType = "before_compact"
|
||||
|
||||
// SubagentStart fires when a spawn_subagent tool call begins executing.
|
||||
// SubagentStart fires when a subagent tool call begins executing.
|
||||
// Carries the tool call ID and the task description.
|
||||
SubagentStart EventType = "subagent_start"
|
||||
|
||||
@@ -80,7 +80,7 @@ const (
|
||||
// subagent: text chunks, tool calls, tool results, etc.
|
||||
SubagentChunk EventType = "subagent_chunk"
|
||||
|
||||
// SubagentEnd fires when a spawn_subagent tool call completes (success
|
||||
// SubagentEnd fires when a subagent tool call completes (success
|
||||
// or error). Carries the final response and any error message.
|
||||
SubagentEnd EventType = "subagent_end"
|
||||
)
|
||||
|
||||
@@ -154,6 +154,11 @@ func NewInstaller(projectDir string) *Installer {
|
||||
|
||||
// Install clones a git repository to the appropriate scope.
|
||||
func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
return i.install(source, scope, nil)
|
||||
}
|
||||
|
||||
// install is the internal implementation that supports optional include paths.
|
||||
func (i *Installer) install(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
targetDir := i.getInstallPath(source, scope)
|
||||
|
||||
// Check if already installed
|
||||
@@ -199,6 +204,7 @@ func (i *Installer) Install(source *GitSource, scope InstallScope) error {
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Include: includePaths,
|
||||
}
|
||||
if err := i.addToManifest(entry, scope); err != nil {
|
||||
// Don't fail the install, just log the error
|
||||
@@ -268,7 +274,22 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
cleanCmd.Dir = targetDir
|
||||
_ = cleanCmd.Run() // Ignore errors - clean is best effort
|
||||
|
||||
// Update manifest timestamp
|
||||
// Update manifest timestamp, preserving existing fields like Include
|
||||
existing, _ := i.loadManifest(scope)
|
||||
var include []string
|
||||
var installed time.Time
|
||||
if existing != nil {
|
||||
for _, p := range existing.Packages {
|
||||
if p.Host+"/"+p.Path == source.Identity() {
|
||||
include = p.Include
|
||||
installed = p.Installed
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if installed.IsZero() {
|
||||
installed = time.Now()
|
||||
}
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
@@ -277,8 +298,9 @@ func (i *Installer) Update(source *GitSource, scope InstallScope) error {
|
||||
Ref: "",
|
||||
Pinned: false,
|
||||
Scope: scope,
|
||||
Installed: time.Now(),
|
||||
Installed: installed,
|
||||
Updated: time.Now(),
|
||||
Include: include,
|
||||
}
|
||||
_ = i.addToManifest(entry, scope) // Best effort - don't fail update if manifest fails
|
||||
|
||||
@@ -503,30 +525,7 @@ func (i *Installer) PreviewExtensions(source *GitSource) ([]ExtensionPreview, st
|
||||
// InstallWithInclude clones a repo and installs only the specified extensions.
|
||||
// includePaths are relative paths like "./git/main.go" - if empty, installs all.
|
||||
func (i *Installer) InstallWithInclude(source *GitSource, scope InstallScope, includePaths []string) error {
|
||||
// First, do a regular install
|
||||
if err := i.Install(source, scope); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If specific includes were requested, update the manifest
|
||||
if len(includePaths) > 0 {
|
||||
entry := ManifestEntry{
|
||||
Source: source.String(),
|
||||
Repo: source.Repo,
|
||||
Host: source.Host,
|
||||
Path: source.Path,
|
||||
Ref: source.Ref,
|
||||
Pinned: source.Pinned,
|
||||
Scope: scope,
|
||||
Include: includePaths,
|
||||
}
|
||||
|
||||
if err := addEntryToManifest(entry, scope); err != nil {
|
||||
return fmt.Errorf("updating manifest with includes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return i.install(source, scope, includePaths)
|
||||
}
|
||||
|
||||
// CleanupTempDir removes a temporary directory used for preview.
|
||||
|
||||
@@ -34,59 +34,64 @@ func LoadExtensions(extraPaths []string) ([]LoadedExtension, error) {
|
||||
for _, p := range paths {
|
||||
ext, err := loadSingleExtension(p)
|
||||
if err != nil {
|
||||
log.Warn("skipping extension", "path", p, "err", err)
|
||||
continue
|
||||
}
|
||||
loaded = append(loaded, *ext)
|
||||
log.Debug("loaded extension", "path", p,
|
||||
"handlers", countHandlers(ext),
|
||||
"tools", len(ext.Tools),
|
||||
"commands", len(ext.Commands),
|
||||
"tool_renderers", len(ext.ToolRenderers))
|
||||
log.Debug("loaded extension", "path", p, "handlers", countHandlers(ext), "tools", len(ext.Tools), "commands", len(ext.Commands), "tool_renderers", len(ext.ToolRenderers))
|
||||
}
|
||||
return loaded, nil
|
||||
}
|
||||
|
||||
// pathSet is a thread-safe helper for deduplicating and ordering file paths.
|
||||
type pathSet struct {
|
||||
m map[string]bool
|
||||
list []string
|
||||
}
|
||||
|
||||
func newPathSet() *pathSet {
|
||||
return &pathSet{m: make(map[string]bool)}
|
||||
}
|
||||
|
||||
func (ps *pathSet) add(p string) bool {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if ps.m[abs] {
|
||||
return false
|
||||
}
|
||||
ps.m[abs] = true
|
||||
ps.list = append(ps.list, abs)
|
||||
return true
|
||||
}
|
||||
|
||||
// discoverExtensionPaths returns deduplicated paths to extension files in
|
||||
// load-order (global first, then project-local, then explicit).
|
||||
func discoverExtensionPaths(extraPaths []string) []string {
|
||||
seen := make(map[string]bool)
|
||||
var paths []string
|
||||
|
||||
add := func(p string) {
|
||||
abs, err := filepath.Abs(p)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
seen[abs] = true
|
||||
paths = append(paths, abs)
|
||||
}
|
||||
ps := newPathSet()
|
||||
|
||||
// Global extensions: $XDG_CONFIG_HOME/kit/extensions/ (default ~/.config/kit/extensions/)
|
||||
globalDir := globalExtensionsDir()
|
||||
for _, p := range findExtensionsInDir(globalDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Global installed git packages: $XDG_DATA_HOME/kit/git/
|
||||
globalGitDir := globalGitInstallRoot()
|
||||
for _, p := range findExtensionsInGitPackages(globalGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local extensions: .kit/extensions/
|
||||
localDir := filepath.Join(".kit", "extensions")
|
||||
for _, p := range findExtensionsInDir(localDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Project-local installed git packages: .kit/git/
|
||||
projectGitDir := filepath.Join(".kit", "git")
|
||||
for _, p := range findExtensionsInGitPackages(projectGitDir) {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
|
||||
// Explicit paths (highest precedence)
|
||||
@@ -97,14 +102,14 @@ func discoverExtensionPaths(extraPaths []string) []string {
|
||||
}
|
||||
if info.IsDir() {
|
||||
for _, found := range findExtensionsInDir(p) {
|
||||
add(found)
|
||||
ps.add(found)
|
||||
}
|
||||
} else if strings.HasSuffix(p, ".go") {
|
||||
add(p)
|
||||
ps.add(p)
|
||||
}
|
||||
}
|
||||
|
||||
return paths
|
||||
return ps.list
|
||||
}
|
||||
|
||||
// findExtensionsInDir returns .go files in dir and main.go in immediate subdirs.
|
||||
@@ -123,7 +128,7 @@ func findExtensionsInDir(dir string) []string {
|
||||
|
||||
for _, entry := range entries {
|
||||
full := filepath.Join(dir, entry.Name())
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") {
|
||||
results = append(results, full)
|
||||
} else if entry.IsDir() {
|
||||
main := filepath.Join(full, "main.go")
|
||||
@@ -180,9 +185,13 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
if relPath == base { // Top-level directory
|
||||
@@ -192,13 +201,6 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
results = append(results, mainPath)
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -217,7 +219,7 @@ func findExtensionsInRepo(repoPath string) []string {
|
||||
}
|
||||
|
||||
// It's a file
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -253,10 +253,13 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
isExtDir := base == "extensions" || base == "ext" ||
|
||||
strings.HasSuffix(base, "-extensions") || strings.HasSuffix(base, "-ext")
|
||||
|
||||
// Or check if it's a subdirectory of examples/ that might contain extensions
|
||||
isExamplesSubdir := relPath == "examples" || strings.HasPrefix(relPath, "examples/")
|
||||
// Allow walking into examples/ so we can reach examples/extensions/ etc,
|
||||
// but don't treat examples/ itself or non-extension subdirs as extension locations.
|
||||
if relPath == "examples" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isExtDir && !isExamplesSubdir {
|
||||
if !isExtDir {
|
||||
// Check for main.go before skipping
|
||||
mainPath := filepath.Join(path, "main.go")
|
||||
if _, err := os.Stat(mainPath); err == nil {
|
||||
@@ -272,18 +275,6 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
// Inside a valid extensions directory
|
||||
if isExamplesSubdir || isExtDir {
|
||||
if !multiFileDirs[relPath] {
|
||||
multiFileDirs[relPath] = true
|
||||
previews = append(previews, ExtensionPreview{
|
||||
Path: "./" + relPath + "/main.go",
|
||||
Name: deriveExtensionName(relPath+"/main.go", true),
|
||||
IsMain: true,
|
||||
})
|
||||
}
|
||||
return filepath.SkipDir
|
||||
}
|
||||
}
|
||||
|
||||
// Not an extension location
|
||||
@@ -309,7 +300,7 @@ func ScanForExtensions(dir string) ([]ExtensionPreview, error) {
|
||||
}
|
||||
|
||||
// It's a file - check if it's a valid extension
|
||||
if !strings.HasSuffix(info.Name(), ".go") {
|
||||
if !strings.HasSuffix(info.Name(), ".go") || strings.HasSuffix(info.Name(), "_test.go") {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+206
-11
@@ -1,21 +1,93 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// reentrantMu — a per-extension mutex that allows the same goroutine to
|
||||
// re-enter (e.g. handler → ctx.EmitCustomEvent → handler in same extension).
|
||||
// Different goroutines are serialized, preventing concurrent state mutation.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type reentrantMu struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
owner int64 // goroutine ID that holds the lock, or 0
|
||||
depth int // re-entrancy depth
|
||||
}
|
||||
|
||||
// initReentrantMu initializes the reentrant mutex in-place. Must be called
|
||||
// after the struct is at its final memory location (not before copying).
|
||||
func (r *reentrantMu) init() {
|
||||
r.cond = sync.NewCond(&r.mu)
|
||||
}
|
||||
|
||||
// lock acquires the mutex. If the calling goroutine already holds it, the
|
||||
// call succeeds immediately (re-entrant). Every call to lock must be paired
|
||||
// with a call to unlock.
|
||||
func (r *reentrantMu) lock() {
|
||||
gid := goroutineID()
|
||||
r.mu.Lock()
|
||||
if r.owner == gid {
|
||||
// Re-entrant: same goroutine already holds the lock.
|
||||
r.depth++
|
||||
r.mu.Unlock()
|
||||
return
|
||||
}
|
||||
// Wait for the current owner to release.
|
||||
for r.owner != 0 {
|
||||
r.cond.Wait() // releases mu, blocks, re-acquires mu on wake
|
||||
}
|
||||
r.owner = gid
|
||||
r.depth = 1
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// unlock releases the mutex (or decrements re-entrancy depth).
|
||||
func (r *reentrantMu) unlock() {
|
||||
r.mu.Lock()
|
||||
r.depth--
|
||||
if r.depth == 0 {
|
||||
r.owner = 0
|
||||
r.cond.Signal()
|
||||
}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// goroutineID extracts the current goroutine's ID from runtime.Stack output.
|
||||
// This is a well-known technique used by Go testing infrastructure.
|
||||
func goroutineID() int64 {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
// Stack output starts with "goroutine NNN ["
|
||||
s := buf[:n]
|
||||
s = s[len("goroutine "):]
|
||||
s = s[:bytes.IndexByte(s, ' ')]
|
||||
id, _ := strconv.ParseInt(string(s), 10, 64)
|
||||
return id
|
||||
}
|
||||
|
||||
// Runner manages loaded extensions and dispatches events to their handlers
|
||||
// sequentially. Handlers execute in extension
|
||||
// load order; for cancellable events the first blocking result wins.
|
||||
//
|
||||
// Each extension has a dedicated reentrant mutex so that handlers for the
|
||||
// same extension are serialized (preventing data races on shared package-level
|
||||
// state), while handlers for different extensions may execute concurrently.
|
||||
type Runner struct {
|
||||
extensions []LoadedExtension
|
||||
extMu []reentrantMu // per-extension reentrant mutex, indexed by extension position
|
||||
ctx Context
|
||||
widgets map[string]WidgetConfig // keyed by widget ID
|
||||
statusEntries map[string]StatusBarEntry // keyed by status key
|
||||
@@ -52,7 +124,11 @@ type LoadedExtension struct {
|
||||
|
||||
// NewRunner creates a Runner from a set of loaded extensions.
|
||||
func NewRunner(exts []LoadedExtension) *Runner {
|
||||
return &Runner{extensions: exts}
|
||||
mus := make([]reentrantMu, len(exts))
|
||||
for i := range mus {
|
||||
mus[i].init()
|
||||
}
|
||||
return &Runner{extensions: exts, extMu: mus}
|
||||
}
|
||||
|
||||
// SetContext updates the runtime context (session ID, model, etc.) that is
|
||||
@@ -86,6 +162,21 @@ func normalizeContext(ctx Context) Context {
|
||||
if ctx.CancelAndSend == nil {
|
||||
ctx.CancelAndSend = func(string) {}
|
||||
}
|
||||
if ctx.Abort == nil {
|
||||
ctx.Abort = func() {}
|
||||
}
|
||||
if ctx.IsIdle == nil {
|
||||
ctx.IsIdle = func() bool { return true }
|
||||
}
|
||||
if ctx.Compact == nil {
|
||||
ctx.Compact = func(CompactConfig) error { return fmt.Errorf("compact not available") }
|
||||
}
|
||||
if ctx.SendMultimodalMessage == nil {
|
||||
ctx.SendMultimodalMessage = func(string, []FilePart) {}
|
||||
}
|
||||
if ctx.GetSessionUsage == nil {
|
||||
ctx.GetSessionUsage = func() SessionUsage { return SessionUsage{} }
|
||||
}
|
||||
if ctx.SetWidget == nil {
|
||||
ctx.SetWidget = func(WidgetConfig) {}
|
||||
}
|
||||
@@ -214,6 +305,102 @@ func normalizeContext(ctx Context) Context {
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Tree Navigation API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.GetTreeNode == nil {
|
||||
ctx.GetTreeNode = func(string) *TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetCurrentBranch == nil {
|
||||
ctx.GetCurrentBranch = func() []TreeNode { return nil }
|
||||
}
|
||||
if ctx.GetChildren == nil {
|
||||
ctx.GetChildren = func(string) []string { return nil }
|
||||
}
|
||||
if ctx.NavigateTo == nil {
|
||||
ctx.NavigateTo = func(string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.SummarizeBranch == nil {
|
||||
ctx.SummarizeBranch = func(string, string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
if ctx.CollapseBranch == nil {
|
||||
ctx.CollapseBranch = func(string, string, string) TreeNavigationResult {
|
||||
return TreeNavigationResult{Success: false, Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Skill Loading API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.LoadSkill == nil {
|
||||
ctx.LoadSkill = func(string) (*Skill, string) { return nil, "" }
|
||||
}
|
||||
if ctx.LoadSkillsFromDir == nil {
|
||||
ctx.LoadSkillsFromDir = func(string) SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.DiscoverSkills == nil {
|
||||
ctx.DiscoverSkills = func() SkillLoadResult { return SkillLoadResult{} }
|
||||
}
|
||||
if ctx.InjectSkillAsContext == nil {
|
||||
ctx.InjectSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.InjectRawSkillAsContext == nil {
|
||||
ctx.InjectRawSkillAsContext = func(string) string { return "" }
|
||||
}
|
||||
if ctx.GetAvailableSkills == nil {
|
||||
ctx.GetAvailableSkills = func() []Skill { return nil }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Template Parsing API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ParseTemplate == nil {
|
||||
ctx.ParseTemplate = func(string, string) PromptTemplate { return PromptTemplate{} }
|
||||
}
|
||||
if ctx.RenderTemplate == nil {
|
||||
ctx.RenderTemplate = func(PromptTemplate, map[string]string) string { return "" }
|
||||
}
|
||||
if ctx.ParseArguments == nil {
|
||||
ctx.ParseArguments = func(string, ArgumentPattern) ParseResult { return ParseResult{} }
|
||||
}
|
||||
if ctx.SimpleParseArguments == nil {
|
||||
ctx.SimpleParseArguments = func(string, int) []string { return nil }
|
||||
}
|
||||
if ctx.EvaluateModelConditional == nil {
|
||||
ctx.EvaluateModelConditional = func(string) bool { return false }
|
||||
}
|
||||
if ctx.RenderWithModelConditionals == nil {
|
||||
ctx.RenderWithModelConditionals = func(string) string { return "" }
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Model Resolution API no-ops
|
||||
// -------------------------------------------------------------------------
|
||||
if ctx.ResolveModelChain == nil {
|
||||
ctx.ResolveModelChain = func([]string) ModelResolutionResult {
|
||||
return ModelResolutionResult{Error: "not implemented"}
|
||||
}
|
||||
}
|
||||
if ctx.GetModelCapabilities == nil {
|
||||
ctx.GetModelCapabilities = func(string) (ModelCapabilities, string) {
|
||||
return ModelCapabilities{}, "not implemented"
|
||||
}
|
||||
}
|
||||
if ctx.CheckModelAvailable == nil {
|
||||
ctx.CheckModelAvailable = func(string) bool { return false }
|
||||
}
|
||||
if ctx.GetCurrentProvider == nil {
|
||||
ctx.GetCurrentProvider = func() string { return "" }
|
||||
}
|
||||
if ctx.GetCurrentModelID == nil {
|
||||
ctx.GetCurrentModelID = func() string { return "" }
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -256,13 +443,15 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
for i := range r.extensions {
|
||||
ext := &r.extensions[i]
|
||||
handlers := ext.Handlers[event.Type()]
|
||||
if len(handlers) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
r.extMu[i].lock()
|
||||
for _, handler := range handlers {
|
||||
result, err := safeCall(handler, event, ctx)
|
||||
if err != nil {
|
||||
log.Warn("extension handler error",
|
||||
"path", ext.Path,
|
||||
"event", event.Type(),
|
||||
"err", err)
|
||||
log.Printf("WARN extension handler error: path=%s event=%s err=%v", ext.Path, event.Type(), err)
|
||||
continue
|
||||
}
|
||||
if result == nil {
|
||||
@@ -271,6 +460,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
|
||||
// Check for blocking/short-circuit results.
|
||||
if isBlocking(result) {
|
||||
r.extMu[i].unlock()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -278,6 +468,7 @@ func (r *Runner) Emit(event Event) (Result, error) {
|
||||
// the caller is responsible for applying the modifications.
|
||||
accumulated = result
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
return accumulated, nil
|
||||
}
|
||||
@@ -596,9 +787,7 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
safeInvoke := func(h func(string)) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
log.Warn("custom event handler panicked",
|
||||
"event", name,
|
||||
"err", fmt.Sprintf("%v", rec))
|
||||
log.Printf("WARN custom event handler panicked: event=%s err=%v", name, rec)
|
||||
}
|
||||
}()
|
||||
h(data)
|
||||
@@ -606,11 +795,17 @@ func (r *Runner) EmitCustomEvent(name, data string) {
|
||||
|
||||
// Extension-registered handlers first (in load order).
|
||||
for i := range r.extensions {
|
||||
for _, h := range r.extensions[i].CustomEventHandlers[name] {
|
||||
extHandlers := r.extensions[i].CustomEventHandlers[name]
|
||||
if len(extHandlers) == 0 {
|
||||
continue
|
||||
}
|
||||
r.extMu[i].lock()
|
||||
for _, h := range extHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
r.extMu[i].unlock()
|
||||
}
|
||||
// Then dynamic subscriptions.
|
||||
// Then dynamic subscriptions (not extension-scoped, no per-ext lock).
|
||||
for _, h := range dynamicHandlers {
|
||||
safeInvoke(h)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -571,3 +572,142 @@ func TestRunner_ContextPrintNilSafe(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitSameExtension(t *testing.T) {
|
||||
// Verify that concurrent Emit calls for the same extension are serialized
|
||||
// and don't cause data races on shared handler state.
|
||||
var counter int
|
||||
ext := makeHandlerExt("shared-state.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
// Read-modify-write: racy without serialization.
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
SubagentChunk: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter
|
||||
counter = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
_, _ = r.Emit(SubagentChunkEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if counter != goroutines*iterations*2 {
|
||||
t.Errorf("expected counter=%d, got %d (race detected)", goroutines*iterations*2, counter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ConcurrentEmitDifferentExtensions(t *testing.T) {
|
||||
// Two extensions with independent state should not block each other
|
||||
// and should both run correctly under concurrent Emit calls.
|
||||
var counter1, counter2 int
|
||||
ext1 := makeHandlerExt("ext1.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter1
|
||||
counter1 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
ext2 := makeHandlerExt("ext2.go", map[EventType][]HandlerFunc{
|
||||
SubagentStart: {
|
||||
func(e Event, c Context) Result {
|
||||
v := counter2
|
||||
counter2 = v + 1
|
||||
return nil
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
r := makeRunner(ext1, ext2)
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 20
|
||||
const iterations = 50
|
||||
wg.Add(goroutines)
|
||||
for range goroutines {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range iterations {
|
||||
_, _ = r.Emit(SubagentStartEvent{ToolCallID: "x"})
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
expected := goroutines * iterations
|
||||
if counter1 != expected {
|
||||
t.Errorf("ext1 counter: expected %d, got %d", expected, counter1)
|
||||
}
|
||||
if counter2 != expected {
|
||||
t.Errorf("ext2 counter: expected %d, got %d", expected, counter2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunner_ReentrantEmitCustomEvent(t *testing.T) {
|
||||
// Verify that a handler can call EmitCustomEvent (which dispatches to
|
||||
// the same extension's custom event handlers) without deadlocking.
|
||||
var order []string
|
||||
ext := LoadedExtension{
|
||||
Path: "reentrant.go",
|
||||
Handlers: map[EventType][]HandlerFunc{
|
||||
SessionStart: {
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
// This triggers EmitCustomEvent for the same extension
|
||||
// via a direct runner call (simulating ctx.EmitCustomEvent).
|
||||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
CustomEventHandlers: map[string][]func(string){
|
||||
"test-event": {
|
||||
func(data string) {
|
||||
order = append(order, "custom:"+data)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
r := makeRunner(ext)
|
||||
|
||||
// Wire up the handler to call EmitCustomEvent re-entrantly.
|
||||
ext.Handlers[SessionStart] = []HandlerFunc{
|
||||
func(e Event, c Context) Result {
|
||||
order = append(order, "session_start")
|
||||
r.EmitCustomEvent("test-event", "hello")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
r.extensions[0] = ext
|
||||
// Rebuild mutexes after modifying extensions slice.
|
||||
r.extMu = make([]reentrantMu, len(r.extensions))
|
||||
for i := range r.extMu {
|
||||
r.extMu[i].init()
|
||||
}
|
||||
|
||||
_, err := r.Emit(SessionStartEvent{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(order) != 2 || order[0] != "session_start" || order[1] != "custom:hello" {
|
||||
t.Errorf("expected [session_start, custom:hello], got %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,10 +173,10 @@ type subagentJSONOutput struct {
|
||||
} `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
var subagentCounter uint64
|
||||
var subagentCounter atomic.Uint64
|
||||
|
||||
func generateSubagentID() string {
|
||||
n := atomic.AddUint64(&subagentCounter, 1)
|
||||
n := subagentCounter.Add(1)
|
||||
return fmt.Sprintf("sub-%d-%d", time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ func Symbols() interp.Exports {
|
||||
// Session types
|
||||
"SessionMessage": reflect.ValueOf((*SessionMessage)(nil)),
|
||||
"ExtensionEntry": reflect.ValueOf((*ExtensionEntry)(nil)),
|
||||
"SessionUsage": reflect.ValueOf((*SessionUsage)(nil)),
|
||||
|
||||
// Option types
|
||||
"OptionDef": reflect.ValueOf((*OptionDef)(nil)),
|
||||
@@ -44,6 +45,8 @@ func Symbols() interp.Exports {
|
||||
// LLM completion types
|
||||
"CompleteRequest": reflect.ValueOf((*CompleteRequest)(nil)),
|
||||
"CompleteResponse": reflect.ValueOf((*CompleteResponse)(nil)),
|
||||
"CompactConfig": reflect.ValueOf((*CompactConfig)(nil)),
|
||||
"FilePart": reflect.ValueOf((*FilePart)(nil)),
|
||||
|
||||
// Status bar types
|
||||
"StatusBarEntry": reflect.ValueOf((*StatusBarEntry)(nil)),
|
||||
@@ -128,6 +131,24 @@ func Symbols() interp.Exports {
|
||||
"ThemeColor": reflect.ValueOf((*ThemeColor)(nil)),
|
||||
"ThemeColorConfig": reflect.ValueOf((*ThemeColorConfig)(nil)),
|
||||
|
||||
// Tree navigation types
|
||||
"TreeNode": reflect.ValueOf((*TreeNode)(nil)),
|
||||
"TreeNavigationResult": reflect.ValueOf((*TreeNavigationResult)(nil)),
|
||||
|
||||
// Skill types
|
||||
"Skill": reflect.ValueOf((*Skill)(nil)),
|
||||
"SkillLoadResult": reflect.ValueOf((*SkillLoadResult)(nil)),
|
||||
|
||||
// Template parsing types
|
||||
"PromptTemplate": reflect.ValueOf((*PromptTemplate)(nil)),
|
||||
"ArgumentPattern": reflect.ValueOf((*ArgumentPattern)(nil)),
|
||||
"ParseResult": reflect.ValueOf((*ParseResult)(nil)),
|
||||
"ModelConditional": reflect.ValueOf((*ModelConditional)(nil)),
|
||||
|
||||
// Model resolution types
|
||||
"ModelCapabilities": reflect.ValueOf((*ModelCapabilities)(nil)),
|
||||
"ModelResolutionResult": reflect.ValueOf((*ModelResolutionResult)(nil)),
|
||||
|
||||
// Event structs
|
||||
"ToolCallEvent": reflect.ValueOf((*ToolCallEvent)(nil)),
|
||||
"ToolCallResult": reflect.ValueOf((*ToolCallResult)(nil)),
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
// Watcher monitors extension directories for file changes and triggers
|
||||
// a reload callback when .go files are created, modified, or removed.
|
||||
// It uses fsnotify for kernel-level file notifications (inotify on Linux,
|
||||
// kqueue on macOS) with debouncing to coalesce rapid editor writes.
|
||||
type Watcher struct {
|
||||
watcher *fsnotify.Watcher
|
||||
onReload func()
|
||||
debounce time.Duration
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewWatcher creates a file watcher that monitors the given directories
|
||||
// for .go file changes. When a change is detected (after debouncing),
|
||||
// onReload is called. The watcher must be started with Start() and
|
||||
// stopped with Close().
|
||||
func NewWatcher(dirs []string, onReload func()) (*Watcher, error) {
|
||||
fsw, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating file watcher: %w", err)
|
||||
}
|
||||
|
||||
for _, dir := range dirs {
|
||||
// Watch the directory itself.
|
||||
if err := fsw.Add(dir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping directory: dir=%s err=%v", dir, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Also watch immediate subdirectories (for */main.go pattern).
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
subdir := filepath.Join(dir, entry.Name())
|
||||
if err := fsw.Add(subdir); err != nil {
|
||||
log.Printf("DEBUG watcher: skipping subdirectory: dir=%s err=%v", subdir, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
watcher: fsw,
|
||||
onReload: onReload,
|
||||
debounce: 300 * time.Millisecond,
|
||||
done: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching for file changes. It blocks until the context
|
||||
// is cancelled or Close() is called. Typically called in a goroutine.
|
||||
func (w *Watcher) Start(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
ctx, w.cancel = context.WithCancel(ctx)
|
||||
w.mu.Unlock()
|
||||
|
||||
defer close(w.done)
|
||||
|
||||
var timer *time.Timer
|
||||
var timerC <-chan time.Time
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
return
|
||||
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Only care about .go files.
|
||||
if !strings.HasSuffix(event.Name, ".go") {
|
||||
continue
|
||||
}
|
||||
|
||||
// React to write, create, remove, rename events.
|
||||
if event.Op&(fsnotify.Write|fsnotify.Create|fsnotify.Remove|fsnotify.Rename) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("DEBUG watcher: file changed: file=%s op=%s", event.Name, event.Op)
|
||||
|
||||
// Debounce: reset timer on each event.
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
timer = time.NewTimer(w.debounce)
|
||||
timerC = timer.C
|
||||
|
||||
case <-timerC:
|
||||
timerC = nil
|
||||
timer = nil
|
||||
log.Printf("DEBUG watcher: reloading extensions")
|
||||
w.onReload()
|
||||
|
||||
case err, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Printf("WARN watcher: error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the watcher and releases resources.
|
||||
func (w *Watcher) Close() error {
|
||||
w.mu.Lock()
|
||||
cancel := w.cancel
|
||||
w.mu.Unlock()
|
||||
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Wait for the event loop to finish.
|
||||
<-w.done
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
// WatchedDirs returns the directories to watch for extension changes.
|
||||
// This includes the global extensions directory and the project-local
|
||||
// .kit/extensions/ directory (if they exist). Explicit -e paths that
|
||||
// point to directories are also included; explicit file paths cause
|
||||
// their parent directory to be watched instead.
|
||||
func WatchedDirs(extraPaths []string) []string {
|
||||
var dirs []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
add := func(dir string) {
|
||||
abs, err := filepath.Abs(dir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if seen[abs] {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the directory exists.
|
||||
info, err := os.Stat(abs)
|
||||
if err != nil || !info.IsDir() {
|
||||
return
|
||||
}
|
||||
|
||||
seen[abs] = true
|
||||
dirs = append(dirs, abs)
|
||||
}
|
||||
|
||||
// Global extensions dir.
|
||||
add(globalExtensionsDir())
|
||||
|
||||
// Project-local extensions dir.
|
||||
add(filepath.Join(".kit", "extensions"))
|
||||
|
||||
// Explicit paths that are directories.
|
||||
for _, p := range extraPaths {
|
||||
info, err := os.Stat(p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.IsDir() {
|
||||
add(p)
|
||||
} else {
|
||||
// For explicit files, watch the parent directory.
|
||||
add(filepath.Dir(p))
|
||||
}
|
||||
}
|
||||
|
||||
return dirs
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package extensions
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatcher_ReloadsOnGoFileChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Write an initial extension file.
|
||||
extFile := filepath.Join(dir, "test.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Modify the file.
|
||||
time.Sleep(50 * time.Millisecond) // let watcher settle
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait for debounce (300ms) + margin.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_IgnoresNonGoFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
// Write a non-.go file.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
txtFile := filepath.Join(dir, "notes.txt")
|
||||
if err := os.WriteFile(txtFile, []byte("hello"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Wait past the debounce window.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 0 {
|
||||
t.Errorf("expected 0 reloads for .txt file, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcher_Debounces(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
extFile := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(extFile, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var reloadCount atomic.Int32
|
||||
|
||||
w, err := NewWatcher([]string{dir}, func() {
|
||||
reloadCount.Add(1)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go w.Start(t.Context())
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Rapid-fire writes (simulating editor save: write temp, rename, etc.).
|
||||
for range 5 {
|
||||
if err := os.WriteFile(extFile, []byte("package main\n// changed\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Wait for debounce to fire.
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
if got := reloadCount.Load(); got != 1 {
|
||||
t.Errorf("expected 1 debounced reload, got %d", got)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_Deduplicates(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dirs := WatchedDirs([]string{dir, dir})
|
||||
|
||||
count := 0
|
||||
for _, d := range dirs {
|
||||
abs, _ := filepath.Abs(dir)
|
||||
if d == abs {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected directory to appear once, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatchedDirs_FileParent(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
file := filepath.Join(dir, "ext.go")
|
||||
if err := os.WriteFile(file, []byte("package main\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dirs := WatchedDirs([]string{file})
|
||||
|
||||
abs, _ := filepath.Abs(dir)
|
||||
found := false
|
||||
for _, d := range dirs {
|
||||
if d == abs {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected parent dir %s in watched dirs %v", abs, dirs)
|
||||
}
|
||||
}
|
||||
@@ -28,11 +28,11 @@ func WrapToolsWithExtensions(tools []fantasy.AgentTool, runner *Runner) []fantas
|
||||
return wrapped
|
||||
}
|
||||
|
||||
// ExtensionToolsAsFantasy converts ToolDef values registered by extensions
|
||||
// into fantasy.AgentTool implementations so the LLM can invoke them.
|
||||
// ExtensionToolsAsLLMTools converts ToolDef values registered by extensions
|
||||
// into LLM agent tool implementations so the LLM can invoke them.
|
||||
// The runner is optional; if provided, ToolContext.OnProgress routes
|
||||
// progress messages through the runner's Print function.
|
||||
func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
func ExtensionToolsAsLLMTools(defs []ToolDef, runner *Runner) []fantasy.AgentTool {
|
||||
tools := make([]fantasy.AgentTool, 0, len(defs))
|
||||
for _, def := range defs {
|
||||
tools = append(tools, &extensionTool{def: def, runner: runner})
|
||||
@@ -42,14 +42,14 @@ func ExtensionToolsAsFantasy(defs []ToolDef, runner *Runner) []fantasy.AgentTool
|
||||
|
||||
// coreToolKinds maps built-in tool names to their kind classification.
|
||||
var coreToolKinds = map[string]string{
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"spawn_subagent": "agent",
|
||||
"bash": "execute",
|
||||
"edit": "edit",
|
||||
"write": "edit",
|
||||
"read": "read",
|
||||
"ls": "read",
|
||||
"grep": "search",
|
||||
"find": "search",
|
||||
"subagent": "agent",
|
||||
}
|
||||
|
||||
// toolKindFor returns the ToolKind for a given tool name, defaulting to
|
||||
@@ -154,7 +154,7 @@ func (w *wrappedTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.T
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extensionTool — wraps a ToolDef into a fantasy.AgentTool
|
||||
// extensionTool — wraps a ToolDef into an LLM agent tool
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type extensionTool struct {
|
||||
@@ -182,7 +182,7 @@ func (t *extensionTool) Info() fantasy.ToolInfo {
|
||||
info.Parameters = props
|
||||
} else {
|
||||
// Schema doesn't have "properties" — use as-is (may be
|
||||
// a flat property map already matching fantasy's format).
|
||||
// a flat property map already matching the expected format).
|
||||
info.Parameters = schema
|
||||
}
|
||||
// Extract required fields if present.
|
||||
|
||||
@@ -192,7 +192,7 @@ func TestWrappedTool_ExecutionStartEnd(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
func TestExtensionToolsAsLLMTools(t *testing.T) {
|
||||
defs := []ToolDef{
|
||||
{
|
||||
Name: "greet",
|
||||
@@ -202,7 +202,7 @@ func TestExtensionToolsAsFantasy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
@@ -232,7 +232,7 @@ func TestExtensionTool_Error(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "x"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
@@ -259,7 +259,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
}
|
||||
|
||||
// Without runner, OnProgress is a no-op.
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: "test"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -285,7 +285,7 @@ func TestExtensionTool_ExecuteWithContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools2 := ExtensionToolsAsFantasy(defs2, runner)
|
||||
tools2 := ExtensionToolsAsLLMTools(defs2, runner)
|
||||
_, err = tools2[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -306,7 +306,7 @@ func TestExtensionTool_ExecuteWithContextPriority(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
resp, err := tools[0].Run(context.Background(), fantasy.ToolCall{Input: ""})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -330,7 +330,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
_, _ = tools[0].Run(ctx, fantasy.ToolCall{Input: ""})
|
||||
if !sawCancelled {
|
||||
t.Error("expected IsCancelled=true for cancelled context")
|
||||
@@ -339,7 +339,7 @@ func TestExtensionTool_CancelledContext(t *testing.T) {
|
||||
|
||||
func TestExtensionTool_ProviderOptions(t *testing.T) {
|
||||
defs := []ToolDef{{Name: "test", Execute: func(string) (string, error) { return "", nil }}}
|
||||
tools := ExtensionToolsAsFantasy(defs, nil)
|
||||
tools := ExtensionToolsAsLLMTools(defs, nil)
|
||||
|
||||
// Initially nil.
|
||||
opts := tools[0].ProviderOptions()
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
// Package fences provides utilities for detecting markdown code regions
|
||||
// (fenced code blocks and inline code spans) and applying transformations
|
||||
// only to text outside those regions.
|
||||
//
|
||||
// This prevents special tokens like $1, $@, or @file from being interpreted
|
||||
// when they appear inside ``` fences, ~~~ fences, or `inline` code spans.
|
||||
package fences
|
||||
|
||||
import "strings"
|
||||
|
||||
// Ranges returns byte ranges [start, end) of fenced code blocks in content.
|
||||
// Recognises both backtick (```) and tilde (~~~) fences, with optional
|
||||
// leading indentation (up to 3 spaces) and optional info strings.
|
||||
// An unclosed fence extends to the end of content.
|
||||
func Ranges(content string) [][2]int {
|
||||
var result [][2]int
|
||||
var inFence bool
|
||||
var fenceChar byte
|
||||
var fenceCount int
|
||||
var fenceStart int
|
||||
|
||||
pos := 0
|
||||
for pos < len(content) {
|
||||
// Find the end of the current line.
|
||||
lineEnd := strings.IndexByte(content[pos:], '\n')
|
||||
var line string
|
||||
var nextPos int
|
||||
if lineEnd < 0 {
|
||||
line = content[pos:]
|
||||
nextPos = len(content)
|
||||
} else {
|
||||
line = content[pos : pos+lineEnd]
|
||||
nextPos = pos + lineEnd + 1
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(line, " ")
|
||||
indent := len(line) - len(trimmed)
|
||||
|
||||
if !inFence {
|
||||
if indent <= 3 {
|
||||
if ch, n := parseFenceOpen(trimmed); n > 0 {
|
||||
inFence = true
|
||||
fenceChar = ch
|
||||
fenceCount = n
|
||||
fenceStart = pos
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if indent <= 3 && isFenceClose(trimmed, fenceChar, fenceCount) {
|
||||
result = append(result, [2]int{fenceStart, nextPos})
|
||||
inFence = false
|
||||
}
|
||||
}
|
||||
|
||||
pos = nextPos
|
||||
}
|
||||
|
||||
// Unclosed fence extends to end of content.
|
||||
if inFence {
|
||||
result = append(result, [2]int{fenceStart, len(content)})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReplaceOutside applies fn to each text segment that is outside fenced code
|
||||
// blocks and inline code spans, leaving code content unchanged. This is the
|
||||
// primary entry point for callers that need to do regex replacement only on
|
||||
// non-code text.
|
||||
func ReplaceOutside(content string, fn func(string) string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return replaceOutsideInline(content, fn)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
// Within non-fenced segments, also skip inline code spans.
|
||||
b.WriteString(replaceOutsideInline(content[pos:r[0]], fn))
|
||||
}
|
||||
// Preserve fenced content verbatim.
|
||||
b.WriteString(content[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(content) {
|
||||
b.WriteString(replaceOutsideInline(content[pos:], fn))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// StripCode returns content with fenced code blocks and inline code spans
|
||||
// removed. Useful for detection/matching where only non-code text matters.
|
||||
func StripCode(content string) string {
|
||||
// First strip fenced blocks.
|
||||
stripped := StripFenced(content)
|
||||
// Then strip inline code spans from what remains.
|
||||
return stripInlineCode(stripped)
|
||||
}
|
||||
|
||||
// StripFenced returns content with fenced code block regions removed.
|
||||
// Useful for detection/matching where only non-fenced text matters.
|
||||
// NOTE: this does NOT strip inline code spans; use StripCode for both.
|
||||
func StripFenced(content string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(content[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(content[pos:])
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// parseFenceOpen checks whether trimmed (leading spaces already removed)
|
||||
// starts a fenced code block. Returns the fence character and count, or
|
||||
// (0, 0) if it is not a fence opener.
|
||||
func parseFenceOpen(trimmed string) (byte, int) {
|
||||
if len(trimmed) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
ch := trimmed[0]
|
||||
if ch != '`' && ch != '~' {
|
||||
return 0, 0
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == ch {
|
||||
count++
|
||||
}
|
||||
if count < 3 {
|
||||
return 0, 0
|
||||
}
|
||||
// Per CommonMark: backtick fences cannot have backticks in the info string.
|
||||
if ch == '`' && strings.ContainsRune(trimmed[count:], '`') {
|
||||
return 0, 0
|
||||
}
|
||||
return ch, count
|
||||
}
|
||||
|
||||
// isFenceClose checks whether trimmed is a closing fence matching fenceChar
|
||||
// with at least minCount characters. A closing fence line contains only the
|
||||
// fence characters and optional trailing spaces.
|
||||
func isFenceClose(trimmed string, fenceChar byte, minCount int) bool {
|
||||
if len(trimmed) == 0 || trimmed[0] != fenceChar {
|
||||
return false
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == fenceChar {
|
||||
count++
|
||||
}
|
||||
if count < minCount {
|
||||
return false
|
||||
}
|
||||
// Closing fence must contain only fence chars (and optional trailing spaces).
|
||||
return strings.TrimRight(trimmed[count:], " ") == ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Inline code span handling
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// inlineCodeRanges returns byte ranges [start, end) of inline code spans
|
||||
// in segment. Per CommonMark, a code span opens with N backticks and closes
|
||||
// with exactly N backticks.
|
||||
func inlineCodeRanges(s string) [][2]int {
|
||||
var result [][2]int
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
if s[i] != '`' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// Count opening backticks.
|
||||
start := i
|
||||
n := 0
|
||||
for i < len(s) && s[i] == '`' {
|
||||
n++
|
||||
i++
|
||||
}
|
||||
// Scan for a closing run of exactly n backticks.
|
||||
for j := i; j < len(s); {
|
||||
if s[j] != '`' {
|
||||
j++
|
||||
continue
|
||||
}
|
||||
m := 0
|
||||
for j < len(s) && s[j] == '`' {
|
||||
m++
|
||||
j++
|
||||
}
|
||||
if m == n {
|
||||
result = append(result, [2]int{start, j})
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no closing run was found, i is already past the opening
|
||||
// backticks so the outer loop advances naturally.
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// replaceOutsideInline applies fn only to text outside inline code spans.
|
||||
func replaceOutsideInline(segment string, fn func(string) string) string {
|
||||
ranges := inlineCodeRanges(segment)
|
||||
if len(ranges) == 0 {
|
||||
return fn(segment)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(segment))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
b.WriteString(fn(segment[pos:r[0]]))
|
||||
}
|
||||
b.WriteString(segment[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(segment) {
|
||||
b.WriteString(fn(segment[pos:]))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// stripInlineCode removes inline code spans from s.
|
||||
func stripInlineCode(s string) string {
|
||||
ranges := inlineCodeRanges(s)
|
||||
if len(ranges) == 0 {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(s[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(s[pos:])
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package fences
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want [][2]int
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world\nno code here",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single backtick fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "single tilde fence",
|
||||
content: "before\n~~~\ncode\n~~~\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "fence with info string",
|
||||
content: "before\n```go\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: [][2]int{{2, 12}, {14, 24}},
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\ncode\nmore code",
|
||||
want: [][2]int{{7, 25}},
|
||||
},
|
||||
{
|
||||
name: "longer closing fence",
|
||||
content: "before\n```\ncode\n`````\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "shorter closing fence ignored",
|
||||
content: "before\n`````\ncode\n```\nmore\n`````\nafter",
|
||||
want: [][2]int{{7, 33}},
|
||||
},
|
||||
{
|
||||
name: "indented fence up to 3 spaces",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: [][2]int{{7, 26}},
|
||||
},
|
||||
{
|
||||
name: "4 space indent is not a fence",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "backtick in info string rejects open",
|
||||
// The ```foo`bar line is not a valid opener (backtick in info).
|
||||
// The standalone ``` becomes an opener with no close.
|
||||
content: "before\n```foo`bar\ncode\n```\nafter",
|
||||
want: [][2]int{{23, 32}},
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "fence only",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "fence at end without trailing newline",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "tilde fence does not close with backticks",
|
||||
content: "~~~\ncode\n```\nmore\n~~~\nafter",
|
||||
want: [][2]int{{0, 22}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Ranges(tt.content)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("Ranges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("Ranges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world",
|
||||
want: "HELLO WORLD",
|
||||
},
|
||||
{
|
||||
name: "text around fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: "BEFORE\n```\ncode\n```\nAFTER",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "aaa\n```\nxxx\n```\nbbb\n~~~\nyyy\n~~~\nccc",
|
||||
want: "AAA\n```\nxxx\n```\nBBB\n~~~\nyyy\n~~~\nCCC",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence preserves code",
|
||||
content: "before\n```\ncode",
|
||||
want: "BEFORE\n```\ncode",
|
||||
},
|
||||
{
|
||||
name: "only fenced content",
|
||||
content: "```\ncode\n```",
|
||||
want: "```\ncode\n```",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripFenced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips fenced code",
|
||||
content: "before $1\n```\n$2 inside\n```\nafter $3",
|
||||
want: "before $1\nafter $3",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: "a\nb\nc",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\n$1 inside",
|
||||
want: "before\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripFenced(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripFenced() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineCodeRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want [][2]int
|
||||
}{
|
||||
{"no backticks", "hello world", nil},
|
||||
{"single backtick span", "use `$1` here", [][2]int{{4, 8}}},
|
||||
{"double backtick span", "use ``$1`` here", [][2]int{{4, 10}}},
|
||||
{"multiple spans", "`$1` and `$2`", [][2]int{{0, 4}, {9, 13}}},
|
||||
{"unmatched backtick", "use `$1 here", nil},
|
||||
{"mismatched backtick counts", "use ``$1` here", nil},
|
||||
{"empty inline content", "use `` `` here", [][2]int{{4, 9}}},
|
||||
{"backticks inside double", "use ``foo`bar`` here", [][2]int{{4, 15}}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := inlineCodeRanges(tt.s)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("inlineCodeRanges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("inlineCodeRanges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside_InlineCode(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inline code preserved",
|
||||
content: "use `code` here",
|
||||
want: "USE `code` HERE",
|
||||
},
|
||||
{
|
||||
name: "double backtick inline code",
|
||||
content: "use ``co`de`` here",
|
||||
want: "USE ``co`de`` HERE",
|
||||
},
|
||||
{
|
||||
name: "mixed fenced and inline",
|
||||
content: "before `x` mid\n```\nfenced\n```\nafter `y` end",
|
||||
want: "BEFORE `x` MID\n```\nfenced\n```\nAFTER `y` END",
|
||||
},
|
||||
{
|
||||
name: "only inline code",
|
||||
content: "`code`",
|
||||
want: "`code`",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no code",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips inline code",
|
||||
content: "use `$1` and `$2` for positional args",
|
||||
want: "use and for positional args",
|
||||
},
|
||||
{
|
||||
name: "strips fenced and inline",
|
||||
content: "before `$1`\n```\n$2 inside\n```\nafter",
|
||||
want: "before \nafter",
|
||||
},
|
||||
{
|
||||
name: "real world prompt template",
|
||||
content: "Use $@ for all args.\n`$1`, `$2` for positional.\n```bash\necho $1\n```\n",
|
||||
want: "Use $@ for all args.\n, for positional.\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripCode(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripCode() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+104
-24
@@ -33,6 +33,10 @@ type AgentSetupOptions struct {
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. Allows SDK users to pass custom tools (e.g. with WithWorkDir).
|
||||
CoreTools []fantasy.AgentTool
|
||||
// DisableCoreTools, when true, prevents loading any core tools.
|
||||
// If both DisableCoreTools is true and CoreTools is empty, the agent
|
||||
// will have no tools (useful for simple chat completions).
|
||||
DisableCoreTools bool
|
||||
// ExtraTools are additional tools added alongside core, MCP, and extension
|
||||
// tools. They do not replace the defaults — they extend them.
|
||||
ExtraTools []fantasy.AgentTool
|
||||
@@ -40,6 +44,34 @@ type AgentSetupOptions struct {
|
||||
// wrapping. Used by the SDK hook system. Both wrappers compose:
|
||||
// extension wrapper runs first (inner), then this wrapper (outer).
|
||||
ToolWrapper func([]fantasy.AgentTool) []fantasy.AgentTool
|
||||
|
||||
// ProviderConfig, when non-nil, is used directly instead of calling
|
||||
// BuildProviderConfig(). Callers that already hold viperInitMu can
|
||||
// pre-build this and release the lock before calling SetupAgent, so the
|
||||
// slow agent/MCP initialisation runs concurrently with other New() calls.
|
||||
ProviderConfig *models.ProviderConfig
|
||||
// Debug enables debug logging. When zero-value, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
Debug bool
|
||||
// NoExtensions skips extension loading. When false, viper is consulted.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
NoExtensions bool
|
||||
// MaxSteps overrides the agent step limit. 0 means use viper value.
|
||||
// Only meaningful when ProviderConfig is also set.
|
||||
MaxSteps int
|
||||
// StreamingEnabled controls streaming. Only meaningful when ProviderConfig
|
||||
// is also set.
|
||||
StreamingEnabled bool
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
}
|
||||
|
||||
// AgentSetupResult bundles the created agent and any debug logger so the caller
|
||||
@@ -54,15 +86,17 @@ type AgentSetupResult struct {
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
@@ -72,9 +106,6 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
@@ -82,21 +113,66 @@ func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
}
|
||||
|
||||
// SetupAgent creates an agent from the current viper state + the provided
|
||||
// options. It wraps BuildProviderConfig and agent.CreateAgent.
|
||||
func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult, error) {
|
||||
modelConfig, systemPrompt, err := BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var modelConfig *models.ProviderConfig
|
||||
var systemPrompt string
|
||||
|
||||
if opts.ProviderConfig != nil {
|
||||
// Pre-built config supplied by caller (e.g. Kit.New after releasing
|
||||
// viperInitMu). Use it directly — no viper reads needed here.
|
||||
modelConfig = opts.ProviderConfig
|
||||
systemPrompt = modelConfig.SystemPrompt
|
||||
} else {
|
||||
var err error
|
||||
modelConfig, systemPrompt, err = BuildProviderConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve debug / no-extensions / max-steps / streaming: prefer explicit
|
||||
// fields (set when ProviderConfig was pre-built) over viper fallback.
|
||||
debugEnabled := opts.Debug || viper.GetBool("debug")
|
||||
noExtensions := opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps := opts.MaxSteps
|
||||
if maxSteps == 0 {
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
}
|
||||
streamingEnabled := opts.StreamingEnabled || viper.GetBool("stream")
|
||||
|
||||
// Create the appropriate debug logger.
|
||||
var debugLogger tools.DebugLogger
|
||||
var bufferedLogger *tools.BufferedDebugLogger
|
||||
if viper.GetBool("debug") {
|
||||
if debugEnabled {
|
||||
if opts.UseBufferedLogger {
|
||||
bufferedLogger = tools.NewBufferedDebugLogger(true)
|
||||
debugLogger = bufferedLogger
|
||||
@@ -108,7 +184,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
// Load extensions unless --no-extensions is set.
|
||||
var extRunner *extensions.Runner
|
||||
var extCreationOpts extensionCreationOpts
|
||||
if !viper.GetBool("no-extensions") {
|
||||
if !noExtensions {
|
||||
var extErr error
|
||||
extRunner, extCreationOpts, extErr = loadExtensions()
|
||||
if extErr != nil {
|
||||
@@ -137,18 +213,22 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
}
|
||||
|
||||
a, err := agent.CreateAgent(ctx, &agent.AgentCreationOptions{
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: viper.GetInt("max-steps"),
|
||||
StreamingEnabled: viper.GetBool("stream"),
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
CoreTools: opts.CoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
ModelConfig: modelConfig,
|
||||
MCPConfig: opts.MCPConfig,
|
||||
SystemPrompt: systemPrompt,
|
||||
MaxSteps: maxSteps,
|
||||
StreamingEnabled: streamingEnabled,
|
||||
ShowSpinner: opts.ShowSpinner,
|
||||
Quiet: opts.Quiet,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
ExtraTools: extraTools,
|
||||
OnMCPServerLoaded: opts.OnMCPServerLoaded,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create agent: %w", err)
|
||||
@@ -187,7 +267,7 @@ func loadExtensions() (*extensions.Runner, extensionCreationOpts, error) {
|
||||
return extensions.WrapToolsWithExtensions(tools, runner)
|
||||
}
|
||||
|
||||
extTools := extensions.ExtensionToolsAsFantasy(runner.RegisteredTools(), runner)
|
||||
extTools := extensions.ExtensionToolsAsLLMTools(runner.RegisteredTools(), runner)
|
||||
|
||||
return runner, extensionCreationOpts{
|
||||
toolWrapper: wrapper,
|
||||
|
||||
+20
-10
@@ -4,12 +4,18 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
// thinkTagRegex matches ... tags that some models (Qwen, DeepSeek) wrap
|
||||
// reasoning content in. Used to strip these tags from text content.
|
||||
// The (?s) flag makes . match newlines.
|
||||
var thinkTagRegex = regexp.MustCompile(`(?s)` + `` + `think` + `` + `(.*?)` + `` + `/think` + ``)
|
||||
|
||||
// sanitizeToolCallID ensures the ID matches Anthropic's required pattern:
|
||||
// ^[a-zA-Z0-9_-]+$ (alphanumeric, underscores, and hyphens only).
|
||||
// Invalid characters are replaced with underscores.
|
||||
@@ -115,9 +121,9 @@ const (
|
||||
)
|
||||
|
||||
// Message is a single conversation message containing a heterogeneous slice
|
||||
// of ContentPart blocks. This design (borrowed from crush) enables a single
|
||||
// assistant message to carry text, reasoning, and multiple tool calls as
|
||||
// discrete, typed blocks rather than flattening everything into strings.
|
||||
// of ContentPart blocks. This design enables a single assistant message to
|
||||
// carry text, reasoning, and multiple tool calls as discrete, typed blocks
|
||||
// rather than flattening everything into strings.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Role MessageRole `json:"role"`
|
||||
@@ -312,13 +318,13 @@ func UnmarshalParts(data []byte) ([]ContentPart, error) {
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
// --- Fantasy bridge ---
|
||||
// --- LLM bridge ---
|
||||
|
||||
// ToFantasyMessages converts a Message to one or more fantasy.Message values.
|
||||
// An assistant message with tool calls produces a single fantasy message with
|
||||
// ToLLMMessages converts a Message to one or more LLM message values.
|
||||
// An assistant message with tool calls produces a single message with
|
||||
// mixed TextPart and ToolCallPart content. Tool-role messages produce
|
||||
// ToolResultPart entries.
|
||||
func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
func (m *Message) ToLLMMessages() []fantasy.Message {
|
||||
switch m.Role {
|
||||
case RoleAssistant:
|
||||
var parts []fantasy.MessagePart
|
||||
@@ -416,9 +422,9 @@ func (m *Message) ToFantasyMessages() []fantasy.Message {
|
||||
}
|
||||
}
|
||||
|
||||
// FromFantasyMessage converts a fantasy.Message into our Message type,
|
||||
// FromLLMMessage converts an LLM message into our Message type,
|
||||
// extracting all content parts into the appropriate block types.
|
||||
func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
func FromLLMMessage(msg fantasy.Message) Message {
|
||||
m := Message{
|
||||
Role: MessageRole(msg.Role),
|
||||
Parts: make([]ContentPart, 0),
|
||||
@@ -430,7 +436,11 @@ func FromFantasyMessage(msg fantasy.Message) Message {
|
||||
switch p := part.(type) {
|
||||
case fantasy.TextPart:
|
||||
if p.Text != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: p.Text})
|
||||
// Strip ... tags that some models wrap reasoning in
|
||||
cleanedText := thinkTagRegex.ReplaceAllString(p.Text, "")
|
||||
if cleanedText != "" {
|
||||
m.Parts = append(m.Parts, TextContent{Text: cleanedText})
|
||||
}
|
||||
}
|
||||
case fantasy.ToolCallPart:
|
||||
m.Parts = append(m.Parts, ToolCall{
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"maps"
|
||||
"os"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"charm.land/fantasy/providers/openai"
|
||||
)
|
||||
|
||||
// buildCacheProviderOptions returns caching options for supported models.
|
||||
// Caching is enabled by default for all supported models to reduce costs.
|
||||
// Set KIT_DISABLE_CACHE=1 or ProviderConfig.DisableCaching=true to opt out.
|
||||
func buildCacheProviderOptions(modelInfo *ModelInfo, config *ProviderConfig) fantasy.ProviderOptions {
|
||||
// Check explicit opt-out via config
|
||||
if config.DisableCaching {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check global opt-out via environment
|
||||
if os.Getenv("KIT_DISABLE_CACHE") != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if model supports caching
|
||||
if modelInfo == nil || !modelInfo.SupportsCaching() {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch modelInfo.CacheType() {
|
||||
case "anthropic-ephemeral":
|
||||
// Provider-level Anthropic caching disabled - use message-level caching instead.
|
||||
return nil
|
||||
case "openai-prompt-cache":
|
||||
return buildOpenAICacheOptions(config, modelInfo.ID)
|
||||
case "google-cached-content":
|
||||
// Google caching not yet implemented.
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// buildOpenAICacheOptions enables prompt caching for OpenAI models.
|
||||
// Uses a deterministic cache key based on system prompt and model ID.
|
||||
func buildOpenAICacheOptions(config *ProviderConfig, modelID string) fantasy.ProviderOptions {
|
||||
cacheKey := generateCacheKey(config.SystemPrompt, modelID)
|
||||
|
||||
return fantasy.ProviderOptions{
|
||||
openai.Name: &openai.ProviderOptions{
|
||||
PromptCacheKey: &cacheKey,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// generateCacheKey creates a deterministic cache key from system prompt and model.
|
||||
// This ensures the same system prompt + model combination gets cache hits.
|
||||
func generateCacheKey(systemPrompt, modelID string) string {
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = "default"
|
||||
}
|
||||
|
||||
h := sha256.New()
|
||||
h.Write([]byte(systemPrompt))
|
||||
h.Write([]byte(modelID))
|
||||
|
||||
// Prefix with "kit-" to identify KIT-generated cache keys
|
||||
return "kit-" + hex.EncodeToString(h.Sum(nil))[:24]
|
||||
}
|
||||
|
||||
// mergeProviderOptions merges multiple ProviderOptions maps.
|
||||
// Later maps take precedence over earlier ones.
|
||||
func mergeProviderOptions(opts ...fantasy.ProviderOptions) fantasy.ProviderOptions {
|
||||
result := make(fantasy.ProviderOptions)
|
||||
|
||||
for _, opt := range opts {
|
||||
maps.Copy(result, opt)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
)
|
||||
|
||||
func TestModelInfo_SupportsCaching(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected bool
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", true},
|
||||
{"Claude 4 model", "claude-4-opus", true},
|
||||
{"GPT model", "gpt-4", true},
|
||||
{"GPT-5 model", "gpt-5", true},
|
||||
{"O1 model", "o1", true},
|
||||
{"O3 model", "o3", true},
|
||||
{"O4 model", "o4-mini", true},
|
||||
{"Codex model", "codex", true},
|
||||
{"Gemini model", "gemini-2.5-pro", true},
|
||||
{"Gemini 1.5 model", "gemini-1.5-flash", true},
|
||||
{"Llama model", "llama-3", false},
|
||||
{"Unknown model", "unknown", false},
|
||||
{"Empty family", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.SupportsCaching(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.SupportsCaching() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelInfo_CacheType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
family string
|
||||
expected string
|
||||
}{
|
||||
{"Claude model", "claude-3-5-sonnet", "anthropic-ephemeral"},
|
||||
{"GPT model", "gpt-4", "openai-prompt-cache"},
|
||||
{"O1 model", "o1", "openai-prompt-cache"},
|
||||
{"Gemini model", "gemini-2.5-pro", "google-cached-content"},
|
||||
{"Unknown model", "llama-3", ""},
|
||||
{"Empty family", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := &ModelInfo{Family: tt.family}
|
||||
if got := m.CacheType(); got != tt.expected {
|
||||
t.Errorf("ModelInfo.CacheType() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCacheKey(t *testing.T) {
|
||||
key1 := generateCacheKey("system prompt", "model-id")
|
||||
key2 := generateCacheKey("system prompt", "model-id")
|
||||
if key1 != key2 {
|
||||
t.Errorf("generateCacheKey should be deterministic: got %q and %q", key1, key2)
|
||||
}
|
||||
|
||||
key3 := generateCacheKey("different prompt", "model-id")
|
||||
if key1 == key3 {
|
||||
t.Errorf("generateCacheKey should produce different keys for different inputs")
|
||||
}
|
||||
|
||||
key4 := generateCacheKey("", "model-id")
|
||||
key5 := generateCacheKey("default", "model-id")
|
||||
if key4 != key5 {
|
||||
t.Errorf("generateCacheKey should treat empty prompt as 'default'")
|
||||
}
|
||||
|
||||
if len(key1) < 4 || key1[:4] != "kit-" {
|
||||
t.Errorf("generateCacheKey should produce keys with 'kit-' prefix, got %q", key1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Disabled(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: true}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when DisableCaching=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_EnvironmentVariable(t *testing.T) {
|
||||
_ = os.Setenv("KIT_DISABLE_CACHE", "1")
|
||||
defer func() { _ = os.Unsetenv("KIT_DISABLE_CACHE") }()
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when KIT_DISABLE_CACHE is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_UnsupportedModel(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "llama-3", ID: "llama-3-70b"}
|
||||
|
||||
if opts := buildCacheProviderOptions(modelInfo, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil for unsupported model families")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_NilModelInfo(t *testing.T) {
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
|
||||
if opts := buildCacheProviderOptions(nil, config); opts != nil {
|
||||
t.Errorf("buildCacheProviderOptions should return nil when modelInfo is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_Anthropic(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{DisableCaching: false}
|
||||
modelInfo := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
// Provider-level Anthropic caching is disabled; message-level caching is used instead
|
||||
if opts != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCacheProviderOptions_OpenAI(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
config := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test system prompt",
|
||||
}
|
||||
modelInfo := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
|
||||
opts := buildCacheProviderOptions(modelInfo, config)
|
||||
if opts == nil {
|
||||
t.Fatalf("buildCacheProviderOptions should return options for OpenAI models")
|
||||
}
|
||||
|
||||
if _, ok := opts["openai"]; !ok {
|
||||
t.Errorf("buildCacheProviderOptions should include 'openai' key for GPT models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingPriorityOverThinking(t *testing.T) {
|
||||
_ = os.Unsetenv("KIT_DISABLE_CACHE")
|
||||
|
||||
// Anthropic uses message-level caching; provider-level returns nil
|
||||
config1 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
modelInfo1 := &ModelInfo{Family: "claude-3", ID: "claude-3-opus"}
|
||||
opts1 := buildCacheProviderOptions(modelInfo1, config1)
|
||||
if opts1 != nil {
|
||||
t.Logf("Provider-level Anthropic caching disabled; using message-level caching")
|
||||
}
|
||||
|
||||
// OpenAI provider-level caching works with thinking enabled
|
||||
config2 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingMedium,
|
||||
}
|
||||
modelInfo2 := &ModelInfo{Family: "gpt-4", ID: "gpt-4o"}
|
||||
opts2 := buildCacheProviderOptions(modelInfo2, config2)
|
||||
if opts2 == nil {
|
||||
t.Errorf("OpenAI caching should work with thinking enabled")
|
||||
}
|
||||
|
||||
// OpenAI caching also works with thinking disabled
|
||||
config3 := &ProviderConfig{
|
||||
DisableCaching: false,
|
||||
SystemPrompt: "test prompt",
|
||||
ThinkingLevel: ThinkingOff,
|
||||
}
|
||||
opts3 := buildCacheProviderOptions(modelInfo2, config3)
|
||||
if opts3 == nil {
|
||||
t.Errorf("OpenAI caching should work when thinking is OFF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeProviderOptions(t *testing.T) {
|
||||
opts1 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "value1"},
|
||||
}
|
||||
opts2 := fantasy.ProviderOptions{
|
||||
"provider2": &testProviderData{value: "value2"},
|
||||
}
|
||||
|
||||
merged := mergeProviderOptions(opts1, opts2)
|
||||
|
||||
if len(merged) != 2 {
|
||||
t.Errorf("mergeProviderOptions should combine options from multiple maps, got %d items", len(merged))
|
||||
}
|
||||
|
||||
if _, ok := merged["provider1"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider1' key")
|
||||
}
|
||||
|
||||
if _, ok := merged["provider2"]; !ok {
|
||||
t.Errorf("merged options should contain 'provider2' key")
|
||||
}
|
||||
|
||||
// Later options should override earlier ones
|
||||
opts3 := fantasy.ProviderOptions{
|
||||
"provider1": &testProviderData{value: "overridden"},
|
||||
}
|
||||
merged2 := mergeProviderOptions(opts1, opts3)
|
||||
|
||||
if data, ok := merged2["provider1"].(*testProviderData); ok {
|
||||
if data.value != "overridden" {
|
||||
t.Errorf("later options should override earlier ones, got %q", data.value)
|
||||
}
|
||||
}
|
||||
|
||||
if mergeProviderOptions() != nil {
|
||||
t.Errorf("mergeProviderOptions with no args should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
// testProviderData is a simple implementation of ProviderOptionsData for testing
|
||||
type testProviderData struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (t *testProviderData) Options() {}
|
||||
|
||||
func (t *testProviderData) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + t.value + `"`), nil
|
||||
}
|
||||
|
||||
func (t *testProviderData) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
+236
-9
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -31,12 +33,14 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
info := ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
Reasoning: cfg.Reasoning,
|
||||
Temperature: cfg.Temperature,
|
||||
BaseURL: cfg.BaseURL,
|
||||
APIKey: cfg.APIKey,
|
||||
Cost: Cost{
|
||||
Input: cfg.Cost.Input,
|
||||
Output: cfg.Cost.Output,
|
||||
@@ -46,19 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert custom model generation params if any are set.
|
||||
if p := convertGenerationParams(cfg.Params); p != nil {
|
||||
info.Params = p
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]*GenerationParams, len(settings))
|
||||
for modelKey, cfg := range settings {
|
||||
if p := convertGenerationParams(cfg); p != nil {
|
||||
result[modelKey] = p
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
|
||||
// Returns nil if no parameters are set.
|
||||
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
|
||||
p := &GenerationParams{}
|
||||
any := false
|
||||
|
||||
if cfg.MaxTokens != nil {
|
||||
p.MaxTokens = cfg.MaxTokens
|
||||
any = true
|
||||
}
|
||||
if cfg.Temperature != nil {
|
||||
p.Temperature = cfg.Temperature
|
||||
any = true
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
p.TopP = cfg.TopP
|
||||
any = true
|
||||
}
|
||||
if cfg.TopK != nil {
|
||||
p.TopK = cfg.TopK
|
||||
any = true
|
||||
}
|
||||
if cfg.FrequencyPenalty != nil {
|
||||
p.FrequencyPenalty = cfg.FrequencyPenalty
|
||||
any = true
|
||||
}
|
||||
if cfg.PresencePenalty != nil {
|
||||
p.PresencePenalty = cfg.PresencePenalty
|
||||
any = true
|
||||
}
|
||||
if len(cfg.StopSequences) > 0 {
|
||||
p.StopSequences = cfg.StopSequences
|
||||
any = true
|
||||
}
|
||||
if cfg.ThinkingLevel != "" {
|
||||
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
|
||||
any = true
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
p.SystemPrompt = cfg.SystemPrompt
|
||||
any = true
|
||||
}
|
||||
|
||||
if !any {
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ApplyModelSettings merges per-model generation parameter defaults from the
|
||||
// registry into a ProviderConfig. Model-level params are only applied for
|
||||
// fields where the user has not explicitly set a value (i.e., the
|
||||
// corresponding viper key is not set via CLI flag or global config).
|
||||
//
|
||||
// The lookup order is:
|
||||
// 1. modelSettings["provider/model"] from config (highest model-level priority)
|
||||
// 2. ModelInfo.Params from custom model definitions
|
||||
//
|
||||
// Both are overridden by explicit CLI flags / global config values.
|
||||
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ModelInfo.Params (from custom model definitions).
|
||||
if params == nil && modelInfo != nil && modelInfo.Params != nil {
|
||||
params = modelInfo.Params
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply each parameter only when the user hasn't explicitly set it.
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
// Resolve file paths: if the value points to an existing file, read it.
|
||||
// We check config.SystemPrompt == "" rather than isExplicitlySet because
|
||||
// viper.BindPFlag causes IsSet to return true even for unset flags.
|
||||
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSystemPromptValue resolves a system prompt value that may be either
|
||||
// inline text or a file path. If the value is a path to an existing file,
|
||||
// its contents are read and returned. Otherwise the string is returned as-is.
|
||||
// This mirrors config.LoadSystemPrompt but lives in the models package to
|
||||
// avoid circular dependencies.
|
||||
func LoadSystemPromptValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
if info, err := os.Stat(input); err == nil && !info.IsDir() {
|
||||
content, err := os.ReadFile(input)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
|
||||
return input
|
||||
}
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
// as a default). For generation params, the global config keys use
|
||||
// hyphenated names (e.g. "max-tokens", "top-p").
|
||||
//
|
||||
// Since viper merges all sources, IsSet returns true even for config
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
// These are stored on ModelInfo and applied during provider creation.
|
||||
// Nil pointer fields mean "no model-level default" — the global config
|
||||
// or CLI flag value (if any) will be used instead.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
ThinkingLevel ThinkingLevel
|
||||
SystemPrompt string // Per-model system prompt (inline text or file path)
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
|
||||
// parameter defaults. Used in both customModels[].params and modelSettings[].
|
||||
type GenerationParamsConfig struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestConvertGenerationParams(t *testing.T) {
|
||||
t.Run("empty config returns nil", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("temperature only", func(t *testing.T) {
|
||||
temp := float32(0.7)
|
||||
cfg := GenerationParamsConfig{Temperature: &temp}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP != nil {
|
||||
t.Errorf("expected nil TopP, got %v", p.TopP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all params set", func(t *testing.T) {
|
||||
maxTokens := 8192
|
||||
temp := float32(0.5)
|
||||
topP := float32(0.9)
|
||||
topK := int32(50)
|
||||
freqPenalty := float32(0.1)
|
||||
presPenalty := float32(0.2)
|
||||
cfg := GenerationParamsConfig{
|
||||
MaxTokens: &maxTokens,
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &freqPenalty,
|
||||
PresencePenalty: &presPenalty,
|
||||
StopSequences: []string{"STOP"},
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
|
||||
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP == nil || *p.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", p.TopP)
|
||||
}
|
||||
if p.TopK == nil || *p.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", p.TopK)
|
||||
}
|
||||
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
|
||||
}
|
||||
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
|
||||
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
|
||||
}
|
||||
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level parsing", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingMedium {
|
||||
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
t.Run("system prompt only", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.SystemPrompt != "You are helpful." {
|
||||
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithParams(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
topP := float32(0.95)
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
Temperature: true,
|
||||
Params: GenerationParamsConfig{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
},
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params == nil {
|
||||
t.Fatal("expected non-nil Params")
|
||||
}
|
||||
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
|
||||
}
|
||||
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params != nil {
|
||||
t.Errorf("expected nil Params, got %+v", info.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyModelSettings(t *testing.T) {
|
||||
// Save and restore viper state.
|
||||
originalViper := viper.AllSettings()
|
||||
defer func() {
|
||||
viper.Reset()
|
||||
for k, v := range originalViper {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("applies model params when not explicitly set", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
maxTokens := 4096
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
MaxTokens: &maxTokens,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature == nil || *config.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", config.TopK)
|
||||
}
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit viper values take precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.Set("temperature", 0.3)
|
||||
|
||||
temp := float32(0.8)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
explicitTemp := float32(0.3)
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
Temperature: &explicitTemp,
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Temperature should NOT be overridden because it's explicitly set in viper
|
||||
if config.Temperature == nil || *config.Temperature != 0.3 {
|
||||
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model info is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
ApplyModelSettings(config, nil)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model info without params is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{ID: "test-model"}
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Set up modelSettings in viper (simulating config file)
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"temperature": 0.5,
|
||||
"topK": 30,
|
||||
},
|
||||
})
|
||||
|
||||
// ModelInfo has different params
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// modelSettings should win over ModelInfo.Params
|
||||
if config.Temperature == nil || *config.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 30 {
|
||||
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop sequences applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
ThinkingLevel: ThinkingHigh,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "You are a coding assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "You are a coding assistant." {
|
||||
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "Model-specific prompt",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
SystemPrompt: "Global prompt",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Global system prompt should NOT be overridden because config
|
||||
// already has a non-empty SystemPrompt.
|
||||
if config.SystemPrompt != "Global prompt" {
|
||||
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt from file path", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Create a temp file with a system prompt
|
||||
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "Prompt from file" {
|
||||
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"systemPrompt": "From modelSettings",
|
||||
},
|
||||
})
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "From custom model",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "From modelSettings" {
|
||||
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -48,10 +48,10 @@ type modelsDBLimit struct {
|
||||
Output int `json:"output"`
|
||||
}
|
||||
|
||||
// npmToFantasyProvider maps npm package names from models.dev to fantasy
|
||||
// npmToLLMProvider maps npm package names from models.dev to LLM
|
||||
// provider identifiers. Providers not in this map but with an api URL
|
||||
// can be auto-routed through openaicompat.
|
||||
var npmToFantasyProvider = map[string]string{
|
||||
var npmToLLMProvider = map[string]string{
|
||||
"@ai-sdk/anthropic": "anthropic",
|
||||
"@ai-sdk/openai": "openai",
|
||||
"@ai-sdk/google": "google",
|
||||
|
||||
+125
-53
@@ -22,9 +22,9 @@ import (
|
||||
"charm.land/fantasy/providers/openaicompat"
|
||||
"charm.land/fantasy/providers/openrouter"
|
||||
"charm.land/fantasy/providers/vercel"
|
||||
openaisdk "github.com/charmbracelet/openai-go"
|
||||
|
||||
"github.com/mark3labs/kit/internal/auth"
|
||||
"github.com/mark3labs/kit/internal/ui/progress"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -142,19 +142,28 @@ func ParseThinkingLevel(s string) ThinkingLevel {
|
||||
|
||||
// ProviderConfig holds configuration for creating LLM providers.
|
||||
type ProviderConfig struct {
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
ModelString string
|
||||
SystemPrompt string
|
||||
ProviderAPIKey string
|
||||
ProviderURL string
|
||||
MaxTokens int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
NumGPU *int32
|
||||
MainGPU *int32
|
||||
TLSSkipVerify bool
|
||||
ThinkingLevel ThinkingLevel
|
||||
DisableCaching bool // Opt-out: set to true to disable automatic prompt caching
|
||||
|
||||
// ProgressReaderFunc, when set, wraps an io.Reader with progress display
|
||||
// for long operations like Ollama model pulls. The returned io.ReadCloser
|
||||
// must be closed when done. When nil, the raw reader is consumed directly
|
||||
// with no progress UI.
|
||||
ProgressReaderFunc func(io.Reader) io.ReadCloser
|
||||
}
|
||||
|
||||
// ProviderResult contains the result of provider creation.
|
||||
@@ -237,30 +246,64 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Apply per-model generation parameter defaults. Model-level params are
|
||||
// only applied for fields where the user hasn't explicitly set a value
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
switch provider {
|
||||
case "anthropic":
|
||||
return createAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createAnthropicProvider(ctx, config, modelName)
|
||||
case "openai":
|
||||
return createOpenAIProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenAIProvider(ctx, config, modelName)
|
||||
case "google", "gemini":
|
||||
return createGoogleProvider(ctx, config, modelName)
|
||||
result, createErr = createGoogleProvider(ctx, config, modelName)
|
||||
case "ollama":
|
||||
return createOllamaProvider(ctx, config, modelName)
|
||||
result, createErr = createOllamaProvider(ctx, config, modelName)
|
||||
case "azure":
|
||||
return createAzureProvider(ctx, config, modelName)
|
||||
result, createErr = createAzureProvider(ctx, config, modelName)
|
||||
case "google-vertex-anthropic":
|
||||
return createVertexAnthropicProvider(ctx, config, modelName)
|
||||
result, createErr = createVertexAnthropicProvider(ctx, config, modelName)
|
||||
case "openrouter":
|
||||
return createOpenRouterProvider(ctx, config, modelName)
|
||||
result, createErr = createOpenRouterProvider(ctx, config, modelName)
|
||||
case "bedrock":
|
||||
return createBedrockProvider(ctx, config, modelName)
|
||||
result, createErr = createBedrockProvider(ctx, config, modelName)
|
||||
case "vercel":
|
||||
return createVercelProvider(ctx, config, modelName)
|
||||
result, createErr = createVercelProvider(ctx, config, modelName)
|
||||
case "custom":
|
||||
return createCustomProvider(ctx, config, modelName)
|
||||
result, createErr = createCustomProvider(ctx, config, modelName)
|
||||
default:
|
||||
return autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
result, createErr = autoRouteProvider(ctx, config, provider, modelName, registry)
|
||||
}
|
||||
|
||||
if createErr != nil {
|
||||
return nil, createErr
|
||||
}
|
||||
|
||||
// AUTOMATICALLY ENABLE CACHING for supported models (unless disabled).
|
||||
// This works for BOTH native and auto-routed providers by detecting
|
||||
// the model family from the model metadata.
|
||||
if cacheOpts := buildCacheProviderOptions(modelInfo, config); cacheOpts != nil {
|
||||
if result.ProviderOptions == nil {
|
||||
result.ProviderOptions = cacheOpts
|
||||
} else {
|
||||
// Merge cache options with existing provider options.
|
||||
// Only add cache options for providers that don't already have
|
||||
// options set, to avoid type conflicts (e.g., Anthropic has
|
||||
// different types for regular options vs cache control options).
|
||||
for k, v := range cacheOpts {
|
||||
if _, exists := result.ProviderOptions[k]; !exists {
|
||||
result.ProviderOptions[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// autoRouteProvider attempts to create a provider by looking up its npm package
|
||||
@@ -280,14 +323,14 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
npmPackage = modelInfo.ProviderNPM
|
||||
}
|
||||
|
||||
// Determine the fantasy provider for this npm package
|
||||
fantasyProvider := npmToFantasyProvider[npmPackage]
|
||||
if fantasyProvider == "" && providerInfo.API != "" {
|
||||
// Determine the LLM provider for this npm package
|
||||
llmProvider := npmToLLMProvider[npmPackage]
|
||||
if llmProvider == "" && providerInfo.API != "" {
|
||||
// Unknown npm but has API URL → route through openaicompat
|
||||
fantasyProvider = "openaicompat"
|
||||
llmProvider = "openaicompat"
|
||||
}
|
||||
|
||||
switch fantasyProvider {
|
||||
switch llmProvider {
|
||||
case "openaicompat":
|
||||
return createAutoRoutedOpenAICompatProvider(ctx, config, modelName, providerInfo)
|
||||
case "anthropic":
|
||||
@@ -301,7 +344,7 @@ func autoRouteProvider(ctx context.Context, config *ProviderConfig, provider, mo
|
||||
}
|
||||
return createAutoRoutedOpenAIProvider(ctx, config, modelName, providerInfo)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no fantasy mapping)", provider, npmPackage)
|
||||
return nil, fmt.Errorf("unsupported provider: %s (npm: %s has no LLM provider mapping)", provider, npmPackage)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -493,13 +536,13 @@ func buildOpenAIProviderOptions(config *ProviderConfig, modelName string) fantas
|
||||
func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort {
|
||||
switch level {
|
||||
case ThinkingMinimal:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMinimal)
|
||||
return new(openai.ReasoningEffortMinimal)
|
||||
case ThinkingLow:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortLow)
|
||||
return new(openai.ReasoningEffortLow)
|
||||
case ThinkingMedium:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortMedium)
|
||||
return new(openai.ReasoningEffortMedium)
|
||||
case ThinkingHigh:
|
||||
return openai.ReasoningEffortOption(openai.ReasoningEffortHigh)
|
||||
return new(openai.ReasoningEffortHigh)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -510,10 +553,15 @@ func thinkingLevelToReasoningEffort(level ThinkingLevel) *openai.ReasoningEffort
|
||||
// SendReasoning to true and configures the thinking budget. For thinking-off
|
||||
// or non-reasoning models the returned map is nil.
|
||||
//
|
||||
// NOTE: With message-level caching, thinking and caching can work together.
|
||||
// Message-level cache control (ProviderCacheControlOptions) doesn't conflict
|
||||
// with provider-level thinking options (ProviderOptions).
|
||||
//
|
||||
// Anthropic requires max_tokens > thinking.budget_tokens. If the configured
|
||||
// MaxTokens is too low, it is bumped to budget + 4096 to leave room for the
|
||||
// actual response.
|
||||
func buildAnthropicProviderOptions(config *ProviderConfig, modelName string) fantasy.ProviderOptions {
|
||||
// Thinking is OFF by default. If user hasn't explicitly enabled it, return nil.
|
||||
if config.ThinkingLevel == "" || config.ThinkingLevel == ThinkingOff {
|
||||
return nil
|
||||
}
|
||||
@@ -963,12 +1011,29 @@ func createVercelProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
return &ProviderResult{Model: model}, nil
|
||||
}
|
||||
|
||||
// customToPromptFunc converts prompts to OpenAI format using the default conversion.
|
||||
func customToPromptFunc(prompt fantasy.Prompt, systemPrompt, user string) ([]openaisdk.ChatCompletionMessageParamUnion, []fantasy.CallWarning) {
|
||||
return openai.DefaultToPrompt(prompt, systemPrompt, user)
|
||||
}
|
||||
|
||||
func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName string) (*ProviderResult, error) {
|
||||
if config.ProviderURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url")
|
||||
// Resolve base URL: per-model override > global provider-url flag/config
|
||||
registry := GetGlobalRegistry()
|
||||
modelInfo := registry.LookupModel("custom", modelName)
|
||||
|
||||
baseURL := config.ProviderURL
|
||||
if modelInfo != nil && modelInfo.BaseURL != "" {
|
||||
baseURL = modelInfo.BaseURL
|
||||
}
|
||||
|
||||
if baseURL == "" {
|
||||
return nil, fmt.Errorf("custom provider requires --provider-url or a baseUrl in the model config")
|
||||
}
|
||||
|
||||
apiKey := config.ProviderAPIKey
|
||||
if modelInfo != nil && modelInfo.APIKey != "" {
|
||||
apiKey = modelInfo.APIKey
|
||||
}
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("CUSTOM_API_KEY")
|
||||
}
|
||||
@@ -977,16 +1042,21 @@ func createCustomProvider(ctx context.Context, config *ProviderConfig, modelName
|
||||
apiKey = "custom"
|
||||
}
|
||||
|
||||
var opts []openaicompat.Option
|
||||
opts = append(opts, openaicompat.WithBaseURL(config.ProviderURL))
|
||||
opts = append(opts, openaicompat.WithAPIKey(apiKey))
|
||||
opts = append(opts, openaicompat.WithName("custom"))
|
||||
// <think> tag extraction is handled transparently at the agent layer,
|
||||
// so no provider-level hooks are needed here.
|
||||
var opts []openai.Option
|
||||
opts = append(opts, openai.WithBaseURL(baseURL))
|
||||
opts = append(opts, openai.WithAPIKey(apiKey))
|
||||
opts = append(opts, openai.WithName("custom"))
|
||||
opts = append(opts, openai.WithLanguageModelOptions(
|
||||
openai.WithLanguageModelToPromptFunc(customToPromptFunc),
|
||||
))
|
||||
|
||||
if config.TLSSkipVerify {
|
||||
opts = append(opts, openaicompat.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
opts = append(opts, openai.WithHTTPClient(createHTTPClientWithTLSConfig(true)))
|
||||
}
|
||||
|
||||
p, err := openaicompat.New(opts...)
|
||||
p, err := openai.New(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create custom provider: %w", err)
|
||||
}
|
||||
@@ -1063,7 +1133,7 @@ func loadOllamaModelWithFallback(ctx context.Context, baseURL, modelName string,
|
||||
// Phase 1: Check if model exists locally
|
||||
if err := checkOllamaModelExists(client, baseURL, modelName); err != nil {
|
||||
// Phase 2: Pull model if not found
|
||||
if err := pullOllamaModel(ctx, client, baseURL, modelName); err != nil {
|
||||
if err := pullOllamaModel(ctx, client, baseURL, modelName, config.ProgressReaderFunc); err != nil {
|
||||
return nil, fmt.Errorf("failed to pull model %s: %v", modelName, err)
|
||||
}
|
||||
}
|
||||
@@ -1106,6 +1176,12 @@ func buildOllamaOptions(config *ProviderConfig) map[string]any {
|
||||
if config.TopK != nil {
|
||||
options["top_k"] = int(*config.TopK)
|
||||
}
|
||||
if config.FrequencyPenalty != nil {
|
||||
options["frequency_penalty"] = *config.FrequencyPenalty
|
||||
}
|
||||
if config.PresencePenalty != nil {
|
||||
options["presence_penalty"] = *config.PresencePenalty
|
||||
}
|
||||
if len(config.StopSequences) > 0 {
|
||||
options["stop"] = config.StopSequences
|
||||
}
|
||||
@@ -1146,11 +1222,7 @@ func checkOllamaModelExists(client *http.Client, baseURL, modelName string) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string) error {
|
||||
return pullOllamaModelWithProgress(ctx, client, baseURL, modelName, true)
|
||||
}
|
||||
|
||||
func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseURL, modelName string, showProgress bool) error {
|
||||
func pullOllamaModel(ctx context.Context, client *http.Client, baseURL, modelName string, progressFn func(io.Reader) io.ReadCloser) error {
|
||||
reqBody := map[string]string{"name": modelName}
|
||||
jsonBody, _ := json.Marshal(reqBody)
|
||||
|
||||
@@ -1174,10 +1246,10 @@ func pullOllamaModelWithProgress(ctx context.Context, client *http.Client, baseU
|
||||
return fmt.Errorf("failed to pull model (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
if showProgress {
|
||||
progressReader := progress.NewProgressReader(resp.Body)
|
||||
defer func() { _ = progressReader.Close() }()
|
||||
_, err = io.ReadAll(progressReader)
|
||||
if progressFn != nil {
|
||||
pr := progressFn(resp.Body)
|
||||
defer func() { _ = pr.Close() }()
|
||||
_, err = io.ReadAll(pr)
|
||||
} else {
|
||||
_, err = io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
+114
-9
@@ -17,12 +17,58 @@ var embeddedModelsJSON []byte
|
||||
type ModelInfo struct {
|
||||
ID string
|
||||
Name string
|
||||
Family string // Model family (e.g., "claude", "gpt", "gemini")
|
||||
Attachment bool
|
||||
Reasoning bool
|
||||
Temperature bool
|
||||
Cost Cost
|
||||
Limit Limit
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
|
||||
// Params holds per-model generation parameter defaults. These are applied
|
||||
// when the user hasn't explicitly set the corresponding CLI flag or global
|
||||
// config value. Nil pointer fields mean "no model-level default".
|
||||
Params *GenerationParams
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
// This enables automatic cost savings for supported models regardless of provider.
|
||||
func (m *ModelInfo) SupportsCaching() bool {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return true
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// CacheType returns the appropriate cache mechanism for this model family.
|
||||
// Returns empty string if caching is not supported.
|
||||
func (m *ModelInfo) CacheType() string {
|
||||
switch {
|
||||
case strings.HasPrefix(m.Family, "claude"):
|
||||
return "anthropic-ephemeral"
|
||||
case strings.HasPrefix(m.Family, "gpt"),
|
||||
strings.HasPrefix(m.Family, "o1"),
|
||||
strings.HasPrefix(m.Family, "o3"),
|
||||
strings.HasPrefix(m.Family, "o4"),
|
||||
strings.HasPrefix(m.Family, "codex"):
|
||||
return "openai-prompt-cache"
|
||||
case strings.HasPrefix(m.Family, "gemini"):
|
||||
return "google-cached-content"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Cost represents the pricing information for a model.
|
||||
@@ -86,6 +132,7 @@ func buildFromModelsDB() map[string]ProviderInfo {
|
||||
modelsMap[modelID] = ModelInfo{
|
||||
ID: dm.ID,
|
||||
Name: dm.Name,
|
||||
Family: dm.Family,
|
||||
Attachment: dm.Attachment,
|
||||
Reasoning: dm.Reasoning,
|
||||
Temperature: dm.Temperature,
|
||||
@@ -194,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// LookupModelForSettings is a convenience function that parses a
|
||||
// "provider/model" string and looks up the ModelInfo in the global registry.
|
||||
// Returns nil when the model string is invalid or the model is unknown.
|
||||
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
||||
func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return GetGlobalRegistry().LookupModel(provider, modelName)
|
||||
}
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
@@ -308,27 +367,27 @@ func (r *ModelsRegistry) GetSupportedProviders() []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// GetFantasyProviders returns provider IDs that can be used with fantasy,
|
||||
// GetLLMProviders returns provider IDs that have LLM support,
|
||||
// either through a native provider or via openaicompat auto-routing.
|
||||
func (r *ModelsRegistry) GetFantasyProviders() []string {
|
||||
func (r *ModelsRegistry) GetLLMProviders() []string {
|
||||
var providers []string
|
||||
for providerID, info := range r.providers {
|
||||
if isProviderFantasySupported(providerID, &info) {
|
||||
if isProviderLLMSupported(providerID, &info) {
|
||||
providers = append(providers, providerID)
|
||||
}
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// isProviderFantasySupported checks if a provider can be used with fantasy.
|
||||
func isProviderFantasySupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama is always supported (via openaicompat pointed at localhost)
|
||||
if providerID == "ollama" {
|
||||
// isProviderLLMSupported checks if a provider can be used with the LLM layer.
|
||||
func isProviderLLMSupported(providerID string, info *ProviderInfo) bool {
|
||||
// Ollama and custom are always supported (model names are user-defined).
|
||||
if providerID == "ollama" || providerID == "custom" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if npm maps to a fantasy provider
|
||||
if _, ok := npmToFantasyProvider[info.NPM]; ok {
|
||||
// Check if npm maps to an LLM provider
|
||||
if _, ok := npmToLLMProvider[info.NPM]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -355,6 +414,52 @@ func (r *ModelsRegistry) GetProviderInfo(provider string) *ProviderInfo {
|
||||
return &info
|
||||
}
|
||||
|
||||
// ValidateModelString checks whether a model string is well-formed and refers
|
||||
// to a known provider. It returns a user-friendly error with suggestions when
|
||||
// the model or provider is unrecognised. Passing validation does not guarantee
|
||||
// that API authentication will succeed — it only catches obvious mistakes
|
||||
// (typos, missing provider prefix, non-existent provider names) early so that
|
||||
// callers such as subagent spawning can return fast feedback.
|
||||
//
|
||||
// Unknown models under a known provider are allowed (the provider API is the
|
||||
// authority), but a completely unknown provider is rejected.
|
||||
func (r *ModelsRegistry) ValidateModelString(modelString string) error {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ollama and custom are always valid — model names are user-defined.
|
||||
if provider == "ollama" || provider == "custom" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the provider exists in the registry.
|
||||
providerInfo := r.GetProviderInfo(provider)
|
||||
if providerInfo == nil {
|
||||
known := r.GetSupportedProviders()
|
||||
return fmt.Errorf(
|
||||
"unknown provider %q in model string %q. Known providers: %s",
|
||||
provider, modelString, strings.Join(known, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// Provider exists — check if the model is known. An unknown model is
|
||||
// only a warning (the provider API decides), but we surface suggestions
|
||||
// so the caller can self-correct.
|
||||
if r.LookupModel(provider, modelName) == nil {
|
||||
if suggestions := r.SuggestModels(provider, modelName); len(suggestions) > 0 {
|
||||
return fmt.Errorf(
|
||||
"model %q not found for provider %s. Did you mean one of: %s",
|
||||
modelName, provider, strings.Join(suggestions, ", "),
|
||||
)
|
||||
}
|
||||
// No suggestions — let it through; the provider API is the authority.
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
var globalRegistry = NewModelsRegistry()
|
||||
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidateModelString(t *testing.T) {
|
||||
registry := GetGlobalRegistry()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
wantErr bool
|
||||
errSubstr string // expected substring in error message (empty = don't check)
|
||||
}{
|
||||
{
|
||||
name: "valid anthropic model",
|
||||
model: "anthropic/claude-sonnet-4-6",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider prefix",
|
||||
model: "claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
model: "",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown provider",
|
||||
model: "fakeprovider/some-model",
|
||||
wantErr: true,
|
||||
errSubstr: "unknown provider",
|
||||
},
|
||||
{
|
||||
name: "ollama always valid",
|
||||
model: "ollama/llama3",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "custom always valid",
|
||||
model: "custom/my-fine-tune",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty provider",
|
||||
model: "/claude-sonnet-4-6",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "empty model name",
|
||||
model: "anthropic/",
|
||||
wantErr: true,
|
||||
errSubstr: "invalid model format",
|
||||
},
|
||||
{
|
||||
name: "unknown model under known provider (no suggestions)",
|
||||
model: "anthropic/totally-unknown-xyz-999",
|
||||
wantErr: false, // no suggestions → passes through
|
||||
},
|
||||
{
|
||||
name: "typo model under known provider with suggestions",
|
||||
model: "anthropic/claude-sonet", // misspelled "sonnet"
|
||||
wantErr: true,
|
||||
errSubstr: "Did you mean",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := registry.ValidateModelString(tt.model)
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("ValidateModelString(%q) = nil, want error", tt.model)
|
||||
}
|
||||
if !tt.wantErr && err != nil {
|
||||
t.Errorf("ValidateModelString(%q) = %v, want nil", tt.model, err)
|
||||
}
|
||||
if tt.errSubstr != "" && err != nil {
|
||||
if !strings.Contains(err.Error(), tt.errSubstr) {
|
||||
t.Errorf("ValidateModelString(%q) error = %q, want substring %q",
|
||||
tt.model, err.Error(), tt.errSubstr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,10 @@ package prompts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
// LoadOptions configures how templates are discovered and loaded.
|
||||
@@ -74,10 +73,7 @@ func LoadAll(opts LoadOptions) ([]*PromptTemplate, []Diagnostic, error) {
|
||||
DroppedPath: tpl.FilePath,
|
||||
Reason: fmt.Sprintf("template from %s overridden by %s", source, existing.Source),
|
||||
})
|
||||
log.Debug("template collision",
|
||||
"name", tpl.Name,
|
||||
"dropped", tpl.FilePath,
|
||||
"kept", existing.FilePath)
|
||||
log.Printf("DEBUG template collision: name=%s dropped=%s kept=%s", tpl.Name, tpl.FilePath, existing.FilePath)
|
||||
} else {
|
||||
tpl.Source = source
|
||||
seen[tpl.Name] = tpl
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/fences"
|
||||
)
|
||||
|
||||
// PromptTemplate is a named prompt template with shell-style argument placeholders.
|
||||
// It supports Pi-style $1, $2, $@, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
// It supports Pi-style $1, $2, $@, $+, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
type PromptTemplate struct {
|
||||
// Name is the human-readable identifier for this template.
|
||||
Name string
|
||||
@@ -120,19 +122,28 @@ func ParseCommandArgs(input string) []string {
|
||||
|
||||
// argPlaceholder matches shell-style argument placeholders:
|
||||
// - $1, $2, etc. - positional arguments
|
||||
// - $@ - all arguments
|
||||
// - $@ - all arguments (zero or more)
|
||||
// - $+ - all arguments (one or more required)
|
||||
// - $ARGUMENTS - all arguments (alias for $@)
|
||||
// - ${@:N} - arguments from N onwards
|
||||
// - ${@:N:L} - L arguments starting from N
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$ARGUMENTS`)
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$\+|\$ARGUMENTS`)
|
||||
|
||||
// SubstituteArgs replaces argument placeholders in content with values from args.
|
||||
// Supported placeholders:
|
||||
// - $N, ${N} - the Nth argument (1-indexed)
|
||||
// - $@, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - $@, $+, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - ${@:N} - arguments from index N onwards (0-indexed)
|
||||
// - ${@:N:L} - L arguments starting from index N (0-indexed)
|
||||
func SubstituteArgs(content string, args []string) string {
|
||||
return fences.ReplaceOutside(content, func(segment string) string {
|
||||
return substituteArgsInSegment(segment, args)
|
||||
})
|
||||
}
|
||||
|
||||
// substituteArgsInSegment performs argument substitution on a single text
|
||||
// segment that is known to be outside fenced code blocks.
|
||||
func substituteArgsInSegment(content string, args []string) string {
|
||||
return argPlaceholder.ReplaceAllStringFunc(content, func(match string) string {
|
||||
// Check for ${N} or ${N:M} format
|
||||
if strings.HasPrefix(match, "${") && strings.Contains(match, "}") {
|
||||
@@ -191,8 +202,8 @@ func SubstituteArgs(content string, args []string) string {
|
||||
if strings.HasPrefix(match, "$") && !strings.HasPrefix(match, "${") {
|
||||
suffix := match[1:]
|
||||
|
||||
// $@ or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "ARGUMENTS" {
|
||||
// $@, $+, or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "+" || suffix == "ARGUMENTS" {
|
||||
return strings.Join(args, " ")
|
||||
}
|
||||
|
||||
@@ -266,6 +277,48 @@ func joinArgsRange(args []string, start, length int) string {
|
||||
return strings.Join(args[start:end], " ")
|
||||
}
|
||||
|
||||
// HasArgPlaceholders reports whether the template content contains any
|
||||
// argument placeholders ($1, $@, $ARGUMENTS, ${@:...}, etc.).
|
||||
// Placeholders inside fenced code blocks and inline code spans are ignored.
|
||||
func (t *PromptTemplate) HasArgPlaceholders() bool {
|
||||
return argPlaceholder.MatchString(fences.StripCode(t.Content))
|
||||
}
|
||||
|
||||
// RequiredArgs returns the number of positional arguments the template
|
||||
// expects. This is determined by the highest $N or ${N} placeholder found
|
||||
// in the content (1-indexed, so $2 means 2 args required). The $+
|
||||
// placeholder (required variadic) ensures at least 1. Optional wildcards
|
||||
// ($@, $ARGUMENTS) do not contribute to the count.
|
||||
func (t *PromptTemplate) RequiredArgs() int {
|
||||
content := fences.StripCode(t.Content)
|
||||
maxN := 0
|
||||
hasRequiredVariadic := strings.Contains(content, "$+")
|
||||
for _, match := range argPlaceholder.FindAllStringSubmatch(content, -1) {
|
||||
// Group 1: ${N} format — the N value.
|
||||
if match[1] != "" {
|
||||
if n, err := strconv.Atoi(match[1]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 2: ${N:M} format — the N value (start index).
|
||||
if match[2] != "" {
|
||||
if n, err := strconv.Atoi(match[2]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 6: $N format (no braces) — the N value.
|
||||
if match[6] != "" {
|
||||
if n, err := strconv.Atoi(match[6]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasRequiredVariadic && maxN < 1 {
|
||||
maxN = 1
|
||||
}
|
||||
return maxN
|
||||
}
|
||||
|
||||
// Expand substitutes arguments into the template content and returns the result.
|
||||
// It first parses args from the input string, then substitutes them into the template.
|
||||
func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
|
||||
@@ -129,6 +129,48 @@ func TestSubstituteArgs(t *testing.T) {
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "$1 inside code block preserved",
|
||||
content: "Use $1 here\n```bash\necho $1\n```\ndone",
|
||||
args: []string{"foo"},
|
||||
expected: "Use foo here\n```bash\necho $1\n```\ndone",
|
||||
},
|
||||
{
|
||||
name: "$@ inside code block preserved",
|
||||
content: "Run $@\n```\necho $@\n```\n",
|
||||
args: []string{"a", "b"},
|
||||
expected: "Run a b\n```\necho $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "all placeholders inside code block",
|
||||
content: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
args: []string{"x"},
|
||||
expected: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "$1 inside inline code preserved",
|
||||
content: "Use `$1` here and $1 outside",
|
||||
args: []string{"foo"},
|
||||
expected: "Use `$1` here and foo outside",
|
||||
},
|
||||
{
|
||||
name: "$+ required variadic",
|
||||
content: "Args: $+",
|
||||
args: []string{"a", "b", "c"},
|
||||
expected: "Args: a b c",
|
||||
},
|
||||
{
|
||||
name: "$+ with empty args",
|
||||
content: "Args: $+",
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "all placeholders in inline code",
|
||||
content: "Use `$1` and `$@` for args",
|
||||
args: []string{"x"},
|
||||
expected: "Use `$1` and `$@` for args",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -213,3 +255,78 @@ func TestPromptTemplateExpand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasArgPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want bool
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt with no args", false},
|
||||
{"$1 placeholder", "Create a $1 component", true},
|
||||
{"$@ placeholder", "Run with args: $@", true},
|
||||
{"$ARGUMENTS placeholder", "Features: $ARGUMENTS", true},
|
||||
{"${1} placeholder", "Name: ${1}", true},
|
||||
{"${ARGUMENTS} placeholder", "All: ${ARGUMENTS}", true},
|
||||
{"${@:1} placeholder", "Rest: ${@:1}", true},
|
||||
{"${@:1:2} placeholder", "Slice: ${@:1:2}", true},
|
||||
{"dollar in text", "Cost is one hundred dollars", false},
|
||||
{"empty content", "", false},
|
||||
{"$1 inside code block only", "Prompt\n```\necho $1\n```\n", false},
|
||||
{"$1 outside and inside code block", "Use $1 here\n```\necho $1\n```\n", true},
|
||||
{"$@ inside code block only", "Prompt\n```bash\necho $@\n```\n", false},
|
||||
{"$+ placeholder", "Run with args: $+", true},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", false},
|
||||
{"$1 inside inline code only", "Use `$1` for positional args", false},
|
||||
{"$1 outside and in inline code", "Create $1 (see `$1` syntax)", true},
|
||||
{"$@ outside $1 in inline code", "Run $@ with `$1` syntax", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.HasArgPlaceholders(); got != tt.want {
|
||||
t.Errorf("HasArgPlaceholders() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequiredArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want int
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt", 0},
|
||||
{"$1 only", "Create a $1 component", 1},
|
||||
{"$1 and $2", "Create $1 with $2", 2},
|
||||
{"$3 skipping $2", "Use $1 and $3", 3},
|
||||
{"${1} braced", "Name: ${1}", 1},
|
||||
{"${2} braced", "Name: ${1} Desc: ${2}", 2},
|
||||
{"$@ only", "Run with: $@", 0},
|
||||
{"$ARGUMENTS only", "Features: $ARGUMENTS", 0},
|
||||
{"${ARGUMENTS} only", "All: ${ARGUMENTS}", 0},
|
||||
{"$1 and $@", "Create $1 with extras: $@", 1},
|
||||
{"${@:1} slice only", "Rest: ${@:1}", 0},
|
||||
{"${@:1:2} slice only", "Slice: ${@:1:2}", 0},
|
||||
{"mixed $1 $2 and $@", "Create $1 named $2: $@", 2},
|
||||
{"empty content", "", 0},
|
||||
{"$2 inside code block only", "Prompt\n```\n$1 $2\n```\n", 0},
|
||||
{"$1 outside $2 inside code block", "Use $1\n```\n$2 inside\n```\n", 1},
|
||||
{"$+ only", "Run with: $+", 1},
|
||||
{"$+ and $2", "Create $2 with: $+", 2},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", 0},
|
||||
{"$1 and $2 in inline code only", "Use `$1` and `$2` for args", 0},
|
||||
{"$1 outside $2 in inline code", "Create $1 (see `$2`)", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.RequiredArgs(); got != tt.want {
|
||||
t.Errorf("RequiredArgs() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
)
|
||||
|
||||
// TestCompactionCreatesNewLeaf verifies that after compaction, the compaction
|
||||
// entry has no parent (creating a new root), and BuildContext returns only
|
||||
// the summary and kept messages, not the old compacted messages.
|
||||
func TestCompactionCreatesNewLeaf(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add some messages: M1, M2 (old, will be compacted), M3, M4 (kept)
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1 - old"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2 - old"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Verify initial state - all messages should be in context
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages before compaction, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs before compaction, got %d", len(entryIDs))
|
||||
}
|
||||
|
||||
// Now add a compaction entry, simulating that M3 is the first kept entry
|
||||
summary := "Summary of old messages"
|
||||
compactionID, err := tm.AppendCompaction(summary, id3, 1000, 500, 2, []string{}, []string{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to append compaction: %v", err)
|
||||
}
|
||||
|
||||
// Verify the compaction entry has no parent (empty ParentID)
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// Verify the leaf is now the compaction entry
|
||||
if tm.GetLeafID() != compactionID {
|
||||
t.Errorf("leaf should be compaction entry %q, got %q", compactionID, tm.GetLeafID())
|
||||
}
|
||||
|
||||
// Now BuildContext should return: [summary] + [M3, M4]
|
||||
messages, _, _ = tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages after compaction (summary + 2 kept), got %d", len(messages))
|
||||
}
|
||||
|
||||
// First message should be the summary
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\n"+summary {
|
||||
t.Errorf("unexpected summary text: %s", summaryText)
|
||||
}
|
||||
|
||||
// Second message should be M3 (kept)
|
||||
if messages[1].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("second message should be user (M3), got %s", messages[1].Role)
|
||||
}
|
||||
m3Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m3Text != "Message 3 - kept" {
|
||||
t.Errorf("unexpected M3 text: %s", m3Text)
|
||||
}
|
||||
|
||||
// Third message should be M4 (kept)
|
||||
if messages[2].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("third message should be assistant (M4), got %s", messages[2].Role)
|
||||
}
|
||||
m4Text := messages[2].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - kept" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
|
||||
// Verify GetContextEntryIDs returns correct IDs
|
||||
entryIDs = tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 3 {
|
||||
t.Fatalf("expected 3 entry IDs after compaction (empty for summary + 2 kept), got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First entry ID should be empty (summary has no entry)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Second and third should be id3 and id4 (the kept messages)
|
||||
if entryIDs[1] != id3 {
|
||||
t.Errorf("second entry ID should be %q (M3), got %q", id3, entryIDs[1])
|
||||
}
|
||||
if entryIDs[2] != id4 {
|
||||
t.Errorf("third entry ID should be %q (M4), got %q", id4, entryIDs[2])
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNewMessagesAfterCompaction verifies that messages appended
|
||||
// after compaction are correctly included in the context.
|
||||
func TestCompactionWithNewMessagesAfterCompaction(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add initial messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 3 - kept"}}}
|
||||
|
||||
_, _ = tm.AppendMessage(msg1)
|
||||
_, _ = tm.AppendMessage(msg2)
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
|
||||
// Compact, keeping only M3
|
||||
_, _ = tm.AppendCompaction("Summary", id3, 1000, 500, 2, []string{}, []string{})
|
||||
|
||||
// Add a new message after compaction
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 4 - after compaction"}}}
|
||||
_, _ = tm.AppendMessage(msg4)
|
||||
|
||||
// BuildContext should return: [summary] + [M4 (new after compaction)] + [M3 (kept)]
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("expected 3 messages (summary + M4 + M3), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// Verify order: summary, M4 (new), M3 (kept)
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be summary, got %s", messages[0].Role)
|
||||
}
|
||||
if messages[1].Role != fantasy.MessageRoleAssistant {
|
||||
t.Errorf("second message should be assistant (M4), got %s", messages[1].Role)
|
||||
}
|
||||
m4Text := messages[1].Content[0].(fantasy.TextPart).Text
|
||||
if m4Text != "Message 4 - after compaction" {
|
||||
t.Errorf("unexpected M4 text: %s", m4Text)
|
||||
}
|
||||
if messages[2].Role != fantasy.MessageRoleUser {
|
||||
t.Errorf("third message should be user (M3), got %s", messages[2].Role)
|
||||
}
|
||||
|
||||
// Verify that M1 is NOT in the context
|
||||
for i, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Message 1" {
|
||||
t.Errorf("Message 1 (compacted) should not be in context at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCompactionWithNoKeptMessages verifies compaction when all messages are compacted.
|
||||
func TestCompactionWithNoKeptMessages(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// Add messages that will all be compacted
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Message 1"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Message 2"}}}
|
||||
|
||||
if _, err := tm.AppendMessage(msg1); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
if _, err := tm.AppendMessage(msg2); err != nil {
|
||||
t.Fatalf("failed to append message: %v", err)
|
||||
}
|
||||
|
||||
// Compact with no kept messages (empty firstKeptEntryID)
|
||||
summary := "All messages summarized"
|
||||
compactionID, _ := tm.AppendCompaction(summary, "", 1000, 100, 2, []string{}, []string{})
|
||||
|
||||
// Verify the compaction entry has no parent
|
||||
compactionEntry := tm.GetEntry(compactionID).(*CompactionEntry)
|
||||
if compactionEntry.ParentID != "" {
|
||||
t.Errorf("compaction entry should have no parent, got %q", compactionEntry.ParentID)
|
||||
}
|
||||
|
||||
// BuildContext should return only the summary
|
||||
messages, _, _ := tm.BuildContext()
|
||||
if len(messages) != 1 {
|
||||
t.Fatalf("expected 1 message (summary only), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("message should be system summary, got %s", messages[0].Role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultipleCompactions verifies that multiple compactions work correctly.
|
||||
func TestMultipleCompactions(t *testing.T) {
|
||||
tm := InMemoryTreeSession("/test")
|
||||
|
||||
// First batch of messages
|
||||
msg1 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - User"}}}
|
||||
msg2 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 1 - Assistant"}}}
|
||||
id1, _ := tm.AppendMessage(msg1)
|
||||
id2, _ := tm.AppendMessage(msg2)
|
||||
|
||||
// First compaction
|
||||
_, _ = tm.AppendCompaction("Summary 1", id1, 1000, 500, 1, []string{}, []string{})
|
||||
|
||||
// Second batch
|
||||
msg3 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - User"}}}
|
||||
msg4 := message.Message{Role: message.RoleAssistant, Parts: []message.ContentPart{message.TextContent{Text: "Batch 2 - Assistant"}}}
|
||||
id3, _ := tm.AppendMessage(msg3)
|
||||
id4, _ := tm.AppendMessage(msg4)
|
||||
|
||||
// Second compaction (compacting the first compaction + batch 2)
|
||||
// Note: id3 is the first kept entry, so id3 and id4 should be preserved
|
||||
compactionID2, _ := tm.AppendCompaction("Summary 2", id3, 1000, 500, 3, []string{}, []string{})
|
||||
|
||||
// Verify second compaction has no parent
|
||||
compactionEntry2 := tm.GetEntry(compactionID2).(*CompactionEntry)
|
||||
if compactionEntry2.ParentID != "" {
|
||||
t.Errorf("second compaction entry should have no parent, got %q", compactionEntry2.ParentID)
|
||||
}
|
||||
|
||||
// Add final message
|
||||
msg5 := message.Message{Role: message.RoleUser, Parts: []message.ContentPart{message.TextContent{Text: "Final message"}}}
|
||||
id5, _ := tm.AppendMessage(msg5)
|
||||
|
||||
// BuildContext should include:
|
||||
// - Summary 2 (from second compaction)
|
||||
// - msg5 (final message)
|
||||
// - msg3, msg4 (kept from second compaction)
|
||||
// But NOT Summary 1 or msg1, msg2 (they're before the first kept entry of compaction 2)
|
||||
messages, _, _ := tm.BuildContext()
|
||||
|
||||
// Should have: Summary 2 + msg5 + msg3 + msg4 = 4 messages
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("expected 4 messages (Summary 2 + msg5 + msg3 + msg4), got %d: %+v", len(messages), messages)
|
||||
}
|
||||
|
||||
// First should be Summary 2
|
||||
if messages[0].Role != fantasy.MessageRoleSystem {
|
||||
t.Errorf("first message should be system (Summary 2), got %s", messages[0].Role)
|
||||
}
|
||||
summaryText := messages[0].Content[0].(fantasy.TextPart).Text
|
||||
if summaryText != "[Conversation summary — earlier messages were compacted]\n\nSummary 2" {
|
||||
t.Errorf("unexpected summary: %s", summaryText)
|
||||
}
|
||||
|
||||
// Verify msg5 is included
|
||||
foundFinal := false
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Final message" {
|
||||
foundFinal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundFinal {
|
||||
t.Error("Final message (msg5) should be in context")
|
||||
}
|
||||
|
||||
// Verify msg1, msg2 are NOT included (compacted by first compaction, then second)
|
||||
for _, msg := range messages {
|
||||
if msg.Role == fantasy.MessageRoleUser || msg.Role == fantasy.MessageRoleAssistant {
|
||||
text := msg.Content[0].(fantasy.TextPart).Text
|
||||
if text == "Batch 1 - User" || text == "Batch 1 - Assistant" {
|
||||
t.Errorf("Batch 1 messages should not be in context, found: %s", text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify entry IDs
|
||||
entryIDs := tm.GetContextEntryIDs()
|
||||
if len(entryIDs) != 4 {
|
||||
t.Fatalf("expected 4 entry IDs, got %d: %v", len(entryIDs), entryIDs)
|
||||
}
|
||||
|
||||
// First should be empty (summary)
|
||||
if entryIDs[0] != "" {
|
||||
t.Errorf("first entry ID should be empty (summary), got %q", entryIDs[0])
|
||||
}
|
||||
|
||||
// Check that id5 is in the list
|
||||
if !slices.Contains(entryIDs, id5) {
|
||||
t.Errorf("id5 (final message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id3 and id4 ARE in the list (they were kept)
|
||||
foundID3, foundID4 := false, false
|
||||
for _, id := range entryIDs {
|
||||
if id == id3 {
|
||||
foundID3 = true
|
||||
}
|
||||
if id == id4 {
|
||||
foundID4 = true
|
||||
}
|
||||
}
|
||||
if !foundID3 {
|
||||
t.Errorf("id3 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
if !foundID4 {
|
||||
t.Errorf("id4 (kept message) should be in entry IDs, got %v", entryIDs)
|
||||
}
|
||||
|
||||
// Verify id1 and id2 are NOT in the list (they were compacted away)
|
||||
for _, id := range entryIDs {
|
||||
if id == id1 || id == id2 {
|
||||
t.Errorf("id1 or id2 (compacted) should not be in entry IDs, found %q in %v", id, entryIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
EntryTypeSessionInfo EntryType = "session_info"
|
||||
EntryTypeExtensionData EntryType = "extension_data"
|
||||
EntryTypeCompaction EntryType = "compaction"
|
||||
EntryTypeSystemPrompt EntryType = "system_prompt"
|
||||
)
|
||||
|
||||
// CurrentVersion is the session format version for JSONL tree sessions.
|
||||
@@ -117,6 +118,19 @@ type CompactionEntry struct {
|
||||
ModifiedFiles []string `json:"modified_files,omitempty"`
|
||||
}
|
||||
|
||||
// SystemPromptEntry records the system prompt and model used for the session.
|
||||
// This is primarily for sharing/debugging to see what instructions were
|
||||
// active during the conversation. It does NOT participate in the tree
|
||||
// structure (no ParentID) and is not used when building LLM context.
|
||||
type SystemPromptEntry struct {
|
||||
Type EntryType `json:"type"` // always "system_prompt"
|
||||
ID string `json:"id"` // unique entry ID
|
||||
Timestamp time.Time `json:"timestamp"` // when captured
|
||||
Content string `json:"content"` // the system prompt text
|
||||
Model string `json:"model"` // the model used (e.g., "claude-sonnet-4-5")
|
||||
Provider string `json:"provider"` // the provider used (e.g., "anthropic")
|
||||
}
|
||||
|
||||
// GenerateEntryID creates a unique entry identifier (16 hex chars).
|
||||
func GenerateEntryID() string {
|
||||
bytes := make([]byte, 8)
|
||||
@@ -217,6 +231,18 @@ func NewCompactionEntry(parentID, summary, firstKeptEntryID string, tokensBefore
|
||||
}
|
||||
}
|
||||
|
||||
// NewSystemPromptEntry creates a SystemPromptEntry.
|
||||
func NewSystemPromptEntry(content, model, provider string) *SystemPromptEntry {
|
||||
return &SystemPromptEntry{
|
||||
Type: EntryTypeSystemPrompt,
|
||||
ID: GenerateEntryID(),
|
||||
Timestamp: time.Now(),
|
||||
Content: content,
|
||||
Model: model,
|
||||
Provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// --- JSONL marshaling helpers ---
|
||||
|
||||
// MarshalEntry serializes any entry to a JSON line (no trailing newline).
|
||||
@@ -295,6 +321,13 @@ func UnmarshalEntry(data []byte) (any, error) {
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
case EntryTypeSystemPrompt:
|
||||
var e SystemPromptEntry
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal system_prompt entry: %w", err)
|
||||
}
|
||||
return &e, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown entry type: %q", env.Type)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSystemPromptEntry(t *testing.T) {
|
||||
// Test creation
|
||||
content := "You are a helpful coding assistant."
|
||||
model := "claude-sonnet-4-5"
|
||||
provider := "anthropic"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
if entry.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Expected type %q, got %q", EntryTypeSystemPrompt, entry.Type)
|
||||
}
|
||||
|
||||
if entry.Content != content {
|
||||
t.Errorf("Expected content %q, got %q", content, entry.Content)
|
||||
}
|
||||
|
||||
if entry.Model != model {
|
||||
t.Errorf("Expected model %q, got %q", model, entry.Model)
|
||||
}
|
||||
|
||||
if entry.Provider != provider {
|
||||
t.Errorf("Expected provider %q, got %q", provider, entry.Provider)
|
||||
}
|
||||
|
||||
if entry.ID == "" {
|
||||
t.Error("Expected non-empty ID")
|
||||
}
|
||||
|
||||
// Test marshaling
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Test unmarshaling
|
||||
unmarshaled, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
sysPrompt, ok := unmarshaled.(*SystemPromptEntry)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *SystemPromptEntry, got %T", unmarshaled)
|
||||
}
|
||||
|
||||
if sysPrompt.Type != EntryTypeSystemPrompt {
|
||||
t.Errorf("Unmarshaled: expected type %q, got %q", EntryTypeSystemPrompt, sysPrompt.Type)
|
||||
}
|
||||
|
||||
if sysPrompt.Content != content {
|
||||
t.Errorf("Unmarshaled: expected content %q, got %q", content, sysPrompt.Content)
|
||||
}
|
||||
|
||||
if sysPrompt.Model != model {
|
||||
t.Errorf("Unmarshaled: expected model %q, got %q", model, sysPrompt.Model)
|
||||
}
|
||||
|
||||
if sysPrompt.Provider != provider {
|
||||
t.Errorf("Unmarshaled: expected provider %q, got %q", provider, sysPrompt.Provider)
|
||||
}
|
||||
|
||||
if sysPrompt.ID != entry.ID {
|
||||
t.Errorf("Unmarshaled: expected ID %q, got %q", entry.ID, sysPrompt.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemPromptEntryJSONStructure(t *testing.T) {
|
||||
content := "Test system prompt content"
|
||||
model := "gpt-4o"
|
||||
provider := "openai"
|
||||
entry := NewSystemPromptEntry(content, model, provider)
|
||||
|
||||
data, err := MarshalEntry(entry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("Failed to unmarshal to raw map: %v", err)
|
||||
}
|
||||
|
||||
if raw["type"] != "system_prompt" {
|
||||
t.Errorf("Expected type 'system_prompt', got %v", raw["type"])
|
||||
}
|
||||
|
||||
if raw["content"] != content {
|
||||
t.Errorf("Expected content %q, got %v", content, raw["content"])
|
||||
}
|
||||
|
||||
if raw["model"] != model {
|
||||
t.Errorf("Expected model %q, got %v", model, raw["model"])
|
||||
}
|
||||
|
||||
if raw["provider"] != provider {
|
||||
t.Errorf("Expected provider %q, got %v", provider, raw["provider"])
|
||||
}
|
||||
|
||||
if raw["id"] == "" || raw["id"] == nil {
|
||||
t.Error("Expected non-empty id field")
|
||||
}
|
||||
|
||||
if raw["timestamp"] == "" || raw["timestamp"] == nil {
|
||||
t.Error("Expected non-empty timestamp field")
|
||||
}
|
||||
}
|
||||
@@ -114,6 +114,187 @@ func CreateTreeSession(cwd string) (*TreeManager, error) {
|
||||
return tm, nil
|
||||
}
|
||||
|
||||
// ForkToNewSession creates a new session file containing the history up to and
|
||||
// including the target entry ID. This matches Pi's /fork behavior: it creates
|
||||
// a completely new session file with a parent_session reference, copying all
|
||||
// entries from the root to the target point.
|
||||
func (tm *TreeManager) ForkToNewSession(cwd string, targetID string) (*TreeManager, error) {
|
||||
tm.mu.RLock()
|
||||
defer tm.mu.RUnlock()
|
||||
|
||||
// Get the branch from root to target (root-to-leaf order).
|
||||
branch := tm.getBranchLocked(targetID)
|
||||
if len(branch) == 0 {
|
||||
return nil, fmt.Errorf("target entry %q not found", targetID)
|
||||
}
|
||||
|
||||
// Create a new session file.
|
||||
newTm, err := CreateTreeSession(cwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set the parent session reference in the header.
|
||||
newTm.header.ParentSession = tm.filePath
|
||||
newTm.header.ParentSessionID = tm.header.ID
|
||||
|
||||
// Rewrite the header with the parent reference.
|
||||
// We need to close and recreate the file to rewrite the header.
|
||||
if err := newTm.file.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close new session file: %w", err)
|
||||
}
|
||||
|
||||
// Recreate the file and write the updated header.
|
||||
f, err := os.Create(newTm.filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to recreate session file: %w", err)
|
||||
}
|
||||
newTm.file = f
|
||||
|
||||
if err := newTm.writeEntry(&newTm.header); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to write session header: %w", err)
|
||||
}
|
||||
|
||||
// Copy entries from the branch to the new session.
|
||||
// We need to remap IDs since the new session is independent.
|
||||
idMap := make(map[string]string) // old ID -> new ID
|
||||
var prevNewID string
|
||||
|
||||
for _, entry := range branch {
|
||||
oldID := tm.EntryID(entry)
|
||||
newID := GenerateEntryID()
|
||||
idMap[oldID] = newID
|
||||
|
||||
// Create a copy of the entry with the new ID and remapped parent.
|
||||
var newEntry any
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
newEntry = &MessageEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeMessage,
|
||||
ID: newID,
|
||||
ParentID: prevNewID, // Chain sequentially in new session
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Role: e.Role,
|
||||
Parts: e.Parts,
|
||||
Model: e.Model,
|
||||
Provider: e.Provider,
|
||||
}
|
||||
// Copy label if present.
|
||||
if label, ok := tm.labels[oldID]; ok {
|
||||
newTm.labels[newID] = label
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
newEntry = &ModelChangeEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeModelChange,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Provider: e.Provider,
|
||||
ModelID: e.ModelID,
|
||||
}
|
||||
|
||||
case *LabelEntry:
|
||||
// Remap the target ID if it's in our copied branch.
|
||||
newTargetID := e.TargetID
|
||||
if mapped, ok := idMap[e.TargetID]; ok {
|
||||
newTargetID = mapped
|
||||
}
|
||||
newEntry = &LabelEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeLabel,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
TargetID: newTargetID,
|
||||
Label: e.Label,
|
||||
}
|
||||
|
||||
case *SessionInfoEntry:
|
||||
newEntry = &SessionInfoEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeSessionInfo,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Name: e.Name,
|
||||
}
|
||||
newTm.sessionName = e.Name
|
||||
|
||||
case *ExtensionDataEntry:
|
||||
newEntry = &ExtensionDataEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeExtensionData,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
ExtType: e.ExtType,
|
||||
Data: e.Data,
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Remap the from ID if it's in our copied branch.
|
||||
newFromID := e.FromID
|
||||
if mapped, ok := idMap[e.FromID]; ok {
|
||||
newFromID = mapped
|
||||
}
|
||||
newEntry = &BranchSummaryEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeBranchSummary,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
FromID: newFromID,
|
||||
Summary: e.Summary,
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
// Remap the first kept entry ID if it's in our copied branch.
|
||||
newFirstKeptID := e.FirstKeptEntryID
|
||||
if mapped, ok := idMap[e.FirstKeptEntryID]; ok {
|
||||
newFirstKeptID = mapped
|
||||
}
|
||||
newEntry = &CompactionEntry{
|
||||
Entry: Entry{
|
||||
Type: EntryTypeCompaction,
|
||||
ID: newID,
|
||||
ParentID: prevNewID,
|
||||
Timestamp: e.Timestamp,
|
||||
},
|
||||
Summary: e.Summary,
|
||||
FirstKeptEntryID: newFirstKeptID,
|
||||
TokensBefore: e.TokensBefore,
|
||||
TokensAfter: e.TokensAfter,
|
||||
MessagesRemoved: e.MessagesRemoved,
|
||||
ReadFiles: e.ReadFiles,
|
||||
ModifiedFiles: e.ModifiedFiles,
|
||||
}
|
||||
}
|
||||
|
||||
if newEntry != nil {
|
||||
if err := newTm.appendAndPersist(newEntry); err != nil {
|
||||
_ = f.Close()
|
||||
return nil, fmt.Errorf("failed to copy entry: %w", err)
|
||||
}
|
||||
prevNewID = newID
|
||||
}
|
||||
}
|
||||
|
||||
// Set the leaf to the last entry in the new session.
|
||||
newTm.leafID = prevNewID
|
||||
|
||||
return newTm, nil
|
||||
}
|
||||
|
||||
// OpenTreeSession opens an existing JSONL session file.
|
||||
func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
@@ -181,7 +362,7 @@ func OpenTreeSession(path string) (*TreeManager, error) {
|
||||
|
||||
// Set leaf to the last entry.
|
||||
if len(tm.entries) > 0 {
|
||||
tm.leafID = tm.entryID(tm.entries[len(tm.entries)-1])
|
||||
tm.leafID = tm.EntryID(tm.entries[len(tm.entries)-1])
|
||||
}
|
||||
|
||||
// Open file for appending.
|
||||
@@ -242,9 +423,14 @@ func (tm *TreeManager) AppendMessage(msg message.Message) (string, error) {
|
||||
return entry.ID, nil
|
||||
}
|
||||
|
||||
// AppendFantasyMessage converts a fantasy.Message and appends it.
|
||||
// AppendLLMMessage converts an LLM message and appends it.
|
||||
func (tm *TreeManager) AppendLLMMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromLLMMessage(msg))
|
||||
}
|
||||
|
||||
// Deprecated: Use AppendLLMMessage instead.
|
||||
func (tm *TreeManager) AppendFantasyMessage(msg fantasy.Message) (string, error) {
|
||||
return tm.AppendMessage(message.FromFantasyMessage(msg))
|
||||
return tm.AppendLLMMessage(msg)
|
||||
}
|
||||
|
||||
// AppendModelChange records a model/provider change.
|
||||
@@ -323,11 +509,19 @@ func (tm *TreeManager) AppendExtensionData(extType, data string) (string, error)
|
||||
// AppendCompaction adds a compaction entry to the tree. The entry records
|
||||
// the summary and the ID of the first entry that should be preserved in the
|
||||
// LLM context. Messages before that entry are replaced by the summary.
|
||||
//
|
||||
// The compaction entry becomes a new "root" for the post-compaction branch
|
||||
// with no parent (empty ParentID). This breaks the parent chain so that old
|
||||
// compacted messages are no longer traversed when building context. The kept
|
||||
// messages are explicitly collected via FirstKeptEntryID in BuildContext.
|
||||
func (tm *TreeManager) AppendCompaction(summary, firstKeptEntryID string, tokensBefore, tokensAfter, messagesRemoved int, readFiles, modifiedFiles []string) (string, error) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
|
||||
entry := NewCompactionEntry(tm.leafID, summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
// The compaction entry has no parent, making it a new "root" for the
|
||||
// post-compaction branch. This ensures old compacted messages are not
|
||||
// traversed when walking from the current leaf.
|
||||
entry := NewCompactionEntry("", summary, firstKeptEntryID, tokensBefore, tokensAfter, messagesRemoved, readFiles, modifiedFiles)
|
||||
if err := tm.appendAndPersist(entry); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -497,14 +691,18 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
// Find the last compaction entry on this branch — it determines
|
||||
// which older messages are replaced by the summary.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If there is a compaction, inject the summary first.
|
||||
// If there is a compaction, inject the summary first and collect
|
||||
// the kept messages starting from FirstKeptEntryID (since the
|
||||
// compaction entry's parent chain doesn't include them).
|
||||
if lastCompaction != nil {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -514,28 +712,111 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Determine whether to skip entries (everything before firstKeptEntryID).
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
// Once we reach the first kept entry, stop skipping.
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
// Collect entries from the compaction entry itself (at compactionIndex)
|
||||
// and any entries before it in the branch (newer messages).
|
||||
for i := compactionIndex; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
// Convert branch summary to a user message for context.
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (summary injected).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Now collect the kept messages starting from FirstKeptEntryID.
|
||||
// These are not in the current branch because the compaction entry
|
||||
// is parented to the first kept entry's parent, not the first kept entry.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting
|
||||
// messages that were added after the compaction.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
// Messages after the compaction are collected from the branch walk above.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
// Process this kept entry.
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
messages = append(messages, fantasy.Message{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{
|
||||
fantasy.TextPart{
|
||||
Text: fmt.Sprintf("[Branch context: %s]", e.Summary),
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages, provider, modelID
|
||||
}
|
||||
|
||||
// No compaction - process the entire branch normally.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue // skip malformed entries
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
messages = append(messages, msgs...)
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
@@ -554,10 +835,6 @@ func (tm *TreeManager) BuildContext() (messages []fantasy.Message, provider stri
|
||||
case *ModelChangeEntry:
|
||||
provider = e.Provider
|
||||
modelID = e.ModelID
|
||||
|
||||
case *CompactionEntry:
|
||||
// Already handled above (the last one on the branch).
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -667,38 +944,99 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
|
||||
// Find the last compaction entry for skip logic.
|
||||
var lastCompaction *CompactionEntry
|
||||
var compactionIndex = -1
|
||||
for i := len(branch) - 1; i >= 0; i-- {
|
||||
if c, ok := branch[i].(*CompactionEntry); ok {
|
||||
lastCompaction = c
|
||||
compactionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var ids []string
|
||||
|
||||
// If there's a compaction summary injected, it has no entry ID.
|
||||
// If there's a compaction, we need to collect IDs from:
|
||||
// 1. Entries after the compaction entry in the branch (newer messages)
|
||||
// 2. Entries from FirstKeptEntryID onwards (kept messages)
|
||||
if lastCompaction != nil {
|
||||
ids = append(ids, "") // placeholder for the summary system message
|
||||
}
|
||||
// Placeholder for the summary system message (no entry ID).
|
||||
ids = append(ids, "")
|
||||
|
||||
skipping := lastCompaction != nil
|
||||
for _, entry := range branch {
|
||||
if skipping {
|
||||
entryID := tm.entryID(entry)
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
skipping = false
|
||||
} else {
|
||||
continue
|
||||
// Collect IDs from entries after the compaction entry (newer messages).
|
||||
for i := compactionIndex + 1; i < len(branch); i++ {
|
||||
entry := branch[i]
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect IDs from the kept messages starting at FirstKeptEntryID.
|
||||
// We iterate through entries in order (not using getBranchLocked) to avoid
|
||||
// walking back to old compacted messages.
|
||||
// We stop when we reach the compaction entry to avoid double-counting.
|
||||
if lastCompaction.FirstKeptEntryID != "" {
|
||||
found := false
|
||||
for _, entry := range tm.entries {
|
||||
entryID := tm.EntryID(entry)
|
||||
|
||||
// Skip entries until we reach the first kept entry.
|
||||
if !found {
|
||||
if entryID == lastCompaction.FirstKeptEntryID {
|
||||
found = true
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Stop when we reach the compaction entry itself.
|
||||
if entryID == lastCompaction.ID {
|
||||
break
|
||||
}
|
||||
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *BranchSummaryEntry:
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// No compaction - collect IDs from the entire branch.
|
||||
for _, entry := range branch {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
msg, err := e.ToMessage()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
msgs := msg.ToFantasyMessages()
|
||||
msgs := msg.ToLLMMessages()
|
||||
for range msgs {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
@@ -707,9 +1045,6 @@ func (tm *TreeManager) GetContextEntryIDs() []string {
|
||||
if e.Summary != "" {
|
||||
ids = append(ids, e.ID)
|
||||
}
|
||||
|
||||
case *CompactionEntry:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -737,31 +1072,41 @@ func (tm *TreeManager) GetLastCompaction() *CompactionEntry {
|
||||
|
||||
// --- Legacy bridge ---
|
||||
|
||||
// AddFantasyMessages appends multiple fantasy messages as entries. This is
|
||||
// AddLLMMessages appends multiple LLM messages as entries. This is
|
||||
// used when syncing from the agent's ConversationMessages after a step.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
func (tm *TreeManager) AddLLMMessages(msgs []fantasy.Message) error {
|
||||
for _, msg := range msgs {
|
||||
if _, err := tm.AppendFantasyMessage(msg); err != nil {
|
||||
if _, err := tm.AppendLLMMessage(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFantasyMessages builds the context and returns just the messages.
|
||||
// Deprecated: Use AddLLMMessages instead.
|
||||
func (tm *TreeManager) AddFantasyMessages(msgs []fantasy.Message) error {
|
||||
return tm.AddLLMMessages(msgs)
|
||||
}
|
||||
|
||||
// GetLLMMessages builds the context and returns just the messages.
|
||||
// This satisfies the same conceptual role as the old Manager.GetMessages().
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
func (tm *TreeManager) GetLLMMessages() []fantasy.Message {
|
||||
msgs, _, _ := tm.BuildContext()
|
||||
return msgs
|
||||
}
|
||||
|
||||
// Deprecated: Use GetLLMMessages instead.
|
||||
func (tm *TreeManager) GetFantasyMessages() []fantasy.Message {
|
||||
return tm.GetLLMMessages()
|
||||
}
|
||||
|
||||
// --- Internal helpers ---
|
||||
|
||||
// addEntryToIndex adds an entry to the in-memory indices.
|
||||
func (tm *TreeManager) addEntryToIndex(entry any) {
|
||||
tm.entries = append(tm.entries, entry)
|
||||
|
||||
id := tm.entryID(entry)
|
||||
id := tm.EntryID(entry)
|
||||
parentID := tm.entryParentID(entry)
|
||||
|
||||
if id != "" {
|
||||
@@ -798,8 +1143,8 @@ func (tm *TreeManager) writeEntry(entry any) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// entryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) entryID(entry any) string {
|
||||
// EntryID extracts the ID from any entry type.
|
||||
func (tm *TreeManager) EntryID(entry any) string {
|
||||
switch e := entry.(type) {
|
||||
case *MessageEntry:
|
||||
return e.ID
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// ConnectionPoolConfig defines configuration parameters for the MCP connection pool.
|
||||
@@ -60,34 +60,38 @@ type MCPConnection struct {
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
// If config is nil, default configuration values will be used. The pool starts a background
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
tokenStoreFactory: tokenStoreFactory,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
pool.oauthFlow = NewOAuthFlowRunner(authHandler)
|
||||
}
|
||||
|
||||
go pool.startHealthCheck()
|
||||
@@ -103,6 +107,15 @@ func (p *MCPConnectionPool) SetDebugLogger(logger DebugLogger) {
|
||||
p.debugLogger = logger
|
||||
}
|
||||
|
||||
// SetOAuthFlow sets the OAuth flow runner for the connection pool.
|
||||
// When set, the pool can trigger OAuth re-authorization when a tool call fails
|
||||
// with an OAuth error (e.g. expired token). Thread-safe and can be called at any time.
|
||||
func (p *MCPConnectionPool) SetOAuthFlow(flow *OAuthFlowRunner) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oauthFlow = flow
|
||||
}
|
||||
|
||||
// GetConnection retrieves or creates a connection for the specified MCP server.
|
||||
// If a healthy, non-idle connection exists in the pool, it will be reused.
|
||||
// Otherwise, a new connection is created and added to the pool.
|
||||
@@ -127,9 +140,7 @@ func (p *MCPConnectionPool) GetConnection(ctx context.Context, serverName string
|
||||
return conn, nil
|
||||
} else {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Connection %s unhealthy, removing", serverName))
|
||||
}
|
||||
_ = conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
@@ -232,18 +243,43 @@ func (p *MCPConnectionPool) performHealthCheck(ctx context.Context, conn *MCPCon
|
||||
|
||||
// createConnection creates a new connection
|
||||
func (p *MCPConnectionPool) createConnection(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (*MCPConnection, error) {
|
||||
client, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
mcpClient, err := p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// SSE transport can return OAuth error during Start()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry after successful auth
|
||||
mcpClient, err = p.createMCPClient(ctx, serverName, serverConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.initializeClient(ctx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
// Streamable HTTP transport returns OAuth error during Initialize()
|
||||
if p.oauthFlow != nil && IsOAuthError(err) {
|
||||
if flowErr := p.oauthFlow.RunAuthFlow(ctx, serverName, err); flowErr != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, fmt.Errorf("OAuth authorization failed: %w", flowErr)
|
||||
}
|
||||
// Retry initialization after successful auth
|
||||
if err := p.initializeClient(ctx, mcpClient); err != nil {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
_ = mcpClient.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
client: mcpClient,
|
||||
serverName: serverName,
|
||||
serverConfig: serverConfig,
|
||||
lastUsed: time.Now(),
|
||||
@@ -269,6 +305,8 @@ func (p *MCPConnectionPool) createMCPClient(ctx context.Context, serverName stri
|
||||
return p.createSSEClient(ctx, serverConfig)
|
||||
case "streamable":
|
||||
return p.createStreamableClient(ctx, serverConfig)
|
||||
case "inprocess":
|
||||
return p.createInProcessClient(serverConfig)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
|
||||
}
|
||||
@@ -325,13 +363,39 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := sseClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start SSE client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start SSE client: %w", err)
|
||||
}
|
||||
|
||||
return sseClient, nil
|
||||
@@ -356,18 +420,70 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := streamableClient.Start(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
|
||||
return nil, fmt.Errorf("failed to start streamable HTTP client: %w", err)
|
||||
}
|
||||
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createInProcessClient creates an in-process MCP client that communicates
|
||||
// directly with an *server.MCPServer in the same process. No subprocess is
|
||||
// spawned and no network I/O occurs — calls go through JSON marshal →
|
||||
// MCPServer.HandleMessage → JSON unmarshal, all in-memory.
|
||||
func (p *MCPConnectionPool) createInProcessClient(serverConfig config.MCPServerConfig) (client.MCPClient, error) {
|
||||
srv, ok := serverConfig.InProcessServer.(*server.MCPServer)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("InProcessServer must be *server.MCPServer, got %T", serverConfig.InProcessServer)
|
||||
}
|
||||
inProcessClient, err := client.NewInProcessClient(srv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create in-process client: %w", err)
|
||||
}
|
||||
return inProcessClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
|
||||
if p.tokenStoreFactory != nil {
|
||||
return p.tokenStoreFactory(serverURL)
|
||||
}
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
@@ -383,7 +499,7 @@ func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.
|
||||
|
||||
_, err := client.Initialize(initCtx, initRequest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialization timeout or failed: %v", err)
|
||||
return fmt.Errorf("initialization timeout or failed: %w", err)
|
||||
}
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
@@ -514,6 +630,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// RemoveConnection closes and removes a single connection from the pool.
|
||||
// Returns an error if the connection does not exist or if closing fails.
|
||||
// Thread-safe for concurrent use.
|
||||
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conn, exists := p.connections[serverName]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection %q not found in pool", serverName)
|
||||
}
|
||||
|
||||
err := conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
@@ -541,6 +678,9 @@ func (p *MCPConnectionPool) Close() error {
|
||||
|
||||
// isConnectionError checks if the error is connection-related
|
||||
func isConnectionError(err error) bool {
|
||||
if IsOAuthError(err) {
|
||||
return false // OAuth errors are recoverable, not connection failures
|
||||
}
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "Connection not found") ||
|
||||
strings.Contains(errStr, "transport error") ||
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
// mcpFantasyTool adapts an MCP tool to the fantasy.AgentTool interface.
|
||||
// It bridges the MCP tool protocol with fantasy's agent tool system, handling
|
||||
// name prefixing, schema conversion, connection pooling, and result marshaling.
|
||||
type mcpFantasyTool struct {
|
||||
toolInfo fantasy.ToolInfo
|
||||
mapping *toolMapping
|
||||
providerOptions fantasy.ProviderOptions
|
||||
}
|
||||
|
||||
// Info returns the fantasy tool info including name, description, and parameter schema.
|
||||
func (t *mcpFantasyTool) Info() fantasy.ToolInfo {
|
||||
return t.toolInfo
|
||||
}
|
||||
|
||||
// Run executes the MCP tool by routing through the connection pool.
|
||||
// It maps the prefixed tool name back to the original name, retrieves a healthy
|
||||
// connection, invokes the tool, and converts the MCP result to a fantasy ToolResponse.
|
||||
func (t *mcpFantasyTool) Run(ctx context.Context, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
||||
// Parse and validate JSON arguments
|
||||
var arguments any
|
||||
input := call.Input
|
||||
if input == "" || input == "{}" {
|
||||
arguments = nil
|
||||
} else {
|
||||
var temp any
|
||||
if err := json.Unmarshal([]byte(input), &temp); err != nil {
|
||||
return fantasy.NewTextErrorResponse(fmt.Sprintf("invalid JSON arguments: %v", err)), nil
|
||||
}
|
||||
arguments = json.RawMessage(input)
|
||||
}
|
||||
|
||||
// Get connection from pool with health check
|
||||
conn, err := t.mapping.manager.connectionPool.GetConnectionWithHealthCheck(
|
||||
ctx, t.mapping.serverName, t.mapping.serverConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to get healthy connection from pool: %w", err)
|
||||
}
|
||||
|
||||
// Call the MCP tool using the original (unprefixed) name
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: t.mapping.originalName,
|
||||
Arguments: arguments,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
// Mark connection as unhealthy for automatic recovery
|
||||
t.mapping.manager.connectionPool.HandleConnectionError(t.mapping.serverName, err)
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to call mcp tool: %w", err)
|
||||
}
|
||||
|
||||
// Marshal the MCP result to JSON string
|
||||
marshaledResult, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fantasy.ToolResponse{}, fmt.Errorf("failed to marshal mcp tool result: %w", err)
|
||||
}
|
||||
|
||||
// Return as text response, preserving error status from MCP
|
||||
if result.IsError {
|
||||
return fantasy.NewTextErrorResponse(string(marshaledResult)), nil
|
||||
}
|
||||
return fantasy.NewTextResponse(string(marshaledResult)), nil
|
||||
}
|
||||
|
||||
// ProviderOptions returns provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) ProviderOptions() fantasy.ProviderOptions {
|
||||
return t.providerOptions
|
||||
}
|
||||
|
||||
// SetProviderOptions sets provider-specific options for this tool.
|
||||
func (t *mcpFantasyTool) SetProviderOptions(opts fantasy.ProviderOptions) {
|
||||
t.providerOptions = opts
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTestInProcessServer creates a simple MCP server with one tool for testing.
|
||||
func newTestInProcessServer() *server.MCPServer {
|
||||
srv := server.NewMCPServer("test-server", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
srv.AddTool(
|
||||
mcp.NewTool("greet",
|
||||
mcp.WithDescription("Say hello"),
|
||||
mcp.WithString("name", mcp.Required(), mcp.Description("Name to greet")),
|
||||
),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
name, _ := req.GetArguments()["name"].(string)
|
||||
return mcp.NewToolResultText("Hello, " + name + "!"), nil
|
||||
},
|
||||
)
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestInProcessTransportType(t *testing.T) {
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessTransportTypeInferred(t *testing.T) {
|
||||
// When Type is empty but InProcessServer is set, infer "inprocess".
|
||||
cfg := config.MCPServerConfig{
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
}
|
||||
if got := cfg.GetTransportType(); got != "inprocess" {
|
||||
t.Errorf("GetTransportType() = %q, want %q", got, "inprocess")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInProcessValidation(t *testing.T) {
|
||||
// Valid: InProcessServer is set.
|
||||
validCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
InProcessServer: newTestInProcessServer(),
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := validCfg.Validate(); err != nil {
|
||||
t.Errorf("expected valid config, got error: %v", err)
|
||||
}
|
||||
|
||||
// Invalid: type is inprocess but InProcessServer is nil.
|
||||
invalidCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test": {
|
||||
Type: "inprocess",
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := invalidCfg.Validate(); err == nil {
|
||||
t.Error("expected validation error for nil InProcessServer, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessClient(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the connection is healthy and functional.
|
||||
if !conn.isHealthy {
|
||||
t.Error("expected connection to be healthy")
|
||||
}
|
||||
|
||||
// List tools to verify the connection works end-to-end.
|
||||
toolsResp, err := conn.client.ListTools(ctx, mcp.ListToolsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListTools failed: %v", err)
|
||||
}
|
||||
if len(toolsResp.Tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(toolsResp.Tools))
|
||||
}
|
||||
if toolsResp.Tools[0].Name != "greet" {
|
||||
t.Errorf("expected tool name 'greet', got %q", toolsResp.Tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessToolExecution(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
conn, err := pool.GetConnection(ctx, "test-inproc", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("GetConnection failed: %v", err)
|
||||
}
|
||||
|
||||
// Call the tool.
|
||||
result, err := conn.client.CallTool(ctx, mcp.CallToolRequest{
|
||||
Request: mcp.Request{Method: "tools/call"},
|
||||
Params: mcp.CallToolParams{
|
||||
Name: "greet",
|
||||
Arguments: map[string]any{"name": "World"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CallTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if len(result.Content) == 0 {
|
||||
t.Fatal("expected at least one content block")
|
||||
}
|
||||
text, ok := result.Content[0].(mcp.TextContent)
|
||||
if !ok {
|
||||
t.Fatalf("expected TextContent, got %T", result.Content[0])
|
||||
}
|
||||
if text.Text != "Hello, World!" {
|
||||
t.Errorf("expected 'Hello, World!', got %q", text.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPToolManagerInProcess(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
|
||||
mgr := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
count, err := mgr.AddServer(ctx, "myserver", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", count)
|
||||
}
|
||||
|
||||
tools := mgr.GetTools()
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
if tools[0].Name != "myserver__greet" {
|
||||
t.Errorf("expected tool name 'myserver__greet', got %q", tools[0].Name)
|
||||
}
|
||||
|
||||
// Execute the tool.
|
||||
input, _ := json.Marshal(map[string]any{"name": "SDK"})
|
||||
result, err := mgr.ExecuteTool(ctx, "myserver__greet", string(input))
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteTool failed: %v", err)
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("expected non-error result")
|
||||
}
|
||||
if result.Content == "" {
|
||||
t.Error("expected non-empty result content")
|
||||
}
|
||||
|
||||
// Verify result contains our greeting.
|
||||
if !strings.Contains(result.Content, "Hello, SDK!") {
|
||||
t.Errorf("expected 'Hello, SDK!' in result, got %q", result.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessInvalidServer(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Pass a non-*server.MCPServer value.
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: "not a server",
|
||||
}
|
||||
|
||||
_, err := pool.GetConnection(ctx, "bad", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid InProcessServer type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionPoolInProcessReuse(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
ctx := context.Background()
|
||||
srv := newTestInProcessServer()
|
||||
cfg := config.MCPServerConfig{
|
||||
Type: "inprocess",
|
||||
InProcessServer: srv,
|
||||
}
|
||||
|
||||
// Get connection twice — should reuse.
|
||||
conn1, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("first GetConnection failed: %v", err)
|
||||
}
|
||||
conn2, err := pool.GetConnection(ctx, "reuse-test", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("second GetConnection failed: %v", err)
|
||||
}
|
||||
if conn1 != conn2 {
|
||||
t.Error("expected same connection object on reuse")
|
||||
}
|
||||
}
|
||||
+901
-59
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,323 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// testdataDir returns the absolute path to the testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
|
||||
// at runtime and verifying tools are loaded.
|
||||
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Track callbacks.
|
||||
var mu sync.Mutex
|
||||
var loadedServer string
|
||||
var loadedCount int
|
||||
toolsChangedCount := 0
|
||||
|
||||
manager.SetOnServerLoaded(func(name string, count int, err error) {
|
||||
mu.Lock()
|
||||
loadedServer = name
|
||||
loadedCount = count
|
||||
mu.Unlock()
|
||||
})
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Add the server.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools from echo server, got %d", count)
|
||||
}
|
||||
|
||||
// Verify callbacks fired.
|
||||
mu.Lock()
|
||||
if loadedServer != "echo" {
|
||||
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
|
||||
}
|
||||
if loadedCount != 2 {
|
||||
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
|
||||
}
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify tools are accessible.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify tool names are prefixed.
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo'")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet'")
|
||||
}
|
||||
|
||||
// Verify server appears in loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
if !slices.Contains(names, "echo") {
|
||||
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
|
||||
// and verifying tools are cleaned up.
|
||||
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server first.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
toolsChangedCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Remove the server.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools are gone.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify callback fired.
|
||||
mu.Lock()
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify server is gone from loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
t.Error("Server 'echo' should not appear in loaded names after removal")
|
||||
}
|
||||
}
|
||||
|
||||
// Removing again should error.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error removing already-removed server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
|
||||
// multiple servers, verifying tool isolation.
|
||||
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add two servers with the same binary but different names.
|
||||
count1, err := manager.AddServer(ctx, "server-a", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-a failed: %v", err)
|
||||
}
|
||||
count2, err := manager.AddServer(ctx, "server-b", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
totalTools := count1 + count2
|
||||
if totalTools != 4 {
|
||||
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 4 {
|
||||
t.Fatalf("Expected 4 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remove server-a, verify server-b tools remain.
|
||||
err = manager.RemoveServer("server-a")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-a failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remaining tools should all be from server-b.
|
||||
for _, tool := range tools {
|
||||
if !strings.HasPrefix(tool.Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove server-b.
|
||||
err = manager.RemoveServer("server-b")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
|
||||
// adding a server with the same name as an already loaded server errors.
|
||||
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to add again with the same name.
|
||||
_, err = manager.AddServer(ctx, "echo", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error adding duplicate server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already loaded") {
|
||||
t.Errorf("Expected 'already loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
|
||||
// re-added after being removed.
|
||||
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add, remove, re-add.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
|
||||
// with a name that already exists returns an error.
|
||||
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// First add will fail (bad command), but let's test the duplicate detection
|
||||
// by simulating a loaded server via LoadTools first.
|
||||
loadCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test-server": cfg,
|
||||
},
|
||||
}
|
||||
// This will fail to load but creates the connection pool.
|
||||
_ = manager.LoadTools(ctx, loadCfg)
|
||||
|
||||
// Now try to add the same server name — the tools didn't load (bad command),
|
||||
// so AddServer should not find a duplicate and should fail with connection error.
|
||||
_, err := manager.AddServer(ctx, "test-server", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding server with bad command, got nil")
|
||||
}
|
||||
// It should be a connection error, not a duplicate error.
|
||||
if strings.Contains(err.Error(), "already loaded") {
|
||||
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
|
||||
// that doesn't exist returns an appropriate error.
|
||||
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
err := manager.RemoveServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when removing non-existent server, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
|
||||
// lazily creates a connection pool when LoadTools was never called.
|
||||
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Connection pool should be nil initially.
|
||||
if manager.connectionPool != nil {
|
||||
t.Fatal("Expected nil connection pool before any operation")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// AddServer with a bad command — should fail, but the pool should be created.
|
||||
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for bad command")
|
||||
}
|
||||
|
||||
// Connection pool should have been created.
|
||||
if manager.connectionPool == nil {
|
||||
t.Fatal("Expected connection pool to be created lazily by AddServer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
|
||||
// callback fires on RemoveServer (we can't easily test AddServer with a real
|
||||
// MCP server, but we can test the callback wiring).
|
||||
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
callCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// RemoveServer on non-existent should NOT fire callback.
|
||||
_ = manager.RemoveServer("nonexistent")
|
||||
|
||||
mu.Lock()
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
|
||||
// pool was never initialized.
|
||||
func TestMCPToolManager_Close_NilPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
err := manager.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
|
||||
// non-existent connection returns an error.
|
||||
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
err := pool.RemoveConnection("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for non-existent connection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
|
||||
// ensureConnectionPool doesn't recreate an existing pool.
|
||||
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// First call creates the pool.
|
||||
manager.ensureConnectionPool()
|
||||
pool1 := manager.connectionPool
|
||||
if pool1 == nil {
|
||||
t.Fatal("Expected pool to be created")
|
||||
}
|
||||
|
||||
// Second call should be a no-op.
|
||||
manager.ensureConnectionPool()
|
||||
pool2 := manager.connectionPool
|
||||
if pool1 != pool2 {
|
||||
t.Fatal("Expected ensureConnectionPool to be idempotent")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,691 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
mcpclient "github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
// newTestPromptServer creates an in-process MCP server with prompt capabilities
|
||||
// and the specified prompts + handlers. Returns an initialized MCPClient.
|
||||
func newTestPromptServer(t *testing.T, prompts ...server.ServerPrompt) mcpclient.MCPClient {
|
||||
t.Helper()
|
||||
|
||||
mcpServer := server.NewMCPServer(
|
||||
"test-prompt-server", "1.0.0",
|
||||
server.WithPromptCapabilities(true),
|
||||
server.WithToolCapabilities(true),
|
||||
)
|
||||
|
||||
if len(prompts) > 0 {
|
||||
mcpServer.AddPrompts(prompts...)
|
||||
}
|
||||
|
||||
// Add a dummy tool so loadServerTools has something to list.
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool("noop", mcp.WithDescription("no-op tool")),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
)
|
||||
|
||||
client, err := mcpclient.NewInProcessClient(mcpServer)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInProcessClient: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if err := client.Start(ctx); err != nil {
|
||||
t.Fatalf("client.Start: %v", err)
|
||||
}
|
||||
|
||||
initReq := mcp.InitializeRequest{}
|
||||
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
|
||||
if _, err := client.Initialize(ctx, initReq); err != nil {
|
||||
t.Fatalf("client.Initialize: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
// injectClientIntoManager sets up an MCPToolManager with a pre-connected
|
||||
// in-process client, bypassing the normal connection pool flow.
|
||||
func injectClientIntoManager(t *testing.T, serverName string, client mcpclient.MCPClient) *MCPToolManager {
|
||||
t.Helper()
|
||||
|
||||
m := NewMCPToolManager()
|
||||
|
||||
// Create a minimal connection pool and inject our client.
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
pool.mu.Lock()
|
||||
pool.connections[serverName] = &MCPConnection{
|
||||
client: client,
|
||||
serverName: serverName,
|
||||
isHealthy: true,
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
m.connectionPool = pool
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func TestLoadServerPrompts_Basic(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("review-pr",
|
||||
mcp.WithPromptDescription("Review a pull request"),
|
||||
mcp.WithArgument("pr_number",
|
||||
mcp.ArgumentDescription("The PR number to review"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
mcp.WithArgument("focus",
|
||||
mcp.ArgumentDescription("Area to focus on"),
|
||||
),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
prNum := req.Params.Arguments["pr_number"]
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "PR review prompt",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Please review PR #%s", prNum),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("explain-code",
|
||||
mcp.WithPromptDescription("Explain a piece of code"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: "Please explain the following code.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "github", client)
|
||||
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
serverName: "github",
|
||||
isHealthy: true,
|
||||
}
|
||||
m.loadServerPrompts(ctx, "github", conn)
|
||||
|
||||
prompts := m.GetPrompts()
|
||||
if len(prompts) != 2 {
|
||||
t.Fatalf("expected 2 prompts, got %d", len(prompts))
|
||||
}
|
||||
|
||||
// Find review-pr prompt.
|
||||
var reviewPR *MCPPrompt
|
||||
for i := range prompts {
|
||||
if prompts[i].Name == "review-pr" {
|
||||
reviewPR = &prompts[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if reviewPR == nil {
|
||||
t.Fatal("review-pr prompt not found")
|
||||
}
|
||||
if reviewPR.Description != "Review a pull request" {
|
||||
t.Errorf("unexpected description: %q", reviewPR.Description)
|
||||
}
|
||||
if reviewPR.ServerName != "github" {
|
||||
t.Errorf("unexpected server name: %q", reviewPR.ServerName)
|
||||
}
|
||||
if len(reviewPR.Arguments) != 2 {
|
||||
t.Fatalf("expected 2 arguments, got %d", len(reviewPR.Arguments))
|
||||
}
|
||||
|
||||
// Verify argument metadata.
|
||||
arg0 := reviewPR.Arguments[0]
|
||||
if arg0.Name != "pr_number" {
|
||||
t.Errorf("expected first arg name 'pr_number', got %q", arg0.Name)
|
||||
}
|
||||
if !arg0.Required {
|
||||
t.Error("expected first arg to be required")
|
||||
}
|
||||
arg1 := reviewPR.Arguments[1]
|
||||
if arg1.Name != "focus" {
|
||||
t.Errorf("expected second arg name 'focus', got %q", arg1.Name)
|
||||
}
|
||||
if arg1.Required {
|
||||
t.Error("expected second arg to be optional")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_ExpandsWithArgs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("greet",
|
||||
mcp.WithPromptDescription("Greet someone"),
|
||||
mcp.WithArgument("name", mcp.RequiredArgument()),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
name := req.Params.Arguments["name"]
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Greeting",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{
|
||||
Type: "text",
|
||||
Text: fmt.Sprintf("Hello, %s!", name),
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "myserver", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "myserver", "greet", map[string]string{"name": "World"})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if result.Description != "Greeting" {
|
||||
t.Errorf("unexpected description: %q", result.Description)
|
||||
}
|
||||
if len(result.Messages) != 1 {
|
||||
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Role != "user" {
|
||||
t.Errorf("unexpected role: %q", result.Messages[0].Role)
|
||||
}
|
||||
if result.Messages[0].Content != "Hello, World!" {
|
||||
t.Errorf("unexpected content: %q", result.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_MultipleMessages(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("chat-starter"),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "What is Go?"},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleAssistant,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Go is a programming language."},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Tell me more."},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "server", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "server", "chat-starter", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if len(result.Messages) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(result.Messages))
|
||||
}
|
||||
if result.Messages[0].Role != "user" {
|
||||
t.Errorf("msg[0] role: got %q, want 'user'", result.Messages[0].Role)
|
||||
}
|
||||
if result.Messages[1].Role != "assistant" {
|
||||
t.Errorf("msg[1] role: got %q, want 'assistant'", result.Messages[1].Role)
|
||||
}
|
||||
if result.Messages[2].Content != "Tell me more." {
|
||||
t.Errorf("msg[2] content: got %q, want 'Tell me more.'", result.Messages[2].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_ServerNotFound(t *testing.T) {
|
||||
m := NewMCPToolManager()
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), false, nil, nil)
|
||||
m.connectionPool = pool
|
||||
|
||||
_, err := m.GetPrompt(context.Background(), "nonexistent", "foo", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_NoPool(t *testing.T) {
|
||||
m := NewMCPToolManager()
|
||||
|
||||
_, err := m.GetPrompt(context.Background(), "any", "foo", nil)
|
||||
if err == nil {
|
||||
t.Fatal("expected error with no pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveServer_RemovesPrompts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("my-prompt",
|
||||
mcp.WithPromptDescription("A test prompt"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Messages: []mcp.PromptMessage{
|
||||
{Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "hi"}},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "testsvr", client)
|
||||
|
||||
// Manually populate tools and prompts as loadServerTools would.
|
||||
conn := m.connectionPool.connections["testsvr"]
|
||||
m.loadServerPrompts(ctx, "testsvr", conn)
|
||||
|
||||
// Also add a fake tool mapping so RemoveServer finds the server.
|
||||
m.toolMap["testsvr__noop"] = &toolMapping{
|
||||
serverName: "testsvr",
|
||||
originalName: "noop",
|
||||
}
|
||||
m.tools = append(m.tools, MCPTool{
|
||||
Name: "testsvr__noop",
|
||||
ServerName: "testsvr",
|
||||
})
|
||||
|
||||
// Verify prompts exist before removal.
|
||||
if got := len(m.GetPrompts()); got != 1 {
|
||||
t.Fatalf("expected 1 prompt before removal, got %d", got)
|
||||
}
|
||||
|
||||
// Remove the server.
|
||||
err := m.RemoveServer("testsvr")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer error: %v", err)
|
||||
}
|
||||
|
||||
// Verify prompts are gone.
|
||||
if got := len(m.GetPrompts()); got != 0 {
|
||||
t.Fatalf("expected 0 prompts after removal, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadServerPrompts_NoPromptCapability(t *testing.T) {
|
||||
// Server without prompt capabilities — ListPrompts should fail gracefully.
|
||||
mcpServer := server.NewMCPServer("no-prompts", "1.0.0",
|
||||
server.WithToolCapabilities(true),
|
||||
// No WithPromptCapabilities
|
||||
)
|
||||
mcpServer.AddTool(
|
||||
mcp.NewTool("noop"),
|
||||
func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
return mcp.NewToolResultText("ok"), nil
|
||||
},
|
||||
)
|
||||
|
||||
client, err := mcpclient.NewInProcessClient(mcpServer)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInProcessClient: %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
_ = client.Start(ctx)
|
||||
initReq := mcp.InitializeRequest{}
|
||||
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
|
||||
initReq.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.0"}
|
||||
_, _ = client.Initialize(ctx, initReq)
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
m := NewMCPToolManager()
|
||||
conn := &MCPConnection{
|
||||
client: client,
|
||||
serverName: "no-prompts",
|
||||
isHealthy: true,
|
||||
}
|
||||
|
||||
// Should not panic or error — just silently skip.
|
||||
m.loadServerPrompts(ctx, "no-prompts", conn)
|
||||
|
||||
if got := len(m.GetPrompts()); got != 0 {
|
||||
t.Fatalf("expected 0 prompts from server without prompt capability, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractPromptContent(t *testing.T) {
|
||||
t.Run("TextContent", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.TextContent{Type: "text", Text: "hello world"})
|
||||
if text != "hello world" {
|
||||
t.Errorf("text = %q, want %q", text, "hello world")
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImageContent", func(t *testing.T) {
|
||||
// base64 of "fake image"
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("fake image"))
|
||||
text, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: encoded,
|
||||
MIMEType: "image/png",
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "image/png" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "image/png")
|
||||
}
|
||||
if parts[0].Filename != "image.png" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "image.png")
|
||||
}
|
||||
if string(parts[0].Data) != "fake image" {
|
||||
t.Errorf("data = %q, want %q", string(parts[0].Data), "fake image")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ImageContent_DefaultMIME", func(t *testing.T) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("img"))
|
||||
_, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: encoded,
|
||||
// no MIMEType → should default to image/png
|
||||
})
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "image/png" {
|
||||
t.Errorf("default MIME = %q, want %q", parts[0].MediaType, "image/png")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("AudioContent", func(t *testing.T) {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte("fake audio"))
|
||||
text, parts := extractPromptContent(mcp.AudioContent{
|
||||
Type: "audio",
|
||||
Data: encoded,
|
||||
MIMEType: "audio/mp3",
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part, got %d", len(parts))
|
||||
}
|
||||
if parts[0].MediaType != "audio/mp3" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "audio/mp3")
|
||||
}
|
||||
if parts[0].Filename != "audio.wav" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "audio.wav")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmbeddedResource_Text", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.TextResourceContents{
|
||||
URI: "file:///project/main.go",
|
||||
MIMEType: "text/x-go",
|
||||
Text: "package main",
|
||||
},
|
||||
})
|
||||
if text == "" {
|
||||
t.Fatal("expected non-empty text for text resource")
|
||||
}
|
||||
if !strings.Contains(text, "package main") {
|
||||
t.Errorf("text should contain resource content, got %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "file:///project/main.go") {
|
||||
t.Errorf("text should contain URI, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for text resource, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EmbeddedResource_Blob", func(t *testing.T) {
|
||||
blobData := []byte("binary content")
|
||||
encoded := base64.StdEncoding.EncodeToString(blobData)
|
||||
text, parts := extractPromptContent(mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.BlobResourceContents{
|
||||
URI: "file:///project/data.bin",
|
||||
MIMEType: "application/octet-stream",
|
||||
Blob: encoded,
|
||||
},
|
||||
})
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text for blob resource, got %q", text)
|
||||
}
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 file part for blob resource, got %d", len(parts))
|
||||
}
|
||||
if parts[0].Filename != "data.bin" {
|
||||
t.Errorf("filename = %q, want %q", parts[0].Filename, "data.bin")
|
||||
}
|
||||
if parts[0].MediaType != "application/octet-stream" {
|
||||
t.Errorf("media type = %q, want %q", parts[0].MediaType, "application/octet-stream")
|
||||
}
|
||||
if string(parts[0].Data) != "binary content" {
|
||||
t.Errorf("data = %q, want %q", string(parts[0].Data), "binary content")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ResourceLink", func(t *testing.T) {
|
||||
text, parts := extractPromptContent(mcp.ResourceLink{
|
||||
Type: "resource_link",
|
||||
URI: "file:///docs/readme.md",
|
||||
Name: "readme.md",
|
||||
})
|
||||
if text == "" {
|
||||
t.Fatal("expected non-empty text for resource link")
|
||||
}
|
||||
if !strings.Contains(text, "file:///docs/readme.md") {
|
||||
t.Errorf("text should contain URI, got %q", text)
|
||||
}
|
||||
if !strings.Contains(text, "readme.md") {
|
||||
t.Errorf("text should contain name, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for resource link, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InvalidBase64", func(t *testing.T) {
|
||||
_, parts := extractPromptContent(mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: "not-valid-base64!!!",
|
||||
MIMEType: "image/png",
|
||||
})
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 file parts for invalid base64, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("NilContent", func(t *testing.T) {
|
||||
text, parts := extractPromptContent((*mcp.TextContent)(nil))
|
||||
if text != "" {
|
||||
t.Errorf("expected empty text for nil, got %q", text)
|
||||
}
|
||||
if len(parts) != 0 {
|
||||
t.Errorf("expected 0 parts for nil, got %d", len(parts))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilenameFromURI(t *testing.T) {
|
||||
tests := []struct {
|
||||
uri string
|
||||
want string
|
||||
}{
|
||||
{"file:///path/to/image.png", "image.png"},
|
||||
{"file:///single.txt", "single.txt"},
|
||||
{"resource://server/data.json", "data.json"},
|
||||
{"nopath", "nopath"},
|
||||
{"", "resource"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.uri, func(t *testing.T) {
|
||||
got := filenameFromURI(tt.uri)
|
||||
if got != tt.want {
|
||||
t.Errorf("filenameFromURI(%q) = %q, want %q", tt.uri, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrompt_EmbeddedResources(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
imgData := base64.StdEncoding.EncodeToString([]byte("fake-png"))
|
||||
blobData := base64.StdEncoding.EncodeToString([]byte("binary-blob"))
|
||||
|
||||
client := newTestPromptServer(t,
|
||||
server.ServerPrompt{
|
||||
Prompt: mcp.NewPrompt("review-with-files",
|
||||
mcp.WithPromptDescription("Review with embedded resources"),
|
||||
),
|
||||
Handler: func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Review prompt with embedded files",
|
||||
Messages: []mcp.PromptMessage{
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.TextContent{Type: "text", Text: "Please review these files:"},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.TextResourceContents{
|
||||
URI: "file:///src/main.go",
|
||||
MIMEType: "text/x-go",
|
||||
Text: "package main\n\nfunc main() {}",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.ImageContent{
|
||||
Type: "image",
|
||||
Data: imgData,
|
||||
MIMEType: "image/png",
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: mcp.RoleUser,
|
||||
Content: mcp.EmbeddedResource{
|
||||
Type: "resource",
|
||||
Resource: mcp.BlobResourceContents{
|
||||
URI: "file:///data/model.bin",
|
||||
MIMEType: "application/octet-stream",
|
||||
Blob: blobData,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
m := injectClientIntoManager(t, "test", client)
|
||||
|
||||
result, err := m.GetPrompt(ctx, "test", "review-with-files", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPrompt error: %v", err)
|
||||
}
|
||||
if result.Description != "Review prompt with embedded files" {
|
||||
t.Errorf("unexpected description: %q", result.Description)
|
||||
}
|
||||
|
||||
// Should have 4 messages: text, embedded text resource, image, embedded blob
|
||||
if len(result.Messages) != 4 {
|
||||
t.Fatalf("expected 4 messages, got %d", len(result.Messages))
|
||||
}
|
||||
|
||||
// Message 0: plain text
|
||||
msg0 := result.Messages[0]
|
||||
if msg0.Content != "Please review these files:" {
|
||||
t.Errorf("msg[0] content = %q", msg0.Content)
|
||||
}
|
||||
if len(msg0.FileParts) != 0 {
|
||||
t.Errorf("msg[0] expected 0 file parts, got %d", len(msg0.FileParts))
|
||||
}
|
||||
|
||||
// Message 1: embedded text resource → inlined as text
|
||||
msg1 := result.Messages[1]
|
||||
if !strings.Contains(msg1.Content, "package main") {
|
||||
t.Errorf("msg[1] should contain resource text, got %q", msg1.Content)
|
||||
}
|
||||
if len(msg1.FileParts) != 0 {
|
||||
t.Errorf("msg[1] expected 0 file parts (text resource), got %d", len(msg1.FileParts))
|
||||
}
|
||||
|
||||
// Message 2: image → file part
|
||||
msg2 := result.Messages[2]
|
||||
if msg2.Content != "" {
|
||||
t.Errorf("msg[2] expected empty text for image, got %q", msg2.Content)
|
||||
}
|
||||
if len(msg2.FileParts) != 1 {
|
||||
t.Fatalf("msg[2] expected 1 file part, got %d", len(msg2.FileParts))
|
||||
}
|
||||
if msg2.FileParts[0].MediaType != "image/png" {
|
||||
t.Errorf("msg[2] file part MIME = %q", msg2.FileParts[0].MediaType)
|
||||
}
|
||||
if string(msg2.FileParts[0].Data) != "fake-png" {
|
||||
t.Errorf("msg[2] file part data = %q", string(msg2.FileParts[0].Data))
|
||||
}
|
||||
|
||||
// Message 3: embedded blob resource → file part
|
||||
msg3 := result.Messages[3]
|
||||
if msg3.Content != "" {
|
||||
t.Errorf("msg[3] expected empty text for blob resource, got %q", msg3.Content)
|
||||
}
|
||||
if len(msg3.FileParts) != 1 {
|
||||
t.Fatalf("msg[3] expected 1 file part, got %d", len(msg3.FileParts))
|
||||
}
|
||||
if msg3.FileParts[0].Filename != "model.bin" {
|
||||
t.Errorf("msg[3] filename = %q, want %q", msg3.FileParts[0].Filename, "model.bin")
|
||||
}
|
||||
if string(msg3.FileParts[0].Data) != "binary-blob" {
|
||||
t.Errorf("msg[3] file part data = %q", string(msg3.FileParts[0].Data))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user