mirror of
https://github.com/mark3labs/kit.git
synced 2026-06-14 03:30:26 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68d798d2f4 | |||
| eefd5565f8 | |||
| 9d1b8a102e | |||
| f57e045c69 | |||
| eb5da28a15 | |||
| cd8e2a7654 | |||
| 64da1caf41 | |||
| 7eaeafff8c | |||
| 8ed8d23c73 | |||
| 2de98d32be | |||
| 83127467c5 | |||
| e07c94f49d | |||
| b87146a284 | |||
| 186d9f7f44 | |||
| 3a8ffc2104 | |||
| e54570162e |
@@ -2,7 +2,7 @@
|
||||
description: Create a feature request using the GitHub template
|
||||
---
|
||||
|
||||
Create a feature request for the Kit repository. The user wants to request: $@
|
||||
Create a feature request for the Kit repository. The user wants to request: $+
|
||||
|
||||
## Feature Request Template
|
||||
|
||||
@@ -16,7 +16,7 @@ This prompt uses the `feature_request` GitHub template which requires:
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the request** from `$@`
|
||||
1. **Understand the request** from `$+`
|
||||
- What capability is missing?
|
||||
- What would the ideal behavior look like?
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description: File a GitHub issue using the appropriate template
|
||||
---
|
||||
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $@
|
||||
File a GitHub issue for the Kit repository. The user wants to create an issue about: $+
|
||||
|
||||
## Issue Templates Available
|
||||
|
||||
@@ -16,7 +16,7 @@ This repository has structured issue templates. You MUST use the appropriate tem
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Determine the issue type** from `$@`:
|
||||
1. **Determine the issue type** from `$+`:
|
||||
- Bug → use `--template bug_report`
|
||||
- Feature → use `--template feature_request`
|
||||
- Documentation → use `--template documentation`
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
description: Scaffold a new prompt template in .kit/prompts/
|
||||
---
|
||||
|
||||
Create a new kit prompt template. The user wants a prompt that does: $@
|
||||
Create a new kit prompt template. The user wants a prompt that does: $+
|
||||
|
||||
## What a prompt template is
|
||||
|
||||
@@ -23,19 +23,21 @@ $1 $2 etc. for positional arguments.
|
||||
- **Filename** → slug: `commit-push.md` becomes `/commit-push`
|
||||
- **Frontmatter**: only `description` is recognised; keep it under ~80 chars
|
||||
- **Body**: plain markdown; the full text is submitted as the user's message when the template fires
|
||||
- **Arguments**: `$@` expands to everything the user typed after the slash command name;
|
||||
- **Arguments**: `$+` expands to everything the user typed after the slash command name
|
||||
(requires at least one argument); `$@` is the same but allows zero arguments;
|
||||
`$1`, `$2` for individual positional args; omit entirely if no arguments are needed
|
||||
|
||||
## Steps
|
||||
|
||||
1. **Understand the workflow** the user described in `$@` — ask a clarifying question if the intent is ambiguous
|
||||
1. **Understand the workflow** the user described in `$+` — ask a clarifying question if the intent is ambiguous
|
||||
2. **Choose a filename**: short, lowercase, hyphen-separated, descriptive (e.g. `code-review.md`)
|
||||
3. **Write the description**: one sentence, imperative, fits in autocomplete
|
||||
4. **Draft the body**:
|
||||
- Open with a single sentence stating the goal
|
||||
- Use `## Steps` for multi-step workflows; use plain prose for simple prompts
|
||||
- Be specific: name commands, flags, and file paths where relevant
|
||||
- End with `$@` on its own line if the user might want to pass context or a hint; omit if the prompt is self-contained
|
||||
- End with `$+` on its own line if the user must pass context; use `$@` if arguments
|
||||
are optional; omit if the prompt is self-contained
|
||||
5. **Write the file** to `.kit/prompts/<slug>.md`
|
||||
6. **Confirm** by showing the final file content and the slash command that activates it
|
||||
|
||||
|
||||
@@ -317,39 +317,39 @@ kit -e examples/extensions/minimal.go
|
||||
|
||||
See the `examples/extensions/` directory:
|
||||
|
||||
- `minimal.go` - Clean UI with custom footer
|
||||
- `auto-commit.go` - Auto-commit on shutdown
|
||||
- `bookmark.go` - Bookmark conversations
|
||||
- `branded-output.go` - Branded output rendering
|
||||
- `compact-notify.go` - Notification on compaction
|
||||
- `confirm-destructive.go` - Confirm destructive operations
|
||||
- `context-inject.go` - Inject context into conversations
|
||||
- `conversation-manager.go` - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- `custom-editor-demo.go` - Vim-like modal editor
|
||||
- `dev-reload.go` - Development live-reload
|
||||
- `header-footer-demo.go` - Custom headers and footers
|
||||
- `inline-bash.go` - Inline bash execution
|
||||
- `interactive-shell.go` - Interactive shell integration
|
||||
- `kit-kit.go` - Kit-in-Kit (sub-agent spawning)
|
||||
- `lsp-diagnostics.go` - LSP diagnostic integration
|
||||
- `notify.go` - Desktop notifications
|
||||
- `overlay-demo.go` - Modal dialogs
|
||||
- `permission-gate.go` - Permission gating for tools
|
||||
- `pirate.go` - Pirate-themed personality
|
||||
- `plan-mode.go` - Read-only planning mode
|
||||
- `project-rules.go` - Project-specific rules
|
||||
- `prompt-demo.go` - Interactive prompts (select/confirm/input)
|
||||
- `prompt-templates.go` - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- `protected-paths.go` - Path protection for sensitive files
|
||||
- `subagent-widget.go` - Multi-agent orchestration with status widget
|
||||
- `subagent-test.go` - Subagent testing utilities
|
||||
- `summarize.go` - Conversation summarization
|
||||
- `tool-logger.go` - Log all tool calls
|
||||
- `neon-theme.go` - Custom theme registration and switching
|
||||
- `tool-renderer-demo.go` - Custom tool call rendering
|
||||
- `widget-status.go` - Persistent status widgets
|
||||
- [`minimal.go`](examples/extensions/minimal.go) - Clean UI with custom footer
|
||||
- [`auto-commit.go`](examples/extensions/auto-commit.go) - Auto-commit on shutdown
|
||||
- [`bookmark.go`](examples/extensions/bookmark.go) - Bookmark conversations
|
||||
- [`branded-output.go`](examples/extensions/branded-output.go) - Branded output rendering
|
||||
- [`compact-notify.go`](examples/extensions/compact-notify.go) - Notification on compaction
|
||||
- [`confirm-destructive.go`](examples/extensions/confirm-destructive.go) - Confirm destructive operations
|
||||
- [`context-inject.go`](examples/extensions/context-inject.go) - Inject context into conversations
|
||||
- [`conversation-manager.go`](examples/extensions/conversation-manager.go) - **NEW** Tree navigation, branch summarization, and fresh context loops
|
||||
- [`custom-editor-demo.go`](examples/extensions/custom-editor-demo.go) - Vim-like modal editor
|
||||
- [`dev-reload.go`](examples/extensions/dev-reload.go) - Development live-reload
|
||||
- [`header-footer-demo.go`](examples/extensions/header-footer-demo.go) - Custom headers and footers
|
||||
- [`inline-bash.go`](examples/extensions/inline-bash.go) - Inline bash execution
|
||||
- [`interactive-shell.go`](examples/extensions/interactive-shell.go) - Interactive shell integration
|
||||
- [`kit-kit.go`](examples/extensions/kit-kit.go) - Kit-in-Kit (sub-agent spawning)
|
||||
- [`lsp-diagnostics.go`](examples/extensions/lsp-diagnostics.go) - LSP diagnostic integration
|
||||
- [`notify.go`](examples/extensions/notify.go) - Desktop notifications
|
||||
- [`overlay-demo.go`](examples/extensions/overlay-demo.go) - Modal dialogs
|
||||
- [`permission-gate.go`](examples/extensions/permission-gate.go) - Permission gating for tools
|
||||
- [`pirate.go`](examples/extensions/pirate.go) - Pirate-themed personality
|
||||
- [`plan-mode.go`](examples/extensions/plan-mode.go) - Read-only planning mode
|
||||
- [`project-rules.go`](examples/extensions/project-rules.go) - Project-specific rules
|
||||
- [`prompt-demo.go`](examples/extensions/prompt-demo.go) - Interactive prompts (select/confirm/input)
|
||||
- [`prompt-templates.go`](examples/extensions/prompt-templates.go) - **NEW** Frontmatter-driven templates with model switching and skill injection
|
||||
- [`protected-paths.go`](examples/extensions/protected-paths.go) - Path protection for sensitive files
|
||||
- [`subagent-widget.go`](examples/extensions/subagent-widget.go) - Multi-agent orchestration with status widget
|
||||
- [`subagent-test.go`](examples/extensions/subagent-test.go) - Subagent testing utilities
|
||||
- [`summarize.go`](examples/extensions/summarize.go) - Conversation summarization
|
||||
- [`tool-logger.go`](examples/extensions/tool-logger.go) - Log all tool calls
|
||||
- [`neon-theme.go`](examples/extensions/neon-theme.go) - Custom theme registration and switching
|
||||
- [`tool-renderer-demo.go`](examples/extensions/tool-renderer-demo.go) - Custom tool call rendering
|
||||
- [`widget-status.go`](examples/extensions/widget-status.go) - Persistent status widgets
|
||||
|
||||
Also see `.kit/extensions/go-edit-lint.go` (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
Also see [`.kit/extensions/go-edit-lint.go`](.kit/extensions/go-edit-lint.go) (in this repo) for a project-local extension example that runs gopls and golangci-lint on Go file edits.
|
||||
|
||||
### Loading Extensions
|
||||
|
||||
@@ -406,7 +406,7 @@ func TestMyExtension(t *testing.T) {
|
||||
- `AssertPrinted()`, `AssertPrintedContains()` — Verify output
|
||||
- `AssertToolRegistered()`, `AssertCommandRegistered()` — Verify registration
|
||||
|
||||
See `examples/extensions/tool-logger_test.go` for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
See [`examples/extensions/tool-logger_test.go`](examples/extensions/tool-logger_test.go) for a complete example with 14 test cases covering tool calls, input handling, and session lifecycle.
|
||||
|
||||
### Prompt Templates
|
||||
|
||||
|
||||
@@ -10,13 +10,21 @@ import (
|
||||
"kit/ext"
|
||||
)
|
||||
|
||||
// re matches !{...} with non-greedy content.
|
||||
var re = regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
|
||||
// Init expands inline bash expressions in user prompts before they reach the
|
||||
// LLM. Text like !{git branch --show-current} is replaced with the command's
|
||||
// stdout.
|
||||
// LLM. Text like !{git rev-parse --abbrev-ref HEAD} is replaced with the
|
||||
// command's stdout.
|
||||
//
|
||||
// In interactive mode the expansion happens at submit time via an editor
|
||||
// interceptor, so the expanded text is also visible in the user message
|
||||
// block on screen. In non-interactive mode (CLI, script, queue) the
|
||||
// expansion happens via OnInput transform.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// "Fix the tests on !{git branch --show-current}"
|
||||
// "Fix the tests on !{git rev-parse --abbrev-ref HEAD}"
|
||||
// → "Fix the tests on main"
|
||||
//
|
||||
// "The current directory is !{pwd}"
|
||||
@@ -24,29 +32,59 @@ import (
|
||||
//
|
||||
// Usage: kit -e examples/extensions/inline-bash.go
|
||||
func Init(api ext.API) {
|
||||
// Matches !{...} with non-greedy content.
|
||||
re := regexp.MustCompile(`!\{([^}]+)\}`)
|
||||
// ── Interactive mode: editor interceptor ──────────────────────────
|
||||
// Intercept Enter / Ctrl+D so we can expand !{...} BEFORE the
|
||||
// SubmitMsg is created. This ensures the expanded text appears in
|
||||
// the user message block on screen as well as in the LLM prompt.
|
||||
api.OnSessionStart(func(_ ext.SessionStartEvent, ctx ext.Context) {
|
||||
if !ctx.Interactive {
|
||||
return
|
||||
}
|
||||
ctx.SetEditor(ext.EditorConfig{
|
||||
HandleKey: func(key string, currentText string) ext.EditorKeyAction {
|
||||
if (key == "enter" || key == "ctrl+d") && re.MatchString(currentText) {
|
||||
expanded := expand(currentText)
|
||||
// Clear the textarea asynchronously — calling
|
||||
// SetEditorText synchronously from inside Update()
|
||||
// would deadlock the BubbleTea event loop.
|
||||
go ctx.SetEditorText("")
|
||||
return ext.EditorKeyAction{
|
||||
Type: ext.EditorKeySubmit,
|
||||
SubmitText: expanded,
|
||||
}
|
||||
}
|
||||
return ext.EditorKeyAction{Type: ext.EditorKeyPassthrough}
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
// ── Non-interactive fallback: OnInput transform ──────────────────
|
||||
// For CLI, script, and queue sources the editor interceptor is not
|
||||
// active, so we fall back to OnInput which still rewrites the
|
||||
// prompt text sent to the LLM.
|
||||
api.OnInput(func(ev ext.InputEvent, ctx ext.Context) *ext.InputResult {
|
||||
if !re.MatchString(ev.Text) {
|
||||
if ev.Source == "interactive" || !re.MatchString(ev.Text) {
|
||||
return nil
|
||||
}
|
||||
|
||||
expanded := re.ReplaceAllStringFunc(ev.Text, func(match string) string {
|
||||
// Extract the command between !{ and }.
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
|
||||
return &ext.InputResult{
|
||||
Action: "transform",
|
||||
Text: expanded,
|
||||
Text: expand(ev.Text),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// expand replaces every !{cmd} in text with the command's stdout.
|
||||
// On error the original !{cmd} token is preserved.
|
||||
func expand(text string) string {
|
||||
return re.ReplaceAllStringFunc(text, func(match string) string {
|
||||
cmd := re.FindStringSubmatch(match)[1]
|
||||
cmd = strings.TrimSpace(cmd)
|
||||
|
||||
out, err := exec.Command("bash", "-c", cmd).Output()
|
||||
if err != nil {
|
||||
return match // keep original on error
|
||||
}
|
||||
return strings.TrimSpace(string(out))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,31 +1,32 @@
|
||||
module github.com/mark3labs/kit
|
||||
|
||||
go 1.26.1
|
||||
go 1.26.2
|
||||
|
||||
require (
|
||||
charm.land/bubbles/v2 v2.1.0
|
||||
charm.land/bubbletea/v2 v2.0.2
|
||||
charm.land/fantasy v0.17.1
|
||||
charm.land/bubbletea/v2 v2.0.5
|
||||
charm.land/fantasy v0.17.2
|
||||
charm.land/huh/v2 v2.0.3
|
||||
charm.land/lipgloss/v2 v2.0.2
|
||||
charm.land/lipgloss/v2 v2.0.3
|
||||
github.com/alecthomas/chroma/v2 v2.23.1
|
||||
github.com/atotto/clipboard v0.1.4
|
||||
github.com/aymanbagabas/go-udiff v0.4.1
|
||||
github.com/charmbracelet/fang v1.0.0
|
||||
github.com/charmbracelet/log v1.0.0
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e
|
||||
github.com/charmbracelet/x/editor v0.2.0
|
||||
github.com/clipperhouse/displaywidth v0.11.0
|
||||
github.com/clipperhouse/uax29/v2 v2.7.0
|
||||
github.com/coder/acp-go-sdk v0.6.3
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/indaco/herald v0.13.0
|
||||
github.com/indaco/herald-md v0.3.0
|
||||
github.com/mark3labs/mcp-go v0.47.1
|
||||
github.com/mark3labs/mcp-go v0.48.0
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/traefik/yaegi v0.16.1
|
||||
golang.org/x/term v0.41.0
|
||||
golang.org/x/term v0.42.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
@@ -58,9 +59,9 @@ require (
|
||||
github.com/charmbracelet/harmonica v0.2.0 // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 // indirect
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 // indirect
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 // indirect
|
||||
github.com/charmbracelet/x/json v0.2.0 // indirect
|
||||
github.com/charmbracelet/x/termios v0.1.1 // indirect
|
||||
@@ -81,10 +82,10 @@ require (
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.3.1 // indirect
|
||||
github.com/kaptinlin/go-i18n v0.4.0 // indirect
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 // indirect
|
||||
github.com/kaptinlin/jsonschema v0.7.7 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 // indirect
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 // indirect
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect
|
||||
github.com/muesli/mango v0.2.0 // indirect
|
||||
github.com/muesli/mango-cobra v1.3.0 // indirect
|
||||
@@ -109,14 +110,14 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.43.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.49.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/crypto v0.50.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
|
||||
golang.org/x/net v0.53.0 // indirect
|
||||
golang.org/x/oauth2 v0.36.0 // indirect
|
||||
golang.org/x/time v0.15.0 // indirect
|
||||
google.golang.org/api v0.275.0 // indirect
|
||||
google.golang.org/genai v1.52.1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d // indirect
|
||||
google.golang.org/genai v1.54.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 // indirect
|
||||
google.golang.org/grpc v1.80.0 // indirect
|
||||
google.golang.org/protobuf v1.36.11 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
@@ -124,7 +125,7 @@ require (
|
||||
|
||||
require (
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.6
|
||||
github.com/charmbracelet/x/ansi v0.11.7
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 // indirect
|
||||
@@ -136,5 +137,5 @@ require (
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/text v0.35.0
|
||||
golang.org/x/text v0.36.0
|
||||
)
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g=
|
||||
charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY=
|
||||
charm.land/bubbletea/v2 v2.0.2 h1:4CRtRnuZOdFDTWSff9r8QFt/9+z6Emubz3aDMnf/dx0=
|
||||
charm.land/bubbletea/v2 v2.0.2/go.mod h1:3LRff2U4WIYXy7MTxfbAQ+AdfM3D8Xuvz2wbsOD9OHQ=
|
||||
charm.land/fantasy v0.17.1 h1:SQzfnyJPDuQWt6e//KKmQmEEXdqHMC0IZz10XwkLcEM=
|
||||
charm.land/fantasy v0.17.1/go.mod h1:FF5ALCCHETacHJPBqU42CtwMInYQ0ul52fdzIHQMbQk=
|
||||
charm.land/bubbletea/v2 v2.0.5 h1:TQlLFqxo39AAHSVuOhJ5D3nH7O9Nk8JGinsfWQ4y1U4=
|
||||
charm.land/bubbletea/v2 v2.0.5/go.mod h1:dvbsYZD+MHkdIZl+Z67D212hEvB+GII2tfH8f9SnoDw=
|
||||
charm.land/fantasy v0.17.2 h1:ojTMufMxY/PVH7TzYUxht2SVkvD90iCTJfmPR6c8BR8=
|
||||
charm.land/fantasy v0.17.2/go.mod h1:V9cCIUMZB9g3Bq40aKEY8xBNzDd48EdfHp2OMS0uzWs=
|
||||
charm.land/huh/v2 v2.0.3 h1:2cJsMqEPwSywGHvdlKsJyQKPtSJLVnFKyFbsYZTlLkU=
|
||||
charm.land/huh/v2 v2.0.3/go.mod h1:93eEveeeqn47MwiC3tf+2atZ2l7Is88rAtmZNZ8x9Wc=
|
||||
charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs=
|
||||
charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM=
|
||||
charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU=
|
||||
charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA=
|
||||
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
|
||||
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
|
||||
cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA=
|
||||
@@ -86,24 +86,26 @@ github.com/charmbracelet/log v1.0.0 h1:HVVVMmfOorfj3BA9i8X8UL69Hoz9lI0PYwXfJvOdR
|
||||
github.com/charmbracelet/log v1.0.0/go.mod h1:uYgY3SmLpwJWxmlrPwXvzVYujxis1vAKRV/0VQB7yWA=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266 h1:BW/sZtyd1JyYy0h5adMm3tzpNyL857LWjuTRET6OhpY=
|
||||
github.com/charmbracelet/openai-go v0.0.0-20260319145158-d0740cc34266/go.mod h1:1DahUaExbUZx/jD+FNT2PKP4L9rLE5+ZBRuI8mZjd/E=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b h1:ASDO9RT6SNKTQN87jO2bRfxHFJq8cgeYdFzivY2gCeM=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260330092749-0f94982c930b/go.mod h1:Vo8TffMf0q7Uho/n8e6XpBZvOWtd3g39yX+9P5rRutA=
|
||||
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
|
||||
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e h1:O5hZFj55wZQWxMiRtQLa3uLKhZGZGS/j8M3OXinQlrw=
|
||||
github.com/charmbracelet/ultraviolet v0.0.0-20260414011438-8c69ec811b1e/go.mod h1:bAAz7dh/FTYfC+oiHavL4mX1tOIBZ0ZwYjSi3qE6ivM=
|
||||
github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dAYC84jI=
|
||||
github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/conpty v0.1.1 h1:s1bUxjoi7EpqiXysVtC+a8RrvPPNcNvAjfi4jxsAuEs=
|
||||
github.com/charmbracelet/x/conpty v0.1.1/go.mod h1:OmtR77VODEFbiTzGE9G1XiRJAga6011PIm4u5fTNZpk=
|
||||
github.com/charmbracelet/x/editor v0.2.0 h1:7XLUKtaRaB8jN7bWU2p2UChiySyaAuIfYiIRg8gGWwk=
|
||||
github.com/charmbracelet/x/editor v0.2.0/go.mod h1:p3oQ28TSL3YPd+GKJ1fHWcp+7bVGpedHpXmo0D6t1dY=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86 h1:JSt3B+U9iqk37QUU2Rvb6DSBYRLtWqFqfxf8l5hOZUA=
|
||||
github.com/charmbracelet/x/errors v0.0.0-20240508181413-e8d8b6e2de86/go.mod h1:2P0UgXMEa6TsToMSuFqKFQR+fZTO9CNGUNokkPatT/0=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143 h1:zmBor0ftFNqVFp9U59ZoEDRUCIYSGOGSIfGGkNZRufs=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260406091427-a791e22d5143/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913 h1:6F/6bu5nBLjodsvaU5xAszTaxtHrDU5UiJarpMPZj48=
|
||||
github.com/charmbracelet/x/exp/charmtone v0.0.0-20260413165052-6921c759c913/go.mod h1:nsExn0DGyX0lh9LwLHTn2Gg+hafdzfSXnC+QmEJTZFY=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0 h1:55/qLwjIh0gL0Vni+QAWk7T/qRVP6sBf+2agPBgnOFE=
|
||||
github.com/charmbracelet/x/exp/ordered v0.1.0/go.mod h1:5UHwmG+is5THxMyCJHNPCn2/ecI07aKNrW+LcResjJ8=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143 h1:aEppolah2k9c0LzKX2fk5ryuyQ0Lq8kCOjkvMw1b8o4=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260406091427-a791e22d5143/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913 h1:RiZFY92Ug9iz1CenzxSSQla2Z3WflsR7bIuXq40JlpU=
|
||||
github.com/charmbracelet/x/exp/slice v0.0.0-20260413165052-6921c759c913/go.mod h1:vqEfX6xzqW1pKKZUUiFOKg0OQ7bCh54Q2vR/tserrRA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0 h1:i69S2XI7uG1u4NLGeJPSYU++Nmjvpo9nwd6aoEm7gkA=
|
||||
github.com/charmbracelet/x/exp/strings v0.1.0/go.mod h1:/ehtMPNh9K4odGFkqYJKpIYyePhdp1hLBRvyY4bWkH8=
|
||||
github.com/charmbracelet/x/json v0.2.0 h1:DqB+ZGx2h+Z+1s98HOuOyli+i97wsFQIxP2ZQANTPrQ=
|
||||
@@ -185,14 +187,14 @@ github.com/indaco/herald v0.13.0 h1:+xVG9Fx5NpuWhwku/9IlRL6I009NnX4VUGKvlZHTRxU=
|
||||
github.com/indaco/herald v0.13.0/go.mod h1:T5g1+XLYvpjouhzAGHnAHDCKizhESkoV6+QPZ3DhgWA=
|
||||
github.com/indaco/herald-md v0.3.0 h1:hN1cKyrexPPM9PeHBsKuaWvIizSi/iYvM9yzRgtdb8M=
|
||||
github.com/indaco/herald-md v0.3.0/go.mod h1:RUHVaDSG45ymJjKyxpDwBocLXrZo93FB4OeYMsw9B9s=
|
||||
github.com/kaptinlin/go-i18n v0.3.1 h1:plXi3XQE1aYamFi8TU0K6actODmw2+5FSobmhTkfQ/0=
|
||||
github.com/kaptinlin/go-i18n v0.3.1/go.mod h1:ZRoAHj7elWYamfbv7wev7Ajch6LOzjtBaq8nWe8HIVk=
|
||||
github.com/kaptinlin/go-i18n v0.4.0 h1:i7L3U2yurg+xhokITtJ0k+mjHnXqkoyz8ju5Wb7W8Oc=
|
||||
github.com/kaptinlin/go-i18n v0.4.0/go.mod h1:njA6x0+4MWGcLWT0KLrwekhRPmze1Hnstf2+VJFzwpM=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17 h1:mY9k8ciWncxbsECyaxKnR0MdmxamNdp2tLQkAKVrtSk=
|
||||
github.com/kaptinlin/jsonpointer v0.4.17/go.mod h1:SsfsjqnHG5zuKo1DTBzk1VknaHlL4osHw+X9kZKukpU=
|
||||
github.com/kaptinlin/jsonschema v0.7.7 h1:41BlQJ9dskH0oE5DSzBUrl/w4JQYIr6N6L0B5GNyDoM=
|
||||
github.com/kaptinlin/jsonschema v0.7.7/go.mod h1:rKjWfyySHSxAD7Li2ctYkPlOu960igoKBvZ2ADRtd5Q=
|
||||
github.com/kaptinlin/messageformat-go v0.4.19 h1:A5kuuZ1ybXDQ7kD1aoEWGAOemX7hLsMY0yolgSbgpRI=
|
||||
github.com/kaptinlin/messageformat-go v0.4.19/go.mod h1:utSDTfiXTxl66OC5RIEuObLH7Ue3YjbA2X86SYMBYWg=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20 h1:a0ufTd5liiUubIGeGxpSTnNS8ZSrN4DV01/wGFmfzMs=
|
||||
github.com/kaptinlin/messageformat-go v0.4.20/go.mod h1:FqdEPfQLkqVBX7OBRMPgYwUPvKYJohFD9Ok1BMzCfIo=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
@@ -201,8 +203,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0 h1:UtrWVfLdarDgc44HcS7pYloGHJUjHV/4FwW4TvVgFr4=
|
||||
github.com/lucasb-eyer/go-colorful v1.4.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mark3labs/mcp-go v0.47.1 h1:A9sJJ20mscl/ssLYHjodfaoBmq6uuhMG7pAPNYaQymQ=
|
||||
github.com/mark3labs/mcp-go v0.47.1/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mark3labs/mcp-go v0.48.0 h1:o+MXuGW/HCeR2ny5LcAcZQn2bo6I2xaZMEHnpRG+dtw=
|
||||
github.com/mark3labs/mcp-go v0.48.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.21 h1:xYae+lCNBP7QuW4PUnNG61ffM4hVIfm+zUzDuSzYLGs=
|
||||
github.com/mattn/go-isatty v0.0.21/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||
github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw=
|
||||
@@ -288,36 +290,36 @@ go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09
|
||||
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
|
||||
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90 h1:jiDhWWeC7jfWqR9c/uplMOqJ0sbNlNWv0UkzE0vX1MA=
|
||||
golang.org/x/exp v0.0.0-20260312153236-7ab1446f8b90/go.mod h1:xE1HEv6b+1SCZ5/uscMRjUBKtIxworgEcEi+/n9NQDQ=
|
||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
|
||||
golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI=
|
||||
golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM=
|
||||
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80=
|
||||
golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA=
|
||||
golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||
golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI=
|
||||
golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||
golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU=
|
||||
golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A=
|
||||
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||
golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY=
|
||||
golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY=
|
||||
golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg=
|
||||
golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164=
|
||||
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||
gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E=
|
||||
google.golang.org/api v0.275.0 h1:vfY5d9vFVJeWEZT65QDd9hbndr7FyZ2+6mIzGAh71NI=
|
||||
google.golang.org/api v0.275.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw=
|
||||
google.golang.org/genai v1.52.1 h1:dYoljKtLDXMiBdVaClSJ/ZPwZ7j1N0lGjMhwOKOQUlk=
|
||||
google.golang.org/genai v1.52.1/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0=
|
||||
google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d h1:wT2n40TBqFY6wiwazVK9/iTWbsQrgk5ZfCSVFLO9LQA=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/genai v1.54.0 h1:ZQCa70WMTJDI11FdqWCzGvZ5PanpcpfoO6jl/lrSnGU=
|
||||
google.golang.org/genai v1.54.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d h1:N1Ec54vZnIPd7MnxRiYLW+oY4fDR4BOS/LrssdD9+ek=
|
||||
google.golang.org/genproto v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:c2hJ1grtnH0xUiEKGDGkjGNTJ1Hy2LrblyKOHF0sqRM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d h1:/aDRtSZJjyLQzm75d+a1wOJaqyKBMvIAfeQmoa3ORiI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20260406210006-6f92a3bedf2d/go.mod h1:etfGUgejTiadZAUaEP14NP97xi1RGeawqkjDARA/UOs=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478 h1:RmoJA1ujG+/lRGNfUnOMfhCy5EipVMyvUE+KNbPbTlw=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20260414002931-afd174a4e478/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||
google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM=
|
||||
google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
|
||||
+87
-3
@@ -30,6 +30,11 @@ type AgentConfig struct {
|
||||
// If nil, remote MCP servers that require OAuth will fail to connect.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used. This allows SDK users to provide a custom tool set (e.g.
|
||||
// CodingTools or tools with a custom WorkDir).
|
||||
@@ -130,6 +135,12 @@ type Agent struct {
|
||||
skipMaxOutputTokens bool
|
||||
modelConfig *models.ProviderConfig
|
||||
|
||||
// authHandler and tokenStoreFactory are stored from AgentConfig so that
|
||||
// AddMCPServer() can propagate them when creating a new MCPToolManager
|
||||
// at runtime (i.e. when no MCP servers were configured at init time).
|
||||
authHandler tools.MCPAuthHandler
|
||||
tokenStoreFactory tools.TokenStoreFactory
|
||||
|
||||
// mcpReady is closed when background MCP tool loading completes (success
|
||||
// or failure). nil when no MCP servers are configured.
|
||||
mcpReady chan struct{}
|
||||
@@ -226,6 +237,8 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
providerOptions: providerResult.ProviderOptions,
|
||||
skipMaxOutputTokens: providerResult.SkipMaxOutputTokens,
|
||||
modelConfig: agentConfig.ModelConfig,
|
||||
authHandler: agentConfig.AuthHandler,
|
||||
tokenStoreFactory: agentConfig.TokenStoreFactory,
|
||||
}
|
||||
|
||||
// Start MCP tool loading in the background if servers are configured.
|
||||
@@ -236,6 +249,9 @@ func NewAgent(ctx context.Context, agentConfig *AgentConfig) (*Agent, error) {
|
||||
if agentConfig.AuthHandler != nil {
|
||||
toolManager.SetAuthHandler(agentConfig.AuthHandler)
|
||||
}
|
||||
if agentConfig.TokenStoreFactory != nil {
|
||||
toolManager.SetTokenStoreFactory(agentConfig.TokenStoreFactory)
|
||||
}
|
||||
if agentConfig.DebugLogger != nil {
|
||||
toolManager.SetDebugLogger(agentConfig.DebugLogger)
|
||||
}
|
||||
@@ -826,6 +842,65 @@ func (a *Agent) SetExtraTools(extraTools []fantasy.AgentTool) {
|
||||
a.rebuildFantasyAgent()
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent. Returns the number of tools loaded.
|
||||
// If the agent has no tool manager (no MCP servers were configured at init),
|
||||
// one is created automatically.
|
||||
func (a *Agent) AddMCPServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
if a.toolManager == nil {
|
||||
a.toolManager = tools.NewMCPToolManager()
|
||||
a.toolManager.SetModel(a.model)
|
||||
if a.authHandler != nil {
|
||||
a.toolManager.SetAuthHandler(a.authHandler)
|
||||
}
|
||||
if a.tokenStoreFactory != nil {
|
||||
a.toolManager.SetTokenStoreFactory(a.tokenStoreFactory)
|
||||
}
|
||||
a.toolManager.SetOnToolsChanged(func() {
|
||||
a.rebuildFantasyAgent()
|
||||
})
|
||||
}
|
||||
|
||||
count, err := a.toolManager.AddServer(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// AddServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but only if it was wired. Ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes its tools from the agent.
|
||||
func (a *Agent) RemoveMCPServer(name string) error {
|
||||
if a.toolManager == nil {
|
||||
return fmt.Errorf("no MCP servers loaded")
|
||||
}
|
||||
|
||||
// Ensure MCP tools from initial load are settled first.
|
||||
a.ensureMCPTools()
|
||||
|
||||
err := a.toolManager.RemoveServer(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveServer's onToolsChanged callback triggers rebuildFantasyAgent,
|
||||
// but ensure rebuild happens regardless.
|
||||
a.rebuildFantasyAgent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMCPToolManager returns the underlying MCP tool manager.
|
||||
// Returns nil if no MCP servers have been configured.
|
||||
func (a *Agent) GetMCPToolManager() *tools.MCPToolManager {
|
||||
return a.toolManager
|
||||
}
|
||||
|
||||
// GetLoadingMessage returns the loading message from provider creation.
|
||||
func (a *Agent) GetLoadingMessage() string {
|
||||
return a.loadingMessage
|
||||
@@ -839,9 +914,11 @@ func (a *Agent) GetLoadedServerNames() []string {
|
||||
return a.toolManager.GetLoadedServerNames()
|
||||
}
|
||||
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools,
|
||||
// system prompt, and configuration are preserved. The old provider is closed
|
||||
// if it has a closer. Returns the previous model string for notification.
|
||||
// SetModel swaps the agent's LLM provider to a new model. The existing tools
|
||||
// and configuration are preserved. When the new model's ProviderConfig carries
|
||||
// a system prompt (from per-model settings), it replaces the agent's stored
|
||||
// prompt so the rebuilt fantasy agent uses it. The old provider is closed if
|
||||
// it has a closer.
|
||||
func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) error {
|
||||
// Ensure MCP tools are loaded before rebuilding (SetModel may be called
|
||||
// before the first LLM call).
|
||||
@@ -868,6 +945,13 @@ func (a *Agent) SetModel(ctx context.Context, config *models.ProviderConfig) err
|
||||
a.skipMaxOutputTokens = providerResult.SkipMaxOutputTokens
|
||||
a.modelConfig = config
|
||||
|
||||
// Update system prompt when the config carries one (from per-model
|
||||
// settings or the global config). This allows model-specific system
|
||||
// prompts to take effect on model switch.
|
||||
if config.SystemPrompt != "" {
|
||||
a.systemPrompt = config.SystemPrompt
|
||||
}
|
||||
|
||||
// Update provider type.
|
||||
if config.ModelString != "" {
|
||||
if p, _, err := models.ParseModelString(config.ModelString); err == nil {
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// mockModel is a minimal LanguageModel that satisfies the interface
|
||||
// without making real API calls. Used to test tool management wiring.
|
||||
type mockModel struct{}
|
||||
|
||||
func (m *mockModel) Generate(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{}, nil
|
||||
}
|
||||
func (m *mockModel) Stream(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) GenerateObject(_ context.Context, _ fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
return &fantasy.ObjectResponse{}, nil
|
||||
}
|
||||
func (m *mockModel) StreamObject(_ context.Context, _ fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockModel) Provider() string { return "mock" }
|
||||
func (m *mockModel) Model() string { return "mock-model" }
|
||||
|
||||
// testdataDir returns the absolute path to the tools testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "..", "tools", "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// mockAuthHandler is a minimal MCPAuthHandler for testing that auth handler
|
||||
// propagation works without requiring a real OAuth server.
|
||||
type mockAuthHandler struct {
|
||||
redirectURI string
|
||||
}
|
||||
|
||||
func (h *mockAuthHandler) RedirectURI() string { return h.redirectURI }
|
||||
func (h *mockAuthHandler) HandleAuth(_ context.Context, _ string, _ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// newTestAgent creates a minimal Agent with a mock model and no core tools,
|
||||
// suitable for testing MCP server management without an API key.
|
||||
func newTestAgent() *Agent {
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Initially no MCP tools.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Fatalf("Expected 0 MCP tools initially, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Add a server.
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
// Verify tools are in the agent's tool list.
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
allTools := a.GetTools()
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range allTools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo' in agent tools")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet' in agent tools")
|
||||
}
|
||||
|
||||
// Verify loaded server names.
|
||||
names := a.GetLoadedServerNames()
|
||||
found := false
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected 'echo' in loaded server names: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add then remove.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools removed.
|
||||
if a.GetMCPToolCount() != 0 {
|
||||
t.Errorf("Expected 0 MCP tools after removal, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
|
||||
// Verify agent's tool list has no MCP tools.
|
||||
for _, tool := range a.GetTools() {
|
||||
if strings.Contains(tool.Info().Name, "echo__") {
|
||||
t.Errorf("Found leftover tool after removal: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_RemoveMCPServer_NoToolManager(t *testing.T) {
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
err := a.RemoveMCPServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when no tool manager exists")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "no MCP servers loaded") {
|
||||
t.Errorf("Expected 'no MCP servers loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddMCPServer_CreatesToolManager(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist.
|
||||
if a.GetMCPToolManager() == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgent_AddRemoveAdd_MCP(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
a := newTestAgent()
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add → Remove → Add cycle.
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First add failed: %v", err)
|
||||
}
|
||||
|
||||
err = a.RemoveMCPServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("Remove failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-add failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
if a.GetMCPToolCount() != 2 {
|
||||
t.Errorf("Expected 2 MCP tools after re-add, got %d", a.GetMCPToolCount())
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgent_AddMCPServer_InheritsAuthHandler verifies that AddMCPServer()
|
||||
// propagates the agent's authHandler and tokenStoreFactory to a newly created
|
||||
// MCPToolManager (fix for issue #3).
|
||||
func TestAgent_AddMCPServer_InheritsAuthHandler(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
handler := &mockAuthHandler{redirectURI: "http://localhost:9999/oauth/callback"}
|
||||
|
||||
model := &mockModel{}
|
||||
a := &Agent{
|
||||
model: model,
|
||||
coreTools: nil,
|
||||
extraTools: nil,
|
||||
maxSteps: 10,
|
||||
systemPrompt: "test",
|
||||
fantasyAgent: fantasy.NewAgent(model),
|
||||
authHandler: handler,
|
||||
tokenStoreFactory: nil, // nil is fine; we just test authHandler propagation
|
||||
}
|
||||
defer func() { _ = a.Close() }()
|
||||
|
||||
// Initially no tool manager.
|
||||
if a.GetMCPToolManager() != nil {
|
||||
t.Fatal("Expected nil tool manager initially")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
_, err := a.AddMCPServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddMCPServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Tool manager should now exist and have the auth handler set.
|
||||
tm := a.GetMCPToolManager()
|
||||
if tm == nil {
|
||||
t.Fatal("Expected tool manager to be created by AddMCPServer")
|
||||
}
|
||||
|
||||
// Verify the auth handler was propagated by checking the field directly.
|
||||
if tm.GetAuthHandler() == nil {
|
||||
t.Fatal("Expected auth handler to be propagated to tool manager")
|
||||
}
|
||||
}
|
||||
@@ -38,6 +38,10 @@ type AgentCreationOptions struct {
|
||||
DebugLogger tools.DebugLogger // Optional debug logger
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// CoreTools overrides the default core tool set. If empty, core.AllTools()
|
||||
// is used.
|
||||
CoreTools []fantasy.AgentTool
|
||||
@@ -66,6 +70,7 @@ func CreateAgent(ctx context.Context, opts *AgentCreationOptions) (*Agent, error
|
||||
StreamingEnabled: opts.StreamingEnabled,
|
||||
DebugLogger: opts.DebugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: opts.ToolWrapper,
|
||||
|
||||
+44
-18
@@ -930,7 +930,8 @@ func (a *App) QuitFromExtension() {
|
||||
// controls styling: "" for plain text, "info" for a system message block,
|
||||
// "error" for an error block. In interactive mode it sends an
|
||||
// ExtensionPrintEvent through the program so the TUI can render it with the
|
||||
// appropriate renderer. In non-interactive mode it falls back to stdout.
|
||||
// appropriate renderer. In non-interactive mode it falls back to stderr with
|
||||
// a level prefix so errors are distinguishable from plain output.
|
||||
func (a *App) PrintFromExtension(level, text string) {
|
||||
a.mu.Lock()
|
||||
prog := a.program
|
||||
@@ -939,8 +940,16 @@ func (a *App) PrintFromExtension(level, text string) {
|
||||
prog.Send(ExtensionPrintEvent{Text: text, Level: level})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback: write directly to stdout.
|
||||
fmt.Println(text)
|
||||
// Non-interactive fallback: write to stderr with a level prefix so that
|
||||
// errors and info messages are distinguishable from plain output.
|
||||
switch level {
|
||||
case "error":
|
||||
fmt.Fprintf(os.Stderr, "[ERROR] %s\n", text)
|
||||
case "info":
|
||||
fmt.Fprintf(os.Stderr, "[INFO] %s\n", text)
|
||||
default:
|
||||
fmt.Println(text)
|
||||
}
|
||||
}
|
||||
|
||||
// SetEditorTextFromExtension sends an EditorTextSetEvent to the TUI to
|
||||
@@ -1122,11 +1131,12 @@ func (a *App) PrintBlockFromExtension(opts extensions.PrintBlockOpts) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Non-interactive fallback.
|
||||
// Non-interactive fallback: render a simple framed block to stderr so
|
||||
// it is visually distinct from plain stdout output.
|
||||
if opts.Subtitle != "" {
|
||||
fmt.Printf("%s\n — %s\n", opts.Text, opts.Subtitle)
|
||||
fmt.Fprintf(os.Stderr, "--- %s ---\n%s\n", opts.Subtitle, opts.Text)
|
||||
} else {
|
||||
fmt.Println(opts.Text)
|
||||
fmt.Fprintf(os.Stderr, "---\n%s\n---\n", opts.Text)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1155,9 +1165,10 @@ func (a *App) recordStepUsage(ev kit.StepUsageEvent, stepUsageSeen *atomic.Bool)
|
||||
int(ev.CacheWriteTokens),
|
||||
)
|
||||
// NOTE: We do NOT call SetContextTokens here. Context fill is set once
|
||||
// at turn completion via updateUsageFromTurnResult using FinalUsage.InputTokens,
|
||||
// which reflects the full accumulated context. Per-step context tokens would
|
||||
// cause the display to jump around during multi-step tool calls.
|
||||
// at turn completion via updateUsageFromTurnResult, which sums all token
|
||||
// categories (Input + CacheRead + CacheCreate + Output) from FinalUsage.
|
||||
// Per-step context tokens would cause the display to jump around during
|
||||
// multi-step tool calls.
|
||||
}
|
||||
|
||||
// updateUsageFromTurnResult records token usage from an SDK TurnResult into the
|
||||
@@ -1221,15 +1232,30 @@ func (a *App) updateUsageFromTurnResult(result *kit.TurnResult, userPrompt strin
|
||||
}
|
||||
|
||||
// --- Context window fill (drives the % bar) ---
|
||||
// Use FinalUsage.InputTokens as the context window fill. The API's InputTokens
|
||||
// already includes the full conversation history (system prompt + all previous
|
||||
// messages + current user message). Adding OutputTokens would double-count since
|
||||
// the output becomes part of the input for the next turn.
|
||||
if result.FinalUsage != nil && result.FinalUsage.InputTokens > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: calling SetContextTokens=%d (FinalUsage.InputTokens)",
|
||||
result.FinalUsage.InputTokens)
|
||||
// Calculate context fill from the LAST API call's usage. The context
|
||||
// window is filled by everything sent to and received from the model:
|
||||
//
|
||||
// InputTokens — non-cached input (may be small with prompt caching)
|
||||
// CacheReadTokens — input tokens served from cache
|
||||
// CacheCreationTokens — input tokens written to cache this call
|
||||
// OutputTokens — assistant output (becomes input next turn)
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can drop to near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. We must sum all four to
|
||||
// get the true context window utilization.
|
||||
//
|
||||
// We use FinalUsage (last step only), NOT TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps — and each step re-sends the full
|
||||
// conversation, so TotalUsage massively overstates the actual window fill.
|
||||
if result.FinalUsage != nil {
|
||||
u := result.FinalUsage
|
||||
contextFill := int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
if contextFill > 0 {
|
||||
if a.opts.Debug {
|
||||
log.Printf("[DEBUG] updateUsageFromTurnResult: SetContextTokens=%d (Input=%d + CacheRead=%d + CacheCreate=%d + Output=%d)",
|
||||
contextFill, u.InputTokens, u.CacheReadTokens, u.CacheCreationTokens, u.OutputTokens)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(contextFill)
|
||||
}
|
||||
a.opts.UsageTracker.SetContextTokens(int(result.FinalUsage.InputTokens))
|
||||
}
|
||||
}
|
||||
|
||||
+19
-13
@@ -630,10 +630,12 @@ func TestUpdateUsageFromTurnResult_recordsWhenInputTokensZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly verifies that context
|
||||
// window fill uses InputTokens only (not input+output). The API's InputTokens
|
||||
// already includes the full conversation history; adding output would double-count.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
// TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories verifies that
|
||||
// context window fill uses all token categories from the final API call:
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context.
|
||||
func TestUpdateUsageFromTurnResult_contextTokensUsesAllCategories(t *testing.T) {
|
||||
usage := &usageUpdaterStub{}
|
||||
app := New(Options{UsageTracker: usage}, nil)
|
||||
defer app.Close()
|
||||
@@ -641,22 +643,26 @@ func TestUpdateUsageFromTurnResult_contextTokensUsesInputOnly(t *testing.T) {
|
||||
app.updateUsageFromTurnResult(&kit.TurnResult{
|
||||
Response: "ok",
|
||||
TotalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 200,
|
||||
InputTokens: 3,
|
||||
OutputTokens: 5,
|
||||
CacheReadTokens: 0,
|
||||
CacheCreationTokens: 4317,
|
||||
},
|
||||
FinalUsage: &kit.LLMUsage{
|
||||
InputTokens: 1000, // Full context including history
|
||||
OutputTokens: 200,
|
||||
InputTokens: 3, // Non-cached input (small with caching)
|
||||
OutputTokens: 5, // Assistant output
|
||||
CacheReadTokens: 0, // No cache reads on first call
|
||||
CacheCreationTokens: 4317, // System prompt + tools written to cache
|
||||
},
|
||||
}, "prompt", false)
|
||||
|
||||
usage.mu.Lock()
|
||||
defer usage.mu.Unlock()
|
||||
|
||||
// Context tokens should be InputTokens only (1000), not input+output (1200)
|
||||
// because InputTokens already includes the full conversation history
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != 1000 {
|
||||
t.Fatalf("expected context tokens=1000 (InputTokens only), got calls=%d tokens=%d",
|
||||
usage.contextCalls, usage.lastContextTokens)
|
||||
// Context tokens should be Input + CacheRead + CacheCreate + Output = 4325
|
||||
expected := 3 + 0 + 4317 + 5
|
||||
if usage.contextCalls != 1 || usage.lastContextTokens != expected {
|
||||
t.Fatalf("expected context tokens=%d (all categories), got calls=%d tokens=%d",
|
||||
expected, usage.contextCalls, usage.lastContextTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,8 +21,10 @@ type UsageUpdater interface {
|
||||
// the provider does not return exact counts.
|
||||
EstimateAndUpdateUsage(inputText, outputText string)
|
||||
// SetContextTokens records the approximate current context window fill
|
||||
// level. This should be the final API call's input+output tokens (from
|
||||
// FinalResponse.Usage), NOT the aggregate TotalUsage.
|
||||
// level. This should be the sum of ALL token categories from the last
|
||||
// API call: InputTokens + CacheReadTokens + CacheCreationTokens +
|
||||
// OutputTokens. With Anthropic prompt caching, InputTokens can be
|
||||
// near-zero while CacheReadTokens holds the bulk of the context.
|
||||
SetContextTokens(tokens int)
|
||||
}
|
||||
|
||||
|
||||
@@ -22,6 +22,14 @@ type MCPServerConfig struct {
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
|
||||
// OAuth configuration for remote servers that don't support dynamic
|
||||
// client registration (e.g. GitHub). When OAuthClientID is set, it is
|
||||
// passed directly to the transport's OAuthConfig instead of relying on
|
||||
// dynamic registration.
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
|
||||
// Legacy fields for backward compatibility
|
||||
Transport string `json:"transport,omitempty"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
@@ -35,13 +43,16 @@ type MCPServerConfig struct {
|
||||
func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
// First try to unmarshal as the new format
|
||||
type newFormat struct {
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Command []string `json:"command,omitempty"`
|
||||
Environment map[string]string `json:"environment,omitempty"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Headers []string `json:"headers,omitempty"`
|
||||
AllowedTools []string `json:"allowedTools,omitempty" yaml:"allowedTools,omitempty"`
|
||||
ExcludedTools []string `json:"excludedTools,omitempty" yaml:"excludedTools,omitempty"`
|
||||
OAuthClientID string `json:"oauthClientId,omitempty" yaml:"oauthClientId,omitempty"`
|
||||
OAuthClientSecret string `json:"oauthClientSecret,omitempty" yaml:"oauthClientSecret,omitempty"`
|
||||
OAuthScopes []string `json:"oauthScopes,omitempty" yaml:"oauthScopes,omitempty"`
|
||||
}
|
||||
|
||||
// Also try legacy format
|
||||
@@ -66,6 +77,9 @@ func (s *MCPServerConfig) UnmarshalJSON(data []byte) error {
|
||||
s.Headers = newConfig.Headers
|
||||
s.AllowedTools = newConfig.AllowedTools
|
||||
s.ExcludedTools = newConfig.ExcludedTools
|
||||
s.OAuthClientID = newConfig.OAuthClientID
|
||||
s.OAuthClientSecret = newConfig.OAuthClientSecret
|
||||
s.OAuthScopes = newConfig.OAuthScopes
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -157,6 +171,21 @@ type Theme struct {
|
||||
Markdown MarkdownThemeConfig `json:"markdown,omitzero" yaml:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParams defines generation parameter defaults that can be attached
|
||||
// to individual models. These act as model-level defaults — CLI flags and
|
||||
// global config values take precedence when explicitly set.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model that can be used with custom/custom
|
||||
// or other custom/ prefixed models. These models are loaded from the config file
|
||||
// and merged into the custom provider in the model registry.
|
||||
@@ -171,6 +200,11 @@ type CustomModelConfig struct {
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
|
||||
// Generation parameter defaults for this model.
|
||||
// These are applied when the user hasn't explicitly set the corresponding
|
||||
// CLI flag or global config value.
|
||||
Params GenerationParams `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
@@ -219,6 +253,12 @@ type Config struct {
|
||||
|
||||
// Custom model definitions (under custom/ provider)
|
||||
CustomModels map[string]CustomModelConfig `json:"customModels,omitempty" yaml:"customModels,omitempty"`
|
||||
|
||||
// Per-model generation parameter overrides. Keys are "provider/model" strings
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929", "openai/gpt-4o"). These
|
||||
// settings act as model-level defaults — CLI flags and global config values
|
||||
// take precedence when explicitly set.
|
||||
ModelSettings map[string]GenerationParams `json:"modelSettings,omitempty" yaml:"modelSettings,omitempty"`
|
||||
}
|
||||
|
||||
// GetTransportType returns the transport type for the server config, mapping
|
||||
@@ -367,7 +407,7 @@ mcpServers:
|
||||
# debug: false # Enable debug logging
|
||||
# system-prompt: "/path/to/system-prompt.txt" # System prompt text file
|
||||
|
||||
# Model generation parameters (all optional)
|
||||
# Model generation parameters (all optional, apply globally to all models)
|
||||
# max-tokens: 4096 # Maximum tokens in response
|
||||
# temperature: 0.7 # Randomness (0.0-1.0)
|
||||
# top-p: 0.95 # Nucleus sampling (0.0-1.0)
|
||||
@@ -376,9 +416,46 @@ mcpServers:
|
||||
# presence-penalty: 0.0 # Penalize present tokens (0.0-2.0)
|
||||
# stop-sequences: ["Human:", "Assistant:"] # Custom stop sequences
|
||||
|
||||
# Per-model generation parameter overrides (apply to specific models)
|
||||
# These act as model-level defaults — CLI flags and global settings above take precedence.
|
||||
# Keys are "provider/model" strings matching the model you use.
|
||||
# modelSettings:
|
||||
# anthropic/claude-sonnet-4-5-20250929:
|
||||
# temperature: 0.3
|
||||
# maxTokens: 8192
|
||||
# openai/gpt-4o:
|
||||
# temperature: 0.7
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# frequencyPenalty: 0.1
|
||||
# presencePenalty: 0.1
|
||||
# anthropic/claude-opus-4-6:
|
||||
# thinkingLevel: "high"
|
||||
# maxTokens: 16384
|
||||
# systemPrompt: "You are a deep reasoning assistant." # or a file path
|
||||
|
||||
# API Configuration (can also use environment variables)
|
||||
# provider-api-key: "your-api-key" # API key for OpenAI, Anthropic, or Google
|
||||
# provider-url: "https://api.openai.com/v1" # Base URL for OpenAI, Anthropic, or Ollama
|
||||
|
||||
# Custom model definitions (under custom/ provider)
|
||||
# customModels:
|
||||
# my-local-llama:
|
||||
# name: "Local Llama 3"
|
||||
# baseUrl: "http://localhost:8080/v1"
|
||||
# family: "llama"
|
||||
# temperature: true
|
||||
# cost:
|
||||
# input: 0.0
|
||||
# output: 0.0
|
||||
# limit:
|
||||
# context: 131072
|
||||
# output: 8192
|
||||
# params: # Generation parameter defaults for this model
|
||||
# temperature: 0.8
|
||||
# topP: 0.95
|
||||
# topK: 40
|
||||
# systemPrompt: "You are a helpful local assistant."
|
||||
`
|
||||
|
||||
_, err = file.WriteString(content)
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestMCPServerConfig_NewFormat(t *testing.T) {
|
||||
@@ -542,3 +544,86 @@ func TestEnsureConfigExistsWhenFileExists(t *testing.T) {
|
||||
t.Error("Existing config file was modified when it shouldn't have been")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_JSON(t *testing.T) {
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://api.githubcopilot.com/mcp/",
|
||||
"oauthClientId": "Ov23liXXXXXXXXXXXXXX",
|
||||
"oauthClientSecret": "secret123",
|
||||
"oauthScopes": ["read:user", "repo"]
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.URL != "https://api.githubcopilot.com/mcp/" {
|
||||
t.Errorf("Expected URL, got %q", cfg.URL)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "secret123" {
|
||||
t.Errorf("Expected OAuthClientSecret 'secret123', got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_YAML(t *testing.T) {
|
||||
yamlData := `
|
||||
type: remote
|
||||
url: https://api.githubcopilot.com/mcp/
|
||||
oauthClientId: "Ov23liXXXXXXXXXXXXXX"
|
||||
oauthScopes:
|
||||
- read:user
|
||||
- repo
|
||||
`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := yaml.Unmarshal([]byte(yamlData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal YAML: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Type != "remote" {
|
||||
t.Errorf("Expected type 'remote', got %q", cfg.Type)
|
||||
}
|
||||
if cfg.OAuthClientID != "Ov23liXXXXXXXXXXXXXX" {
|
||||
t.Errorf("Expected OAuthClientID 'Ov23liXXXXXXXXXXXXXX', got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 2 || cfg.OAuthScopes[0] != "read:user" || cfg.OAuthScopes[1] != "repo" {
|
||||
t.Errorf("Expected OAuthScopes [read:user, repo], got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMCPServerConfig_OAuthFields_Omitted(t *testing.T) {
|
||||
// Verify that omitting OAuth fields still works (backward compat).
|
||||
jsonData := `{
|
||||
"type": "remote",
|
||||
"url": "https://example.com/mcp"
|
||||
}`
|
||||
|
||||
var cfg MCPServerConfig
|
||||
err := json.Unmarshal([]byte(jsonData), &cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
if cfg.OAuthClientID != "" {
|
||||
t.Errorf("Expected empty OAuthClientID, got %q", cfg.OAuthClientID)
|
||||
}
|
||||
if cfg.OAuthClientSecret != "" {
|
||||
t.Errorf("Expected empty OAuthClientSecret, got %q", cfg.OAuthClientSecret)
|
||||
}
|
||||
if len(cfg.OAuthScopes) != 0 {
|
||||
t.Errorf("Expected empty OAuthScopes, got %v", cfg.OAuthScopes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
// Package fences provides utilities for detecting markdown code regions
|
||||
// (fenced code blocks and inline code spans) and applying transformations
|
||||
// only to text outside those regions.
|
||||
//
|
||||
// This prevents special tokens like $1, $@, or @file from being interpreted
|
||||
// when they appear inside ``` fences, ~~~ fences, or `inline` code spans.
|
||||
package fences
|
||||
|
||||
import "strings"
|
||||
|
||||
// Ranges returns byte ranges [start, end) of fenced code blocks in content.
|
||||
// Recognises both backtick (```) and tilde (~~~) fences, with optional
|
||||
// leading indentation (up to 3 spaces) and optional info strings.
|
||||
// An unclosed fence extends to the end of content.
|
||||
func Ranges(content string) [][2]int {
|
||||
var result [][2]int
|
||||
var inFence bool
|
||||
var fenceChar byte
|
||||
var fenceCount int
|
||||
var fenceStart int
|
||||
|
||||
pos := 0
|
||||
for pos < len(content) {
|
||||
// Find the end of the current line.
|
||||
lineEnd := strings.IndexByte(content[pos:], '\n')
|
||||
var line string
|
||||
var nextPos int
|
||||
if lineEnd < 0 {
|
||||
line = content[pos:]
|
||||
nextPos = len(content)
|
||||
} else {
|
||||
line = content[pos : pos+lineEnd]
|
||||
nextPos = pos + lineEnd + 1
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(line, " ")
|
||||
indent := len(line) - len(trimmed)
|
||||
|
||||
if !inFence {
|
||||
if indent <= 3 {
|
||||
if ch, n := parseFenceOpen(trimmed); n > 0 {
|
||||
inFence = true
|
||||
fenceChar = ch
|
||||
fenceCount = n
|
||||
fenceStart = pos
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if indent <= 3 && isFenceClose(trimmed, fenceChar, fenceCount) {
|
||||
result = append(result, [2]int{fenceStart, nextPos})
|
||||
inFence = false
|
||||
}
|
||||
}
|
||||
|
||||
pos = nextPos
|
||||
}
|
||||
|
||||
// Unclosed fence extends to end of content.
|
||||
if inFence {
|
||||
result = append(result, [2]int{fenceStart, len(content)})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReplaceOutside applies fn to each text segment that is outside fenced code
|
||||
// blocks and inline code spans, leaving code content unchanged. This is the
|
||||
// primary entry point for callers that need to do regex replacement only on
|
||||
// non-code text.
|
||||
func ReplaceOutside(content string, fn func(string) string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return replaceOutsideInline(content, fn)
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
// Within non-fenced segments, also skip inline code spans.
|
||||
b.WriteString(replaceOutsideInline(content[pos:r[0]], fn))
|
||||
}
|
||||
// Preserve fenced content verbatim.
|
||||
b.WriteString(content[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(content) {
|
||||
b.WriteString(replaceOutsideInline(content[pos:], fn))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// StripCode returns content with fenced code blocks and inline code spans
|
||||
// removed. Useful for detection/matching where only non-code text matters.
|
||||
func StripCode(content string) string {
|
||||
// First strip fenced blocks.
|
||||
stripped := StripFenced(content)
|
||||
// Then strip inline code spans from what remains.
|
||||
return stripInlineCode(stripped)
|
||||
}
|
||||
|
||||
// StripFenced returns content with fenced code block regions removed.
|
||||
// Useful for detection/matching where only non-fenced text matters.
|
||||
// NOTE: this does NOT strip inline code spans; use StripCode for both.
|
||||
func StripFenced(content string) string {
|
||||
ranges := Ranges(content)
|
||||
if len(ranges) == 0 {
|
||||
return content
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(content))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(content[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(content[pos:])
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// parseFenceOpen checks whether trimmed (leading spaces already removed)
|
||||
// starts a fenced code block. Returns the fence character and count, or
|
||||
// (0, 0) if it is not a fence opener.
|
||||
func parseFenceOpen(trimmed string) (byte, int) {
|
||||
if len(trimmed) == 0 {
|
||||
return 0, 0
|
||||
}
|
||||
ch := trimmed[0]
|
||||
if ch != '`' && ch != '~' {
|
||||
return 0, 0
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == ch {
|
||||
count++
|
||||
}
|
||||
if count < 3 {
|
||||
return 0, 0
|
||||
}
|
||||
// Per CommonMark: backtick fences cannot have backticks in the info string.
|
||||
if ch == '`' && strings.ContainsRune(trimmed[count:], '`') {
|
||||
return 0, 0
|
||||
}
|
||||
return ch, count
|
||||
}
|
||||
|
||||
// isFenceClose checks whether trimmed is a closing fence matching fenceChar
|
||||
// with at least minCount characters. A closing fence line contains only the
|
||||
// fence characters and optional trailing spaces.
|
||||
func isFenceClose(trimmed string, fenceChar byte, minCount int) bool {
|
||||
if len(trimmed) == 0 || trimmed[0] != fenceChar {
|
||||
return false
|
||||
}
|
||||
count := 0
|
||||
for count < len(trimmed) && trimmed[count] == fenceChar {
|
||||
count++
|
||||
}
|
||||
if count < minCount {
|
||||
return false
|
||||
}
|
||||
// Closing fence must contain only fence chars (and optional trailing spaces).
|
||||
return strings.TrimRight(trimmed[count:], " ") == ""
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Inline code span handling
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// inlineCodeRanges returns byte ranges [start, end) of inline code spans
|
||||
// in segment. Per CommonMark, a code span opens with N backticks and closes
|
||||
// with exactly N backticks.
|
||||
func inlineCodeRanges(s string) [][2]int {
|
||||
var result [][2]int
|
||||
i := 0
|
||||
for i < len(s) {
|
||||
if s[i] != '`' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
// Count opening backticks.
|
||||
start := i
|
||||
n := 0
|
||||
for i < len(s) && s[i] == '`' {
|
||||
n++
|
||||
i++
|
||||
}
|
||||
// Scan for a closing run of exactly n backticks.
|
||||
for j := i; j < len(s); {
|
||||
if s[j] != '`' {
|
||||
j++
|
||||
continue
|
||||
}
|
||||
m := 0
|
||||
for j < len(s) && s[j] == '`' {
|
||||
m++
|
||||
j++
|
||||
}
|
||||
if m == n {
|
||||
result = append(result, [2]int{start, j})
|
||||
i = j
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no closing run was found, i is already past the opening
|
||||
// backticks so the outer loop advances naturally.
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// replaceOutsideInline applies fn only to text outside inline code spans.
|
||||
func replaceOutsideInline(segment string, fn func(string) string) string {
|
||||
ranges := inlineCodeRanges(segment)
|
||||
if len(ranges) == 0 {
|
||||
return fn(segment)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(segment))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
if pos < r[0] {
|
||||
b.WriteString(fn(segment[pos:r[0]]))
|
||||
}
|
||||
b.WriteString(segment[r[0]:r[1]])
|
||||
pos = r[1]
|
||||
}
|
||||
if pos < len(segment) {
|
||||
b.WriteString(fn(segment[pos:]))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// stripInlineCode removes inline code spans from s.
|
||||
func stripInlineCode(s string) string {
|
||||
ranges := inlineCodeRanges(s)
|
||||
if len(ranges) == 0 {
|
||||
return s
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
pos := 0
|
||||
for _, r := range ranges {
|
||||
b.WriteString(s[pos:r[0]])
|
||||
pos = r[1]
|
||||
}
|
||||
b.WriteString(s[pos:])
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package fences
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want [][2]int
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world\nno code here",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single backtick fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "single tilde fence",
|
||||
content: "before\n~~~\ncode\n~~~\nafter",
|
||||
want: [][2]int{{7, 20}},
|
||||
},
|
||||
{
|
||||
name: "fence with info string",
|
||||
content: "before\n```go\ncode\n```\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: [][2]int{{2, 12}, {14, 24}},
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\ncode\nmore code",
|
||||
want: [][2]int{{7, 25}},
|
||||
},
|
||||
{
|
||||
name: "longer closing fence",
|
||||
content: "before\n```\ncode\n`````\nafter",
|
||||
want: [][2]int{{7, 22}},
|
||||
},
|
||||
{
|
||||
name: "shorter closing fence ignored",
|
||||
content: "before\n`````\ncode\n```\nmore\n`````\nafter",
|
||||
want: [][2]int{{7, 33}},
|
||||
},
|
||||
{
|
||||
name: "indented fence up to 3 spaces",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: [][2]int{{7, 26}},
|
||||
},
|
||||
{
|
||||
name: "4 space indent is not a fence",
|
||||
content: "before\n ```\ncode\n ```\nafter",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "backtick in info string rejects open",
|
||||
// The ```foo`bar line is not a valid opener (backtick in info).
|
||||
// The standalone ``` becomes an opener with no close.
|
||||
content: "before\n```foo`bar\ncode\n```\nafter",
|
||||
want: [][2]int{{23, 32}},
|
||||
},
|
||||
{
|
||||
name: "empty content",
|
||||
content: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "fence only",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "fence at end without trailing newline",
|
||||
content: "```\ncode\n```",
|
||||
want: [][2]int{{0, 12}},
|
||||
},
|
||||
{
|
||||
name: "tilde fence does not close with backticks",
|
||||
content: "~~~\ncode\n```\nmore\n~~~\nafter",
|
||||
want: [][2]int{{0, 22}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := Ranges(tt.content)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("Ranges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("Ranges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello world",
|
||||
want: "HELLO WORLD",
|
||||
},
|
||||
{
|
||||
name: "text around fence",
|
||||
content: "before\n```\ncode\n```\nafter",
|
||||
want: "BEFORE\n```\ncode\n```\nAFTER",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "aaa\n```\nxxx\n```\nbbb\n~~~\nyyy\n~~~\nccc",
|
||||
want: "AAA\n```\nxxx\n```\nBBB\n~~~\nyyy\n~~~\nCCC",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence preserves code",
|
||||
content: "before\n```\ncode",
|
||||
want: "BEFORE\n```\ncode",
|
||||
},
|
||||
{
|
||||
name: "only fenced content",
|
||||
content: "```\ncode\n```",
|
||||
want: "```\ncode\n```",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripFenced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no fences",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips fenced code",
|
||||
content: "before $1\n```\n$2 inside\n```\nafter $3",
|
||||
want: "before $1\nafter $3",
|
||||
},
|
||||
{
|
||||
name: "multiple fences",
|
||||
content: "a\n```\nx\n```\nb\n~~~\ny\n~~~\nc",
|
||||
want: "a\nb\nc",
|
||||
},
|
||||
{
|
||||
name: "unclosed fence",
|
||||
content: "before\n```\n$1 inside",
|
||||
want: "before\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripFenced(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripFenced() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineCodeRanges(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
want [][2]int
|
||||
}{
|
||||
{"no backticks", "hello world", nil},
|
||||
{"single backtick span", "use `$1` here", [][2]int{{4, 8}}},
|
||||
{"double backtick span", "use ``$1`` here", [][2]int{{4, 10}}},
|
||||
{"multiple spans", "`$1` and `$2`", [][2]int{{0, 4}, {9, 13}}},
|
||||
{"unmatched backtick", "use `$1 here", nil},
|
||||
{"mismatched backtick counts", "use ``$1` here", nil},
|
||||
{"empty inline content", "use `` `` here", [][2]int{{4, 9}}},
|
||||
{"backticks inside double", "use ``foo`bar`` here", [][2]int{{4, 15}}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := inlineCodeRanges(tt.s)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("inlineCodeRanges() = %v, want %v", got, tt.want)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("inlineCodeRanges()[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceOutside_InlineCode(t *testing.T) {
|
||||
upper := func(s string) string {
|
||||
b := []byte(s)
|
||||
for i, c := range b {
|
||||
if c >= 'a' && c <= 'z' {
|
||||
b[i] = c - 32
|
||||
}
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "inline code preserved",
|
||||
content: "use `code` here",
|
||||
want: "USE `code` HERE",
|
||||
},
|
||||
{
|
||||
name: "double backtick inline code",
|
||||
content: "use ``co`de`` here",
|
||||
want: "USE ``co`de`` HERE",
|
||||
},
|
||||
{
|
||||
name: "mixed fenced and inline",
|
||||
content: "before `x` mid\n```\nfenced\n```\nafter `y` end",
|
||||
want: "BEFORE `x` MID\n```\nfenced\n```\nAFTER `y` END",
|
||||
},
|
||||
{
|
||||
name: "only inline code",
|
||||
content: "`code`",
|
||||
want: "`code`",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ReplaceOutside(tt.content, upper)
|
||||
if got != tt.want {
|
||||
t.Errorf("ReplaceOutside() =\n%s\nwant:\n%s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no code",
|
||||
content: "hello $1 world",
|
||||
want: "hello $1 world",
|
||||
},
|
||||
{
|
||||
name: "strips inline code",
|
||||
content: "use `$1` and `$2` for positional args",
|
||||
want: "use and for positional args",
|
||||
},
|
||||
{
|
||||
name: "strips fenced and inline",
|
||||
content: "before `$1`\n```\n$2 inside\n```\nafter",
|
||||
want: "before \nafter",
|
||||
},
|
||||
{
|
||||
name: "real world prompt template",
|
||||
content: "Use $@ for all args.\n`$1`, `$2` for positional.\n```bash\necho $1\n```\n",
|
||||
want: "Use $@ for all args.\n, for positional.\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := StripCode(tt.content)
|
||||
if got != tt.want {
|
||||
t.Errorf("StripCode() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+44
-20
@@ -65,6 +65,10 @@ type AgentSetupOptions struct {
|
||||
// AuthHandler handles OAuth authorization for remote MCP servers.
|
||||
// When set, remote transports are configured with OAuth support.
|
||||
AuthHandler tools.MCPAuthHandler
|
||||
// TokenStoreFactory, if non-nil, creates a custom token store for each
|
||||
// remote MCP server's OAuth tokens. When nil, the default file-based
|
||||
// token store is used.
|
||||
TokenStoreFactory tools.TokenStoreFactory
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading (successfully or with error). Called from the background goroutine.
|
||||
OnMCPServerLoaded func(serverName string, toolCount int, err error)
|
||||
@@ -82,36 +86,55 @@ type AgentSetupResult struct {
|
||||
|
||||
// BuildProviderConfig creates a *models.ProviderConfig from the current viper
|
||||
// state. All entry points (root, script, SDK) converge through this function.
|
||||
//
|
||||
// Generation parameter pointers (Temperature, TopP, etc.) are only set when
|
||||
// the user has explicitly configured them via CLI flag, environment variable,
|
||||
// or global config file. This allows per-model defaults from modelSettings
|
||||
// and customModels to fill in unset parameters downstream.
|
||||
func BuildProviderConfig() (*models.ProviderConfig, string, error) {
|
||||
systemPrompt, err := config.LoadSystemPrompt(viper.GetString("system-prompt"))
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("failed to load system prompt: %w", err)
|
||||
}
|
||||
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
numGPU := int32(viper.GetInt("num-gpu-layers"))
|
||||
mainGPU := int32(viper.GetInt("main-gpu"))
|
||||
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &frequencyPenalty,
|
||||
PresencePenalty: &presencePenalty,
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
ModelString: viper.GetString("model"),
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
ProviderURL: viper.GetString("provider-url"),
|
||||
MaxTokens: viper.GetInt("max-tokens"),
|
||||
StopSequences: viper.GetStringSlice("stop-sequences"),
|
||||
NumGPU: &numGPU,
|
||||
MainGPU: &mainGPU,
|
||||
TLSSkipVerify: viper.GetBool("tls-skip-verify"),
|
||||
ThinkingLevel: models.ParseThinkingLevel(viper.GetString("thinking-level")),
|
||||
}
|
||||
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
return cfg, systemPrompt, nil
|
||||
@@ -200,6 +223,7 @@ func SetupAgent(ctx context.Context, opts AgentSetupOptions) (*AgentSetupResult,
|
||||
SpinnerFunc: opts.SpinnerFunc,
|
||||
DebugLogger: debugLogger,
|
||||
AuthHandler: opts.AuthHandler,
|
||||
TokenStoreFactory: opts.TokenStoreFactory,
|
||||
CoreTools: opts.CoreTools,
|
||||
DisableCoreTools: opts.DisableCoreTools,
|
||||
ToolWrapper: toolWrapper,
|
||||
|
||||
+234
-11
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -31,7 +33,7 @@ func loadCustomModelsFromConfig() map[string]ModelInfo {
|
||||
|
||||
// modelConfigToModelInfo converts a CustomModelConfig to a ModelInfo.
|
||||
func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
return ModelInfo{
|
||||
info := ModelInfo{
|
||||
ID: modelID,
|
||||
Name: cfg.Name,
|
||||
Attachment: cfg.Attachment,
|
||||
@@ -48,21 +50,242 @@ func modelConfigToModelInfo(modelID string, cfg CustomModelConfig) ModelInfo {
|
||||
Output: cfg.Limit.Output,
|
||||
},
|
||||
}
|
||||
|
||||
// Convert custom model generation params if any are set.
|
||||
if p := convertGenerationParams(cfg.Params); p != nil {
|
||||
info.Params = p
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// LoadModelSettingsFromConfig loads per-model generation parameter overrides
|
||||
// from the config file. Keys are "provider/model" strings. Returns nil if
|
||||
// no model settings are configured.
|
||||
func LoadModelSettingsFromConfig() map[string]*GenerationParams {
|
||||
if !viper.IsSet("modelSettings") {
|
||||
return nil
|
||||
}
|
||||
|
||||
var settings map[string]GenerationParamsConfig
|
||||
if err := viper.UnmarshalKey("modelSettings", &settings); err != nil {
|
||||
log.Printf("Warning: Failed to parse modelSettings: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := make(map[string]*GenerationParams, len(settings))
|
||||
for modelKey, cfg := range settings {
|
||||
if p := convertGenerationParams(cfg); p != nil {
|
||||
result[modelKey] = p
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// convertGenerationParams converts a GenerationParamsConfig to a GenerationParams.
|
||||
// Returns nil if no parameters are set.
|
||||
func convertGenerationParams(cfg GenerationParamsConfig) *GenerationParams {
|
||||
p := &GenerationParams{}
|
||||
any := false
|
||||
|
||||
if cfg.MaxTokens != nil {
|
||||
p.MaxTokens = cfg.MaxTokens
|
||||
any = true
|
||||
}
|
||||
if cfg.Temperature != nil {
|
||||
p.Temperature = cfg.Temperature
|
||||
any = true
|
||||
}
|
||||
if cfg.TopP != nil {
|
||||
p.TopP = cfg.TopP
|
||||
any = true
|
||||
}
|
||||
if cfg.TopK != nil {
|
||||
p.TopK = cfg.TopK
|
||||
any = true
|
||||
}
|
||||
if cfg.FrequencyPenalty != nil {
|
||||
p.FrequencyPenalty = cfg.FrequencyPenalty
|
||||
any = true
|
||||
}
|
||||
if cfg.PresencePenalty != nil {
|
||||
p.PresencePenalty = cfg.PresencePenalty
|
||||
any = true
|
||||
}
|
||||
if len(cfg.StopSequences) > 0 {
|
||||
p.StopSequences = cfg.StopSequences
|
||||
any = true
|
||||
}
|
||||
if cfg.ThinkingLevel != "" {
|
||||
p.ThinkingLevel = ParseThinkingLevel(cfg.ThinkingLevel)
|
||||
any = true
|
||||
}
|
||||
if cfg.SystemPrompt != "" {
|
||||
p.SystemPrompt = cfg.SystemPrompt
|
||||
any = true
|
||||
}
|
||||
|
||||
if !any {
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ApplyModelSettings merges per-model generation parameter defaults from the
|
||||
// registry into a ProviderConfig. Model-level params are only applied for
|
||||
// fields where the user has not explicitly set a value (i.e., the
|
||||
// corresponding viper key is not set via CLI flag or global config).
|
||||
//
|
||||
// The lookup order is:
|
||||
// 1. modelSettings["provider/model"] from config (highest model-level priority)
|
||||
// 2. ModelInfo.Params from custom model definitions
|
||||
//
|
||||
// Both are overridden by explicit CLI flags / global config values.
|
||||
func ApplyModelSettings(config *ProviderConfig, modelInfo *ModelInfo) {
|
||||
provider, modelName, err := ParseModelString(config.ModelString)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect model-level params: modelSettings override > custom model params.
|
||||
// modelSettings takes priority because it's the more specific/intentional config.
|
||||
var params *GenerationParams
|
||||
|
||||
// First check modelSettings from config.
|
||||
if settings := LoadModelSettingsFromConfig(); settings != nil {
|
||||
modelKey := provider + "/" + modelName
|
||||
if p, ok := settings[modelKey]; ok {
|
||||
params = p
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ModelInfo.Params (from custom model definitions).
|
||||
if params == nil && modelInfo != nil && modelInfo.Params != nil {
|
||||
params = modelInfo.Params
|
||||
}
|
||||
|
||||
if params == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Apply each parameter only when the user hasn't explicitly set it.
|
||||
// We check viper.IsSet() which returns true only when the key was
|
||||
// set via CLI flag, environment variable, or config file global section.
|
||||
|
||||
if params.MaxTokens != nil && !isExplicitlySet("max-tokens") {
|
||||
config.MaxTokens = *params.MaxTokens
|
||||
}
|
||||
if params.Temperature != nil && !isExplicitlySet("temperature") {
|
||||
config.Temperature = params.Temperature
|
||||
}
|
||||
if params.TopP != nil && !isExplicitlySet("top-p") {
|
||||
config.TopP = params.TopP
|
||||
}
|
||||
if params.TopK != nil && !isExplicitlySet("top-k") {
|
||||
config.TopK = params.TopK
|
||||
}
|
||||
if params.FrequencyPenalty != nil && !isExplicitlySet("frequency-penalty") {
|
||||
config.FrequencyPenalty = params.FrequencyPenalty
|
||||
}
|
||||
if params.PresencePenalty != nil && !isExplicitlySet("presence-penalty") {
|
||||
config.PresencePenalty = params.PresencePenalty
|
||||
}
|
||||
if len(params.StopSequences) > 0 && !isExplicitlySet("stop-sequences") {
|
||||
config.StopSequences = params.StopSequences
|
||||
}
|
||||
if params.ThinkingLevel != "" && !isExplicitlySet("thinking-level") {
|
||||
config.ThinkingLevel = params.ThinkingLevel
|
||||
}
|
||||
if params.SystemPrompt != "" && config.SystemPrompt == "" {
|
||||
// Resolve file paths: if the value points to an existing file, read it.
|
||||
// We check config.SystemPrompt == "" rather than isExplicitlySet because
|
||||
// viper.BindPFlag causes IsSet to return true even for unset flags.
|
||||
config.SystemPrompt = LoadSystemPromptValue(params.SystemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
// LoadSystemPromptValue resolves a system prompt value that may be either
|
||||
// inline text or a file path. If the value is a path to an existing file,
|
||||
// its contents are read and returned. Otherwise the string is returned as-is.
|
||||
// This mirrors config.LoadSystemPrompt but lives in the models package to
|
||||
// avoid circular dependencies.
|
||||
func LoadSystemPromptValue(input string) string {
|
||||
if input == "" {
|
||||
return ""
|
||||
}
|
||||
if info, err := os.Stat(input); err == nil && !info.IsDir() {
|
||||
content, err := os.ReadFile(input)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to read system prompt file %q: %v", input, err)
|
||||
return input
|
||||
}
|
||||
return strings.TrimSpace(string(content))
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
// isExplicitlySet returns true when the user has explicitly set a config key
|
||||
// via CLI flag, environment variable, or the global section of the config file.
|
||||
// Model-level defaults should not override explicitly set values.
|
||||
func isExplicitlySet(key string) bool {
|
||||
// viper.IsSet returns true if the key has been set in any of the
|
||||
// data stores (flag, env, config file, default). We need to check
|
||||
// whether the value was set at the global config level (not just
|
||||
// as a default). For generation params, the global config keys use
|
||||
// hyphenated names (e.g. "max-tokens", "top-p").
|
||||
//
|
||||
// Since viper merges all sources, IsSet returns true even for config
|
||||
// file values. This means global config file values (e.g.
|
||||
// temperature: 0.7 at the top level) will correctly take precedence
|
||||
// over model-level defaults, which is the desired behavior.
|
||||
return viper.IsSet(key)
|
||||
}
|
||||
|
||||
// GenerationParams holds per-model generation parameter defaults.
|
||||
// These are stored on ModelInfo and applied during provider creation.
|
||||
// Nil pointer fields mean "no model-level default" — the global config
|
||||
// or CLI flag value (if any) will be used instead.
|
||||
type GenerationParams struct {
|
||||
MaxTokens *int
|
||||
Temperature *float32
|
||||
TopP *float32
|
||||
TopK *int32
|
||||
FrequencyPenalty *float32
|
||||
PresencePenalty *float32
|
||||
StopSequences []string
|
||||
ThinkingLevel ThinkingLevel
|
||||
SystemPrompt string // Per-model system prompt (inline text or file path)
|
||||
}
|
||||
|
||||
// CustomModelConfig defines a custom model configuration loaded from the config file.
|
||||
// This is a duplicate here to avoid circular dependencies with internal/config.
|
||||
type CustomModelConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Name string `json:"name" yaml:"name"`
|
||||
BaseURL string `json:"baseUrl,omitempty" yaml:"baseUrl,omitempty"`
|
||||
APIKey string `json:"apiKey,omitempty" yaml:"apiKey,omitempty"`
|
||||
Family string `json:"family,omitempty" yaml:"family,omitempty"`
|
||||
Attachment bool `json:"attachment,omitempty" yaml:"attachment,omitempty"`
|
||||
Reasoning bool `json:"reasoning,omitempty" yaml:"reasoning,omitempty"`
|
||||
Temperature bool `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
Knowledge string `json:"knowledge,omitempty" yaml:"knowledge,omitempty"`
|
||||
Cost CostConfig `json:"cost" yaml:"cost"`
|
||||
Limit LimitConfig `json:"limit" yaml:"limit"`
|
||||
Params GenerationParamsConfig `json:"params,omitzero" yaml:"params,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationParamsConfig is the JSON/YAML-serializable form of generation
|
||||
// parameter defaults. Used in both customModels[].params and modelSettings[].
|
||||
type GenerationParamsConfig struct {
|
||||
MaxTokens *int `json:"maxTokens,omitempty" yaml:"maxTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty" yaml:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty" yaml:"topP,omitempty"`
|
||||
TopK *int32 `json:"topK,omitempty" yaml:"topK,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty" yaml:"frequencyPenalty,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty" yaml:"presencePenalty,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty" yaml:"stopSequences,omitempty"`
|
||||
ThinkingLevel string `json:"thinkingLevel,omitempty" yaml:"thinkingLevel,omitempty"`
|
||||
SystemPrompt string `json:"systemPrompt,omitempty" yaml:"systemPrompt,omitempty"`
|
||||
}
|
||||
|
||||
// CostConfig defines the pricing for a custom model.
|
||||
|
||||
@@ -0,0 +1,422 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestConvertGenerationParams(t *testing.T) {
|
||||
t.Run("empty config returns nil", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p != nil {
|
||||
t.Errorf("expected nil, got %+v", p)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("temperature only", func(t *testing.T) {
|
||||
temp := float32(0.7)
|
||||
cfg := GenerationParamsConfig{Temperature: &temp}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.7 {
|
||||
t.Errorf("expected temperature 0.7, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP != nil {
|
||||
t.Errorf("expected nil TopP, got %v", p.TopP)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("all params set", func(t *testing.T) {
|
||||
maxTokens := 8192
|
||||
temp := float32(0.5)
|
||||
topP := float32(0.9)
|
||||
topK := int32(50)
|
||||
freqPenalty := float32(0.1)
|
||||
presPenalty := float32(0.2)
|
||||
cfg := GenerationParamsConfig{
|
||||
MaxTokens: &maxTokens,
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
TopK: &topK,
|
||||
FrequencyPenalty: &freqPenalty,
|
||||
PresencePenalty: &presPenalty,
|
||||
StopSequences: []string{"STOP"},
|
||||
ThinkingLevel: "high",
|
||||
}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.MaxTokens == nil || *p.MaxTokens != 8192 {
|
||||
t.Errorf("expected maxTokens 8192, got %v", p.MaxTokens)
|
||||
}
|
||||
if p.Temperature == nil || *p.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5, got %v", p.Temperature)
|
||||
}
|
||||
if p.TopP == nil || *p.TopP != 0.9 {
|
||||
t.Errorf("expected topP 0.9, got %v", p.TopP)
|
||||
}
|
||||
if p.TopK == nil || *p.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", p.TopK)
|
||||
}
|
||||
if p.FrequencyPenalty == nil || *p.FrequencyPenalty != 0.1 {
|
||||
t.Errorf("expected frequencyPenalty 0.1, got %v", p.FrequencyPenalty)
|
||||
}
|
||||
if p.PresencePenalty == nil || *p.PresencePenalty != 0.2 {
|
||||
t.Errorf("expected presencePenalty 0.2, got %v", p.PresencePenalty)
|
||||
}
|
||||
if len(p.StopSequences) != 1 || p.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP], got %v", p.StopSequences)
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level parsing", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{ThinkingLevel: "medium"}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.ThinkingLevel != ThinkingMedium {
|
||||
t.Errorf("expected thinking level medium, got %v", p.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
t.Run("system prompt only", func(t *testing.T) {
|
||||
cfg := GenerationParamsConfig{SystemPrompt: "You are helpful."}
|
||||
p := convertGenerationParams(cfg)
|
||||
if p == nil {
|
||||
t.Fatal("expected non-nil")
|
||||
}
|
||||
if p.SystemPrompt != "You are helpful." {
|
||||
t.Errorf("expected system prompt, got %q", p.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithParams(t *testing.T) {
|
||||
temp := float32(0.8)
|
||||
topP := float32(0.95)
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
Temperature: true,
|
||||
Params: GenerationParamsConfig{
|
||||
Temperature: &temp,
|
||||
TopP: &topP,
|
||||
},
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params == nil {
|
||||
t.Fatal("expected non-nil Params")
|
||||
}
|
||||
if info.Params.Temperature == nil || *info.Params.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", info.Params.Temperature)
|
||||
}
|
||||
if info.Params.TopP == nil || *info.Params.TopP != 0.95 {
|
||||
t.Errorf("expected topP 0.95, got %v", info.Params.TopP)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelConfigToModelInfoWithoutParams(t *testing.T) {
|
||||
cfg := CustomModelConfig{
|
||||
Name: "Test Model",
|
||||
BaseURL: "http://localhost:8080/v1",
|
||||
}
|
||||
|
||||
info := modelConfigToModelInfo("test-model", cfg)
|
||||
|
||||
if info.Params != nil {
|
||||
t.Errorf("expected nil Params, got %+v", info.Params)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyModelSettings(t *testing.T) {
|
||||
// Save and restore viper state.
|
||||
originalViper := viper.AllSettings()
|
||||
defer func() {
|
||||
viper.Reset()
|
||||
for k, v := range originalViper {
|
||||
viper.Set(k, v)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Run("applies model params when not explicitly set", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
maxTokens := 4096
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
MaxTokens: &maxTokens,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature == nil || *config.Temperature != 0.8 {
|
||||
t.Errorf("expected temperature 0.8, got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 50 {
|
||||
t.Errorf("expected topK 50, got %v", config.TopK)
|
||||
}
|
||||
if config.MaxTokens != 4096 {
|
||||
t.Errorf("expected maxTokens 4096, got %d", config.MaxTokens)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit viper values take precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
viper.Set("temperature", 0.3)
|
||||
|
||||
temp := float32(0.8)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
},
|
||||
}
|
||||
|
||||
explicitTemp := float32(0.3)
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
Temperature: &explicitTemp,
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Temperature should NOT be overridden because it's explicitly set in viper
|
||||
if config.Temperature == nil || *config.Temperature != 0.3 {
|
||||
t.Errorf("expected temperature 0.3 (explicit), got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil model info is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
ApplyModelSettings(config, nil)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model info without params is safe", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{ID: "test-model"}
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.Temperature != nil {
|
||||
t.Errorf("expected nil temperature, got %v", config.Temperature)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings from viper takes priority over ModelInfo.Params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Set up modelSettings in viper (simulating config file)
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"temperature": 0.5,
|
||||
"topK": 30,
|
||||
},
|
||||
})
|
||||
|
||||
// ModelInfo has different params
|
||||
temp := float32(0.8)
|
||||
topK := int32(50)
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
Temperature: &temp,
|
||||
TopK: &topK,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// modelSettings should win over ModelInfo.Params
|
||||
if config.Temperature == nil || *config.Temperature != 0.5 {
|
||||
t.Errorf("expected temperature 0.5 (from modelSettings), got %v", config.Temperature)
|
||||
}
|
||||
if config.TopK == nil || *config.TopK != 30 {
|
||||
t.Errorf("expected topK 30 (from modelSettings), got %v", config.TopK)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop sequences applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
StopSequences: []string{"STOP", "END"},
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if len(config.StopSequences) != 2 || config.StopSequences[0] != "STOP" {
|
||||
t.Errorf("expected stop sequences [STOP END], got %v", config.StopSequences)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking level applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
ThinkingLevel: ThinkingHigh,
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.ThinkingLevel != ThinkingHigh {
|
||||
t.Errorf("expected thinking level high, got %v", config.ThinkingLevel)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt applied from model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "You are a coding assistant.",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "You are a coding assistant." {
|
||||
t.Errorf("expected system prompt to be set, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit system prompt takes precedence", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "Model-specific prompt",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
SystemPrompt: "Global prompt",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Global system prompt should NOT be overridden because config
|
||||
// already has a non-empty SystemPrompt.
|
||||
if config.SystemPrompt != "Global prompt" {
|
||||
t.Errorf("expected global prompt preserved, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("system prompt from file path", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
// Create a temp file with a system prompt
|
||||
tmpFile, err := os.CreateTemp("", "kit-test-prompt-*.txt")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() { _ = os.Remove(tmpFile.Name()) }()
|
||||
if _, err := tmpFile.WriteString(" Prompt from file "); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: tmpFile.Name(),
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "Prompt from file" {
|
||||
t.Errorf("expected trimmed file content, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("modelSettings system prompt overrides custom model params", func(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
viper.Set("modelSettings", map[string]any{
|
||||
"custom/test-model": map[string]any{
|
||||
"systemPrompt": "From modelSettings",
|
||||
},
|
||||
})
|
||||
|
||||
modelInfo := &ModelInfo{
|
||||
ID: "test-model",
|
||||
Params: &GenerationParams{
|
||||
SystemPrompt: "From custom model",
|
||||
},
|
||||
}
|
||||
|
||||
config := &ProviderConfig{
|
||||
ModelString: "custom/test-model",
|
||||
}
|
||||
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
if config.SystemPrompt != "From modelSettings" {
|
||||
t.Errorf("expected modelSettings prompt, got %q", config.SystemPrompt)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -241,6 +241,11 @@ func CreateProvider(ctx context.Context, config *ProviderConfig) (*ProviderResul
|
||||
validateModelConfig(config, modelInfo)
|
||||
}
|
||||
|
||||
// Apply per-model generation parameter defaults. Model-level params are
|
||||
// only applied for fields where the user hasn't explicitly set a value
|
||||
// via CLI flag or global config.
|
||||
ApplyModelSettings(config, modelInfo)
|
||||
|
||||
// Create the base provider
|
||||
var result *ProviderResult
|
||||
var createErr error
|
||||
|
||||
@@ -26,6 +26,11 @@ type ModelInfo struct {
|
||||
ProviderNPM string // Model-specific provider npm override (e.g. "@ai-sdk/anthropic")
|
||||
BaseURL string // Per-model base URL override (custom models only)
|
||||
APIKey string // Per-model API key override (custom models only)
|
||||
|
||||
// Params holds per-model generation parameter defaults. These are applied
|
||||
// when the user hasn't explicitly set the corresponding CLI flag or global
|
||||
// config value. Nil pointer fields mean "no model-level default".
|
||||
Params *GenerationParams
|
||||
}
|
||||
|
||||
// SupportsCaching returns true if this model family supports prompt caching.
|
||||
@@ -236,6 +241,18 @@ func (r *ModelsRegistry) LookupModel(provider, modelID string) *ModelInfo {
|
||||
return &modelInfo
|
||||
}
|
||||
|
||||
// LookupModelForSettings is a convenience function that parses a
|
||||
// "provider/model" string and looks up the ModelInfo in the global registry.
|
||||
// Returns nil when the model string is invalid or the model is unknown.
|
||||
// Used by Kit.SetModel to pre-apply per-model settings before CreateProvider.
|
||||
func LookupModelForSettings(modelString string) *ModelInfo {
|
||||
provider, modelName, err := ParseModelString(modelString)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return GetGlobalRegistry().LookupModel(provider, modelName)
|
||||
}
|
||||
|
||||
// getRequiredEnvVars returns the required environment variables for a provider.
|
||||
func (r *ModelsRegistry) getRequiredEnvVars(provider string) ([]string, error) {
|
||||
providerInfo, exists := r.providers[provider]
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/fences"
|
||||
)
|
||||
|
||||
// PromptTemplate is a named prompt template with shell-style argument placeholders.
|
||||
// It supports Pi-style $1, $2, $@, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
// It supports Pi-style $1, $2, $@, $+, $ARGUMENTS, ${@:N}, ${@:N:L} syntax.
|
||||
type PromptTemplate struct {
|
||||
// Name is the human-readable identifier for this template.
|
||||
Name string
|
||||
@@ -120,19 +122,28 @@ func ParseCommandArgs(input string) []string {
|
||||
|
||||
// argPlaceholder matches shell-style argument placeholders:
|
||||
// - $1, $2, etc. - positional arguments
|
||||
// - $@ - all arguments
|
||||
// - $@ - all arguments (zero or more)
|
||||
// - $+ - all arguments (one or more required)
|
||||
// - $ARGUMENTS - all arguments (alias for $@)
|
||||
// - ${@:N} - arguments from N onwards
|
||||
// - ${@:N:L} - L arguments starting from N
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$ARGUMENTS`)
|
||||
var argPlaceholder = regexp.MustCompile(`\$\{(\d+)\}|\$\{(\d+):(\d+)\}|\$\{ARGUMENTS\}|\$\{@(:\d+)?(:\d+)?\}|\$(\d+)|\$@|\$\+|\$ARGUMENTS`)
|
||||
|
||||
// SubstituteArgs replaces argument placeholders in content with values from args.
|
||||
// Supported placeholders:
|
||||
// - $N, ${N} - the Nth argument (1-indexed)
|
||||
// - $@, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - $@, $+, $ARGUMENTS, ${ARGUMENTS} - all arguments joined with spaces
|
||||
// - ${@:N} - arguments from index N onwards (0-indexed)
|
||||
// - ${@:N:L} - L arguments starting from index N (0-indexed)
|
||||
func SubstituteArgs(content string, args []string) string {
|
||||
return fences.ReplaceOutside(content, func(segment string) string {
|
||||
return substituteArgsInSegment(segment, args)
|
||||
})
|
||||
}
|
||||
|
||||
// substituteArgsInSegment performs argument substitution on a single text
|
||||
// segment that is known to be outside fenced code blocks.
|
||||
func substituteArgsInSegment(content string, args []string) string {
|
||||
return argPlaceholder.ReplaceAllStringFunc(content, func(match string) string {
|
||||
// Check for ${N} or ${N:M} format
|
||||
if strings.HasPrefix(match, "${") && strings.Contains(match, "}") {
|
||||
@@ -191,8 +202,8 @@ func SubstituteArgs(content string, args []string) string {
|
||||
if strings.HasPrefix(match, "$") && !strings.HasPrefix(match, "${") {
|
||||
suffix := match[1:]
|
||||
|
||||
// $@ or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "ARGUMENTS" {
|
||||
// $@, $+, or $ARGUMENTS
|
||||
if suffix == "@" || suffix == "+" || suffix == "ARGUMENTS" {
|
||||
return strings.Join(args, " ")
|
||||
}
|
||||
|
||||
@@ -266,6 +277,48 @@ func joinArgsRange(args []string, start, length int) string {
|
||||
return strings.Join(args[start:end], " ")
|
||||
}
|
||||
|
||||
// HasArgPlaceholders reports whether the template content contains any
|
||||
// argument placeholders ($1, $@, $ARGUMENTS, ${@:...}, etc.).
|
||||
// Placeholders inside fenced code blocks and inline code spans are ignored.
|
||||
func (t *PromptTemplate) HasArgPlaceholders() bool {
|
||||
return argPlaceholder.MatchString(fences.StripCode(t.Content))
|
||||
}
|
||||
|
||||
// RequiredArgs returns the number of positional arguments the template
|
||||
// expects. This is determined by the highest $N or ${N} placeholder found
|
||||
// in the content (1-indexed, so $2 means 2 args required). The $+
|
||||
// placeholder (required variadic) ensures at least 1. Optional wildcards
|
||||
// ($@, $ARGUMENTS) do not contribute to the count.
|
||||
func (t *PromptTemplate) RequiredArgs() int {
|
||||
content := fences.StripCode(t.Content)
|
||||
maxN := 0
|
||||
hasRequiredVariadic := strings.Contains(content, "$+")
|
||||
for _, match := range argPlaceholder.FindAllStringSubmatch(content, -1) {
|
||||
// Group 1: ${N} format — the N value.
|
||||
if match[1] != "" {
|
||||
if n, err := strconv.Atoi(match[1]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 2: ${N:M} format — the N value (start index).
|
||||
if match[2] != "" {
|
||||
if n, err := strconv.Atoi(match[2]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
// Group 6: $N format (no braces) — the N value.
|
||||
if match[6] != "" {
|
||||
if n, err := strconv.Atoi(match[6]); err == nil && n > maxN {
|
||||
maxN = n
|
||||
}
|
||||
}
|
||||
}
|
||||
if hasRequiredVariadic && maxN < 1 {
|
||||
maxN = 1
|
||||
}
|
||||
return maxN
|
||||
}
|
||||
|
||||
// Expand substitutes arguments into the template content and returns the result.
|
||||
// It first parses args from the input string, then substitutes them into the template.
|
||||
func (t *PromptTemplate) Expand(argsInput string) string {
|
||||
|
||||
@@ -129,6 +129,48 @@ func TestSubstituteArgs(t *testing.T) {
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "$1 inside code block preserved",
|
||||
content: "Use $1 here\n```bash\necho $1\n```\ndone",
|
||||
args: []string{"foo"},
|
||||
expected: "Use foo here\n```bash\necho $1\n```\ndone",
|
||||
},
|
||||
{
|
||||
name: "$@ inside code block preserved",
|
||||
content: "Run $@\n```\necho $@\n```\n",
|
||||
args: []string{"a", "b"},
|
||||
expected: "Run a b\n```\necho $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "all placeholders inside code block",
|
||||
content: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
args: []string{"x"},
|
||||
expected: "Prompt\n```\n$1 $2 $@\n```\n",
|
||||
},
|
||||
{
|
||||
name: "$1 inside inline code preserved",
|
||||
content: "Use `$1` here and $1 outside",
|
||||
args: []string{"foo"},
|
||||
expected: "Use `$1` here and foo outside",
|
||||
},
|
||||
{
|
||||
name: "$+ required variadic",
|
||||
content: "Args: $+",
|
||||
args: []string{"a", "b", "c"},
|
||||
expected: "Args: a b c",
|
||||
},
|
||||
{
|
||||
name: "$+ with empty args",
|
||||
content: "Args: $+",
|
||||
args: []string{},
|
||||
expected: "Args: ",
|
||||
},
|
||||
{
|
||||
name: "all placeholders in inline code",
|
||||
content: "Use `$1` and `$@` for args",
|
||||
args: []string{"x"},
|
||||
expected: "Use `$1` and `$@` for args",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -213,3 +255,78 @@ func TestPromptTemplateExpand(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasArgPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want bool
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt with no args", false},
|
||||
{"$1 placeholder", "Create a $1 component", true},
|
||||
{"$@ placeholder", "Run with args: $@", true},
|
||||
{"$ARGUMENTS placeholder", "Features: $ARGUMENTS", true},
|
||||
{"${1} placeholder", "Name: ${1}", true},
|
||||
{"${ARGUMENTS} placeholder", "All: ${ARGUMENTS}", true},
|
||||
{"${@:1} placeholder", "Rest: ${@:1}", true},
|
||||
{"${@:1:2} placeholder", "Slice: ${@:1:2}", true},
|
||||
{"dollar in text", "Cost is one hundred dollars", false},
|
||||
{"empty content", "", false},
|
||||
{"$1 inside code block only", "Prompt\n```\necho $1\n```\n", false},
|
||||
{"$1 outside and inside code block", "Use $1 here\n```\necho $1\n```\n", true},
|
||||
{"$@ inside code block only", "Prompt\n```bash\necho $@\n```\n", false},
|
||||
{"$+ placeholder", "Run with args: $+", true},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", false},
|
||||
{"$1 inside inline code only", "Use `$1` for positional args", false},
|
||||
{"$1 outside and in inline code", "Create $1 (see `$1` syntax)", true},
|
||||
{"$@ outside $1 in inline code", "Run $@ with `$1` syntax", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.HasArgPlaceholders(); got != tt.want {
|
||||
t.Errorf("HasArgPlaceholders() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequiredArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
want int
|
||||
}{
|
||||
{"no placeholders", "Just a plain prompt", 0},
|
||||
{"$1 only", "Create a $1 component", 1},
|
||||
{"$1 and $2", "Create $1 with $2", 2},
|
||||
{"$3 skipping $2", "Use $1 and $3", 3},
|
||||
{"${1} braced", "Name: ${1}", 1},
|
||||
{"${2} braced", "Name: ${1} Desc: ${2}", 2},
|
||||
{"$@ only", "Run with: $@", 0},
|
||||
{"$ARGUMENTS only", "Features: $ARGUMENTS", 0},
|
||||
{"${ARGUMENTS} only", "All: ${ARGUMENTS}", 0},
|
||||
{"$1 and $@", "Create $1 with extras: $@", 1},
|
||||
{"${@:1} slice only", "Rest: ${@:1}", 0},
|
||||
{"${@:1:2} slice only", "Slice: ${@:1:2}", 0},
|
||||
{"mixed $1 $2 and $@", "Create $1 named $2: $@", 2},
|
||||
{"empty content", "", 0},
|
||||
{"$2 inside code block only", "Prompt\n```\n$1 $2\n```\n", 0},
|
||||
{"$1 outside $2 inside code block", "Use $1\n```\n$2 inside\n```\n", 1},
|
||||
{"$+ only", "Run with: $+", 1},
|
||||
{"$+ and $2", "Create $2 with: $+", 2},
|
||||
{"$+ inside inline code only", "Use `$+` for required args", 0},
|
||||
{"$1 and $2 in inline code only", "Use `$1` and `$2` for args", 0},
|
||||
{"$1 outside $2 in inline code", "Create $1 (see `$2`)", 1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tpl := &PromptTemplate{Content: tt.content}
|
||||
if got := tpl.RequiredArgs(); got != tt.want {
|
||||
t.Errorf("RequiredArgs() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,15 +60,16 @@ type MCPConnection struct {
|
||||
// creation, health monitoring, and cleanup. The pool runs background health checks
|
||||
// to proactively identify and remove unhealthy connections.
|
||||
type MCPConnectionPool struct {
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
connections map[string]*MCPConnection
|
||||
config *ConnectionPoolConfig
|
||||
mu sync.RWMutex
|
||||
model fantasy.LanguageModel
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
oauthFlow *OAuthFlowRunner
|
||||
tokenStoreFactory TokenStoreFactory // custom factory for per-server token stores (nil = default FileTokenStore)
|
||||
}
|
||||
|
||||
// NewMCPConnectionPool creates a new MCP connection pool with the specified configuration.
|
||||
@@ -76,19 +77,20 @@ type MCPConnectionPool struct {
|
||||
// goroutine for periodic health checks that runs until Close is called.
|
||||
// The model parameter is used for MCP servers that require sampling support.
|
||||
// Thread-safe for concurrent use immediately after creation.
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler) *MCPConnectionPool {
|
||||
func NewMCPConnectionPool(config *ConnectionPoolConfig, model fantasy.LanguageModel, debug bool, authHandler MCPAuthHandler, tokenStoreFactory TokenStoreFactory) *MCPConnectionPool {
|
||||
if config == nil {
|
||||
config = DefaultConnectionPoolConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &MCPConnectionPool{
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
connections: make(map[string]*MCPConnection),
|
||||
config: config,
|
||||
model: model,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
debug: debug,
|
||||
tokenStoreFactory: tokenStoreFactory,
|
||||
}
|
||||
|
||||
if authHandler != nil {
|
||||
@@ -363,19 +365,29 @@ func (p *MCPConnectionPool) createSSEClient(ctx context.Context, serverConfig co
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithOAuth(transport.OAuthConfig{
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
|
||||
@@ -410,19 +422,29 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
}
|
||||
|
||||
// Enable OAuth for remote transports when an auth handler is configured.
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. Client ID and
|
||||
// scopes are discovered automatically via dynamic client registration and
|
||||
// server metadata (RFC 9728).
|
||||
// The OAuthConfig uses PKCE and the handler's redirect URI. If the server
|
||||
// config provides a pre-registered ClientID (for servers that don't support
|
||||
// dynamic client registration, e.g. GitHub), it is passed through directly.
|
||||
if p.oauthFlow != nil {
|
||||
tokenStore, tsErr := NewFileTokenStore(serverConfig.URL)
|
||||
tokenStore, tsErr := p.createTokenStore(serverConfig.URL)
|
||||
if tsErr != nil {
|
||||
return nil, fmt.Errorf("failed to create token store: %w", tsErr)
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(transport.OAuthConfig{
|
||||
oauthCfg := transport.OAuthConfig{
|
||||
RedirectURI: p.oauthFlow.handler.RedirectURI(),
|
||||
PKCEEnabled: true,
|
||||
TokenStore: tokenStore,
|
||||
}))
|
||||
}
|
||||
if serverConfig.OAuthClientID != "" {
|
||||
oauthCfg.ClientID = serverConfig.OAuthClientID
|
||||
}
|
||||
if serverConfig.OAuthClientSecret != "" {
|
||||
oauthCfg.ClientSecret = serverConfig.OAuthClientSecret
|
||||
}
|
||||
if len(serverConfig.OAuthScopes) > 0 {
|
||||
oauthCfg.Scopes = serverConfig.OAuthScopes
|
||||
}
|
||||
options = append(options, transport.WithHTTPOAuth(oauthCfg))
|
||||
}
|
||||
|
||||
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
|
||||
@@ -437,6 +459,16 @@ func (p *MCPConnectionPool) createStreamableClient(ctx context.Context, serverCo
|
||||
return streamableClient, nil
|
||||
}
|
||||
|
||||
// createTokenStore creates a token store for the given server URL.
|
||||
// If a custom TokenStoreFactory is configured, it is used; otherwise the
|
||||
// default file-backed token store is created.
|
||||
func (p *MCPConnectionPool) createTokenStore(serverURL string) (transport.TokenStore, error) {
|
||||
if p.tokenStoreFactory != nil {
|
||||
return p.tokenStoreFactory(serverURL)
|
||||
}
|
||||
return NewFileTokenStore(serverURL)
|
||||
}
|
||||
|
||||
// initializeClient initializes the client
|
||||
func (p *MCPConnectionPool) initializeClient(ctx context.Context, client client.MCPClient) error {
|
||||
initCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
@@ -583,6 +615,27 @@ func (p *MCPConnectionPool) GetClients() map[string]client.MCPClient {
|
||||
return clients
|
||||
}
|
||||
|
||||
// RemoveConnection closes and removes a single connection from the pool.
|
||||
// Returns an error if the connection does not exist or if closing fails.
|
||||
// Thread-safe for concurrent use.
|
||||
func (p *MCPConnectionPool) RemoveConnection(serverName string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conn, exists := p.connections[serverName]
|
||||
if !exists {
|
||||
return fmt.Errorf("connection %q not found in pool", serverName)
|
||||
}
|
||||
|
||||
err := conn.client.Close()
|
||||
delete(p.connections, serverName)
|
||||
|
||||
if p.debugLogger != nil && p.debugLogger.IsDebugEnabled() {
|
||||
p.debugLogger.LogDebug(fmt.Sprintf("[POOL] Removed connection %s", serverName))
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection pool, closing all client connections
|
||||
// and stopping the background health check goroutine. It attempts to close all
|
||||
// connections even if some fail, logging any errors encountered.
|
||||
|
||||
+153
-10
@@ -20,19 +20,25 @@ import (
|
||||
// pooling, health checks, tool name prefixing to avoid conflicts, and sampling support for LLM interactions.
|
||||
// Thread-safe for concurrent tool invocations.
|
||||
type MCPToolManager struct {
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
mu sync.Mutex // protects tools and toolMap during parallel loading
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
connectionPool *MCPConnectionPool
|
||||
tools []fantasy.AgentTool
|
||||
toolMap map[string]*toolMapping // maps prefixed tool names to their server and original name
|
||||
mu sync.Mutex // protects tools and toolMap during parallel loading
|
||||
model fantasy.LanguageModel // LLM model for sampling
|
||||
authHandler MCPAuthHandler // OAuth handler for remote servers (nil = no OAuth)
|
||||
tokenStoreFactory TokenStoreFactory // factory for creating per-server token stores (nil = default FileTokenStore)
|
||||
config *config.Config
|
||||
debug bool
|
||||
debugLogger DebugLogger
|
||||
|
||||
// onServerLoaded, if non-nil, is called when each server finishes loading.
|
||||
// Called with server name, tool count, and error (nil on success).
|
||||
onServerLoaded func(serverName string, toolCount int, err error)
|
||||
|
||||
// onToolsChanged, if non-nil, is called after AddServer or RemoveServer
|
||||
// mutates the tool list. The agent layer uses this to trigger a
|
||||
// rebuildFantasyAgent so the LLM sees the updated tools.
|
||||
onToolsChanged func()
|
||||
}
|
||||
|
||||
// toolMapping stores the mapping between prefixed tool names and their original details
|
||||
@@ -69,6 +75,20 @@ func (m *MCPToolManager) SetAuthHandler(handler MCPAuthHandler) {
|
||||
m.authHandler = handler
|
||||
}
|
||||
|
||||
// GetAuthHandler returns the OAuth handler for remote MCP server authentication.
|
||||
// Returns nil if no handler is configured.
|
||||
func (m *MCPToolManager) GetAuthHandler() MCPAuthHandler {
|
||||
return m.authHandler
|
||||
}
|
||||
|
||||
// SetTokenStoreFactory sets a custom factory for creating per-server OAuth token
|
||||
// stores. When set, the factory is called for each remote MCP server instead of
|
||||
// using the default file-based token store. This method should be called before
|
||||
// LoadTools.
|
||||
func (m *MCPToolManager) SetTokenStoreFactory(factory TokenStoreFactory) {
|
||||
m.tokenStoreFactory = factory
|
||||
}
|
||||
|
||||
// SetDebugLogger sets the debug logger for the tool manager.
|
||||
// The logger will be used to output detailed debugging information about MCP connections,
|
||||
// tool loading, and execution. If a connection pool exists, it will also be configured
|
||||
@@ -87,6 +107,126 @@ func (m *MCPToolManager) SetOnServerLoaded(cb func(serverName string, toolCount
|
||||
m.onServerLoaded = cb
|
||||
}
|
||||
|
||||
// SetOnToolsChanged sets the callback that's invoked after AddServer or
|
||||
// RemoveServer mutates the tool list. The agent layer uses this to trigger
|
||||
// a rebuild of the fantasy agent so the LLM sees the updated tool set.
|
||||
func (m *MCPToolManager) SetOnToolsChanged(cb func()) {
|
||||
m.onToolsChanged = cb
|
||||
}
|
||||
|
||||
// AddServer connects to a new MCP server at runtime and loads its tools.
|
||||
// The server's tools are immediately available to the agent after this call.
|
||||
// Returns the number of tools loaded from the server.
|
||||
//
|
||||
// If the connection pool has not been initialised yet (i.e. LoadTools was never
|
||||
// called), AddServer creates one automatically using the manager's current
|
||||
// configuration.
|
||||
//
|
||||
// Returns an error if a server with the same name is already loaded, or if
|
||||
// the connection or tool loading fails.
|
||||
func (m *MCPToolManager) AddServer(ctx context.Context, name string, cfg config.MCPServerConfig) (int, error) {
|
||||
m.mu.Lock()
|
||||
// Check for duplicate.
|
||||
if _, exists := m.toolMap[name+"__"]; exists {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
// More thorough duplicate check: scan toolMap for any key with the server prefix.
|
||||
prefix := name + "__"
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
m.mu.Unlock()
|
||||
return 0, fmt.Errorf("MCP server %q is already loaded", name)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Lazily create the connection pool if LoadTools was never called.
|
||||
m.ensureConnectionPool()
|
||||
|
||||
count, err := m.loadServerTools(ctx, name, cfg)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to add MCP server %q: %w", name, err)
|
||||
}
|
||||
|
||||
// Notify listeners.
|
||||
if m.onServerLoaded != nil {
|
||||
m.onServerLoaded(name, count, nil)
|
||||
}
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// RemoveServer disconnects an MCP server and removes all its tools.
|
||||
// After this call the agent will no longer see or be able to call tools from
|
||||
// the named server. Returns an error if the server is not loaded.
|
||||
func (m *MCPToolManager) RemoveServer(name string) error {
|
||||
prefix := name + "__"
|
||||
|
||||
m.mu.Lock()
|
||||
|
||||
// Check the server actually has tools loaded.
|
||||
found := false
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
m.mu.Unlock()
|
||||
return fmt.Errorf("MCP server %q is not loaded", name)
|
||||
}
|
||||
|
||||
// Remove tools belonging to this server.
|
||||
newTools := make([]fantasy.AgentTool, 0, len(m.tools))
|
||||
for _, t := range m.tools {
|
||||
if len(t.Info().Name) < len(prefix) || t.Info().Name[:len(prefix)] != prefix {
|
||||
newTools = append(newTools, t)
|
||||
}
|
||||
}
|
||||
m.tools = newTools
|
||||
|
||||
// Remove tool mappings.
|
||||
for k := range m.toolMap {
|
||||
if len(k) >= len(prefix) && k[:len(prefix)] == prefix {
|
||||
delete(m.toolMap, k)
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
// Close the connection in the pool (best-effort).
|
||||
if m.connectionPool != nil {
|
||||
_ = m.connectionPool.RemoveConnection(name)
|
||||
}
|
||||
|
||||
if m.onToolsChanged != nil {
|
||||
m.onToolsChanged()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureConnectionPool lazily creates a connection pool if one does not exist.
|
||||
// This allows AddServer to work even if LoadTools was never called.
|
||||
func (m *MCPToolManager) ensureConnectionPool() {
|
||||
if m.connectionPool != nil {
|
||||
return
|
||||
}
|
||||
debug := false
|
||||
if m.config != nil {
|
||||
debug = m.config.Debug
|
||||
}
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
}
|
||||
|
||||
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
|
||||
// It initializes the connection pool, connects to each configured server, and loads their tools.
|
||||
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
|
||||
@@ -99,7 +239,7 @@ func (m *MCPToolManager) LoadTools(ctx context.Context, cfg *config.Config) erro
|
||||
if m.debugLogger == nil {
|
||||
m.debugLogger = NewSimpleDebugLogger(cfg.Debug)
|
||||
}
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler)
|
||||
m.connectionPool = NewMCPConnectionPool(DefaultConnectionPoolConfig(), m.model, cfg.Debug, m.authHandler, m.tokenStoreFactory)
|
||||
m.connectionPool.SetDebugLogger(m.debugLogger)
|
||||
|
||||
// Load all servers in parallel. Each server connection (subprocess
|
||||
@@ -290,6 +430,9 @@ func (m *MCPToolManager) GetLoadedServerNames() []string {
|
||||
// proper cleanup of stdio processes, network connections, and other resources.
|
||||
// It is safe to call Close multiple times.
|
||||
func (m *MCPToolManager) Close() error {
|
||||
if m.connectionPool == nil {
|
||||
return nil
|
||||
}
|
||||
return m.connectionPool.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// testdataDir returns the absolute path to the testdata directory.
|
||||
func testdataDir(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "testdata")
|
||||
}
|
||||
|
||||
// echoServerConfig returns an MCPServerConfig for the test echo MCP server.
|
||||
func echoServerConfig(t *testing.T) config.MCPServerConfig {
|
||||
t.Helper()
|
||||
script := filepath.Join(testdataDir(t), "echo_server.py")
|
||||
if _, err := os.Stat(script); err != nil {
|
||||
t.Skipf("echo_server.py not found: %v", err)
|
||||
}
|
||||
return config.MCPServerConfig{
|
||||
Command: []string{"python3", script},
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_Integration tests adding a real MCP server
|
||||
// at runtime and verifying tools are loaded.
|
||||
func TestMCPToolManager_AddServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Track callbacks.
|
||||
var mu sync.Mutex
|
||||
var loadedServer string
|
||||
var loadedCount int
|
||||
toolsChangedCount := 0
|
||||
|
||||
manager.SetOnServerLoaded(func(name string, count int, err error) {
|
||||
mu.Lock()
|
||||
loadedServer = name
|
||||
loadedCount = count
|
||||
mu.Unlock()
|
||||
})
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Add the server.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools from echo server, got %d", count)
|
||||
}
|
||||
|
||||
// Verify callbacks fired.
|
||||
mu.Lock()
|
||||
if loadedServer != "echo" {
|
||||
t.Errorf("Expected onServerLoaded for 'echo', got %q", loadedServer)
|
||||
}
|
||||
if loadedCount != 2 {
|
||||
t.Errorf("Expected onServerLoaded count=2, got %d", loadedCount)
|
||||
}
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify tools are accessible.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify tool names are prefixed.
|
||||
toolNames := make(map[string]bool)
|
||||
for _, tool := range tools {
|
||||
toolNames[tool.Info().Name] = true
|
||||
}
|
||||
if !toolNames["echo__echo"] {
|
||||
t.Error("Expected tool 'echo__echo'")
|
||||
}
|
||||
if !toolNames["echo__greet"] {
|
||||
t.Error("Expected tool 'echo__greet'")
|
||||
}
|
||||
|
||||
// Verify server appears in loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
if !slices.Contains(names, "echo") {
|
||||
t.Errorf("Expected 'echo' in loaded server names, got: %v", names)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_Integration tests removing a real MCP server
|
||||
// and verifying tools are cleaned up.
|
||||
func TestMCPToolManager_RemoveServer_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server first.
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 tools, got %d", count)
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
toolsChangedCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
toolsChangedCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// Remove the server.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify tools are gone.
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removal, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Verify callback fired.
|
||||
mu.Lock()
|
||||
if toolsChangedCount != 1 {
|
||||
t.Errorf("Expected onToolsChanged called once, got %d", toolsChangedCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
|
||||
// Verify server is gone from loaded names.
|
||||
names := manager.GetLoadedServerNames()
|
||||
for _, n := range names {
|
||||
if n == "echo" {
|
||||
t.Error("Server 'echo' should not appear in loaded names after removal")
|
||||
}
|
||||
}
|
||||
|
||||
// Removing again should error.
|
||||
err = manager.RemoveServer("echo")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error removing already-removed server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddRemoveMultiple_Integration tests adding and removing
|
||||
// multiple servers, verifying tool isolation.
|
||||
func TestMCPToolManager_AddRemoveMultiple_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add two servers with the same binary but different names.
|
||||
count1, err := manager.AddServer(ctx, "server-a", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-a failed: %v", err)
|
||||
}
|
||||
count2, err := manager.AddServer(ctx, "server-b", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("AddServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
totalTools := count1 + count2
|
||||
if totalTools != 4 {
|
||||
t.Fatalf("Expected 4 total tools (2+2), got %d", totalTools)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 4 {
|
||||
t.Fatalf("Expected 4 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remove server-a, verify server-b tools remain.
|
||||
err = manager.RemoveServer("server-a")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-a failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Fatalf("Expected 2 tools after removing server-a, got %d", len(tools))
|
||||
}
|
||||
|
||||
// Remaining tools should all be from server-b.
|
||||
for _, tool := range tools {
|
||||
if !strings.HasPrefix(tool.Info().Name, "server-b__") {
|
||||
t.Errorf("Expected tool from server-b, got: %s", tool.Info().Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove server-b.
|
||||
err = manager.RemoveServer("server-b")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer server-b failed: %v", err)
|
||||
}
|
||||
|
||||
tools = manager.GetTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("Expected 0 tools after removing all servers, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateDetection_Integration tests that
|
||||
// adding a server with the same name as an already loaded server errors.
|
||||
func TestMCPToolManager_AddServer_DuplicateDetection_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add the server.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
// Try to add again with the same name.
|
||||
_, err = manager.AddServer(ctx, "echo", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error adding duplicate server")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "already loaded") {
|
||||
t.Errorf("Expected 'already loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddAfterRemove_Integration tests that a server can be
|
||||
// re-added after being removed.
|
||||
func TestMCPToolManager_AddAfterRemove_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
manager := NewMCPToolManager()
|
||||
defer func() { _ = manager.Close() }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cfg := echoServerConfig(t)
|
||||
|
||||
// Add, remove, re-add.
|
||||
_, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("First AddServer failed: %v", err)
|
||||
}
|
||||
|
||||
err = manager.RemoveServer("echo")
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveServer failed: %v", err)
|
||||
}
|
||||
|
||||
count, err := manager.AddServer(ctx, "echo", cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Re-AddServer failed: %v", err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tools on re-add, got %d", count)
|
||||
}
|
||||
|
||||
tools := manager.GetTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("Expected 2 tools after re-add, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mark3labs/kit/internal/config"
|
||||
)
|
||||
|
||||
// TestMCPToolManager_AddServer_DuplicateName verifies that adding a server
|
||||
// with a name that already exists returns an error.
|
||||
func TestMCPToolManager_AddServer_DuplicateName(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
cfg := config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// First add will fail (bad command), but let's test the duplicate detection
|
||||
// by simulating a loaded server via LoadTools first.
|
||||
loadCfg := &config.Config{
|
||||
MCPServers: map[string]config.MCPServerConfig{
|
||||
"test-server": cfg,
|
||||
},
|
||||
}
|
||||
// This will fail to load but creates the connection pool.
|
||||
_ = manager.LoadTools(ctx, loadCfg)
|
||||
|
||||
// Now try to add the same server name — the tools didn't load (bad command),
|
||||
// so AddServer should not find a duplicate and should fail with connection error.
|
||||
_, err := manager.AddServer(ctx, "test-server", cfg)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when adding server with bad command, got nil")
|
||||
}
|
||||
// It should be a connection error, not a duplicate error.
|
||||
if strings.Contains(err.Error(), "already loaded") {
|
||||
t.Fatalf("Should not report duplicate since server failed to load initially: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_RemoveServer_NotLoaded verifies that removing a server
|
||||
// that doesn't exist returns an appropriate error.
|
||||
func TestMCPToolManager_RemoveServer_NotLoaded(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
err := manager.RemoveServer("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error when removing non-existent server, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not loaded") {
|
||||
t.Errorf("Expected 'not loaded' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_AddServer_CreatesConnectionPool verifies that AddServer
|
||||
// lazily creates a connection pool when LoadTools was never called.
|
||||
func TestMCPToolManager_AddServer_CreatesConnectionPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// Connection pool should be nil initially.
|
||||
if manager.connectionPool != nil {
|
||||
t.Fatal("Expected nil connection pool before any operation")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// AddServer with a bad command — should fail, but the pool should be created.
|
||||
_, err := manager.AddServer(ctx, "lazy-server", config.MCPServerConfig{
|
||||
Command: []string{"non-existent-command"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for bad command")
|
||||
}
|
||||
|
||||
// Connection pool should have been created.
|
||||
if manager.connectionPool == nil {
|
||||
t.Fatal("Expected connection pool to be created lazily by AddServer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_OnToolsChanged_Callback verifies that the onToolsChanged
|
||||
// callback fires on RemoveServer (we can't easily test AddServer with a real
|
||||
// MCP server, but we can test the callback wiring).
|
||||
func TestMCPToolManager_OnToolsChanged_Callback(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
var mu sync.Mutex
|
||||
callCount := 0
|
||||
manager.SetOnToolsChanged(func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
})
|
||||
|
||||
// RemoveServer on non-existent should NOT fire callback.
|
||||
_ = manager.RemoveServer("nonexistent")
|
||||
|
||||
mu.Lock()
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected 0 callback calls for failed remove, got %d", callCount)
|
||||
}
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// TestMCPToolManager_Close_NilPool verifies Close is safe when the connection
|
||||
// pool was never initialized.
|
||||
func TestMCPToolManager_Close_NilPool(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
err := manager.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Expected nil error from Close with nil pool, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPConnectionPool_RemoveConnection_NotFound verifies that removing a
|
||||
// non-existent connection returns an error.
|
||||
func TestMCPConnectionPool_RemoveConnection_NotFound(t *testing.T) {
|
||||
pool := NewMCPConnectionPool(DefaultConnectionPoolConfig(), nil, false, nil, nil)
|
||||
defer func() { _ = pool.Close() }()
|
||||
|
||||
err := pool.RemoveConnection("nonexistent")
|
||||
if err == nil {
|
||||
t.Fatal("Expected error for non-existent connection")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not found") {
|
||||
t.Errorf("Expected 'not found' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPToolManager_EnsureConnectionPool_Idempotent verifies that
|
||||
// ensureConnectionPool doesn't recreate an existing pool.
|
||||
func TestMCPToolManager_EnsureConnectionPool_Idempotent(t *testing.T) {
|
||||
manager := NewMCPToolManager()
|
||||
|
||||
// First call creates the pool.
|
||||
manager.ensureConnectionPool()
|
||||
pool1 := manager.connectionPool
|
||||
if pool1 == nil {
|
||||
t.Fatal("Expected pool to be created")
|
||||
}
|
||||
|
||||
// Second call should be a no-op.
|
||||
manager.ensureConnectionPool()
|
||||
pool2 := manager.connectionPool
|
||||
if pool1 != pool2 {
|
||||
t.Fatal("Expected ensureConnectionPool to be idempotent")
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/mark3labs/mcp-go/client"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// MCPAuthHandler is the internal interface for handling MCP OAuth flows.
|
||||
@@ -21,6 +22,12 @@ type MCPAuthHandler interface {
|
||||
HandleAuth(ctx context.Context, serverName string, authURL string) (callbackURL string, err error)
|
||||
}
|
||||
|
||||
// TokenStoreFactory creates a transport.TokenStore for a given MCP server URL.
|
||||
// When provided to the connection pool, it is called once per remote MCP server
|
||||
// instead of using the default file-based token store. Implementations can
|
||||
// return any transport.TokenStore — in-memory, database-backed, encrypted, etc.
|
||||
type TokenStoreFactory func(serverURL string) (transport.TokenStore, error)
|
||||
|
||||
// OAuthFlowRunner handles the OAuth authorization flow when an MCP server
|
||||
// returns an OAuthAuthorizationRequiredError. It coordinates dynamic client
|
||||
// registration, PKCE generation, user authorization (via MCPAuthHandler),
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Minimal MCP server over stdio for testing. Exposes one tool: echo."""
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
def read_message():
|
||||
"""Read a JSON-RPC message from stdin."""
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
return None
|
||||
return json.loads(line.strip())
|
||||
|
||||
|
||||
def write_message(msg):
|
||||
"""Write a JSON-RPC message to stdout."""
|
||||
sys.stdout.write(json.dumps(msg) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def handle(msg):
|
||||
method = msg.get("method", "")
|
||||
mid = msg.get("id")
|
||||
|
||||
if method == "initialize":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-echo", "version": "1.0.0"},
|
||||
},
|
||||
})
|
||||
elif method == "notifications/initialized":
|
||||
pass # no response needed
|
||||
elif method == "tools/list":
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echoes the input text back.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {"type": "string", "description": "Text to echo"}
|
||||
},
|
||||
"required": ["text"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "greet",
|
||||
"description": "Returns a greeting.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Name to greet"}
|
||||
},
|
||||
"required": ["name"],
|
||||
},
|
||||
},
|
||||
]
|
||||
},
|
||||
})
|
||||
elif method == "tools/call":
|
||||
tool_name = msg["params"]["name"]
|
||||
args = msg["params"].get("arguments", {})
|
||||
if tool_name == "echo":
|
||||
text = args.get("text", "")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": text}]
|
||||
},
|
||||
})
|
||||
elif tool_name == "greet":
|
||||
name = args.get("name", "World")
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": f"Hello, {name}!"}]
|
||||
},
|
||||
})
|
||||
else:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown tool: {tool_name}"},
|
||||
})
|
||||
elif method == "ping":
|
||||
write_message({"jsonrpc": "2.0", "id": mid, "result": {}})
|
||||
else:
|
||||
if mid is not None:
|
||||
write_message({
|
||||
"jsonrpc": "2.0",
|
||||
"id": mid,
|
||||
"error": {"code": -32601, "message": f"Unknown method: {method}"},
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
while True:
|
||||
msg = read_message()
|
||||
if msg is None:
|
||||
break
|
||||
handle(msg)
|
||||
@@ -20,6 +20,7 @@ type SlashCommand struct {
|
||||
Aliases []string
|
||||
Category string // e.g., "Navigation", "System", "Info"
|
||||
Complete func(prefix string) []string // optional argument tab-completion
|
||||
HasArgs bool // true when the command expects arguments (e.g. prompt templates with placeholders)
|
||||
}
|
||||
|
||||
// SlashCommands provides the global registry of all available slash commands
|
||||
|
||||
@@ -139,7 +139,9 @@ func (h *CLIEventHandler) Handle(msg tea.Msg) {
|
||||
case "block":
|
||||
h.cli.DisplayExtensionBlock(e.Text, e.BorderColor, e.Subtitle)
|
||||
default:
|
||||
fmt.Println(e.Text)
|
||||
// Route unstyled extension prints through the system message
|
||||
// renderer so they get consistent formatting and timestamps.
|
||||
h.cli.DisplayInfo(e.Text)
|
||||
}
|
||||
|
||||
case app.StepCompleteEvent:
|
||||
|
||||
@@ -109,9 +109,7 @@ func SetupCLI(opts *CLISetupOptions) (*CLI, error) {
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("")
|
||||
|
||||
// Display model info
|
||||
// Display model info (the system message block provides its own spacing).
|
||||
if provider != "unknown" && model != "unknown" {
|
||||
cli.DisplayInfo(fmt.Sprintf("Model loaded: %s (%s)", provider, model))
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/kit/internal/fences"
|
||||
)
|
||||
|
||||
// fileTokenPattern matches @file references in user text. Supports:
|
||||
@@ -20,6 +22,14 @@ var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
//
|
||||
// Returns the original text unchanged if no valid @file references are found.
|
||||
func ProcessFileAttachments(text string, cwd string) string {
|
||||
return fences.ReplaceOutside(text, func(segment string) string {
|
||||
return processFileTokens(segment, cwd)
|
||||
})
|
||||
}
|
||||
|
||||
// processFileTokens handles @file replacement in a single text segment
|
||||
// that is known to be outside fenced code blocks.
|
||||
func processFileTokens(text string, cwd string) string {
|
||||
tokens := fileTokenPattern.FindAllString(text, -1)
|
||||
if len(tokens) == 0 {
|
||||
return text
|
||||
|
||||
+17
-6
@@ -285,16 +285,25 @@ func (s *InputComponent) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
s.textarea.CursorEnd()
|
||||
return s, nil
|
||||
}
|
||||
selectedCmd := s.filtered[s.selected].Command
|
||||
// Populate textarea with selected item and submit on next tick.
|
||||
if s.argMode {
|
||||
s.textarea.SetValue(s.argCommand + " " + s.filtered[s.selected].Command.Name)
|
||||
s.textarea.SetValue(s.argCommand + " " + selectedCmd.Name)
|
||||
} else {
|
||||
s.textarea.SetValue(s.filtered[s.selected].Command.Name)
|
||||
s.textarea.SetValue(selectedCmd.Name)
|
||||
}
|
||||
s.textarea.CursorEnd()
|
||||
s.showPopup = false
|
||||
s.selected = 0
|
||||
s.submitNext = true
|
||||
// If the selected command expects arguments, populate
|
||||
// the input with the command + trailing space so the
|
||||
// user can type args, instead of auto-submitting.
|
||||
if !s.argMode && selectedCmd.HasArgs {
|
||||
s.textarea.SetValue(selectedCmd.Name + " ")
|
||||
s.textarea.CursorEnd()
|
||||
} else {
|
||||
s.submitNext = true
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
return s, nil
|
||||
@@ -521,12 +530,14 @@ func (s *InputComponent) View() tea.View {
|
||||
} else {
|
||||
hint = "^X s steer"
|
||||
}
|
||||
} else if availableHintWidth >= 80 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+x e editor • ctrl+v paste image"
|
||||
} else if availableHintWidth >= 67 {
|
||||
hint = "enter submit • ctrl+j / shift+enter new line • ctrl+v paste image"
|
||||
hint = "enter submit • ctrl+j new line • ctrl+x e editor • ctrl+v image"
|
||||
} else if availableHintWidth >= 40 {
|
||||
hint = "↵ submit • ctrl+j newline • ctrl+v image"
|
||||
hint = "↵ submit • ctrl+j newline • ^X e editor"
|
||||
} else if availableHintWidth >= 20 {
|
||||
hint = "↵ submit • ctrl+j"
|
||||
hint = "↵ submit • ^X e editor"
|
||||
} else {
|
||||
hint = "↵ submit"
|
||||
}
|
||||
|
||||
+102
-7
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
tea "charm.land/bubbletea/v2"
|
||||
"charm.land/lipgloss/v2"
|
||||
"github.com/charmbracelet/x/editor"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/mark3labs/kit/internal/app"
|
||||
@@ -826,6 +827,7 @@ func NewAppModel(appCtrl AppController, opts AppModelOptions) *AppModel {
|
||||
Name: "/" + tpl.Name,
|
||||
Description: tpl.Description,
|
||||
Category: "Prompts",
|
||||
HasArgs: tpl.HasArgPlaceholders(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1333,6 +1335,45 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
}
|
||||
}
|
||||
case "e":
|
||||
// Ctrl+X e → open $EDITOR to compose/edit the prompt.
|
||||
editorApp := os.Getenv("VISUAL")
|
||||
if editorApp == "" {
|
||||
editorApp = os.Getenv("EDITOR")
|
||||
}
|
||||
if editorApp == "" {
|
||||
m.printSystemMessage("Set `$EDITOR` or `$VISUAL` to use external editor")
|
||||
} else {
|
||||
var currentText string
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
currentText = ic.textarea.Value()
|
||||
}
|
||||
tmpFile, err := os.CreateTemp("", "kit_prompt_*.md")
|
||||
if err == nil {
|
||||
if currentText != "" {
|
||||
_, _ = tmpFile.WriteString(currentText)
|
||||
}
|
||||
_ = tmpFile.Close()
|
||||
editorCmd, cmdErr := editor.Command(editorApp, tmpFile.Name())
|
||||
if cmdErr != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
m.printSystemMessage(fmt.Sprintf("Failed to open editor: %v", cmdErr))
|
||||
} else {
|
||||
cmds = append(cmds, tea.ExecProcess(editorCmd, func(err error) tea.Msg {
|
||||
if err != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return externalEditorMsg{err: err}
|
||||
}
|
||||
content, readErr := os.ReadFile(tmpFile.Name())
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
if readErr != nil {
|
||||
return externalEditorMsg{err: readErr}
|
||||
}
|
||||
return externalEditorMsg{text: string(content)}
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Chord consumed — don't propagate to children.
|
||||
return m, tea.Batch(cmds...)
|
||||
@@ -1444,7 +1485,15 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
// Expand prompt templates. If the input matches a template name,
|
||||
// substitute arguments and use the expanded content as the prompt.
|
||||
if expanded, ok := m.expandPromptTemplate(msg.Text); ok {
|
||||
if expanded, ok, validationErr := m.expandPromptTemplate(msg.Text); validationErr != "" {
|
||||
// Validation failed — re-populate the input so the user can
|
||||
// append the missing arguments without retyping.
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.textarea.SetValue(msg.Text + " ")
|
||||
ic.textarea.CursorEnd()
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
} else if ok {
|
||||
msg.Text = expanded
|
||||
}
|
||||
|
||||
@@ -1820,6 +1869,12 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
// Refresh content to show the finalized message.
|
||||
m.refreshContent()
|
||||
|
||||
// Reset context token display — the pre-compaction count is stale.
|
||||
// The next API call will set the accurate post-compaction value.
|
||||
if m.usageTracker != nil {
|
||||
m.usageTracker.SetContextTokens(0)
|
||||
}
|
||||
|
||||
// Print stats as a separate system message.
|
||||
saved := msg.OriginalTokens - msg.CompactedTokens
|
||||
statsMsg := fmt.Sprintf(
|
||||
@@ -1967,6 +2022,19 @@ func (m *AppModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
m.printSystemMessage(msg.output)
|
||||
}
|
||||
|
||||
case externalEditorMsg:
|
||||
// User returned from $EDITOR. Replace input textarea content with
|
||||
// whatever they saved in the temp file. On error (e.g. :cq in vim)
|
||||
// the original input is silently preserved.
|
||||
if msg.err == nil {
|
||||
if ic, ok := m.input.(*InputComponent); ok {
|
||||
ic.textarea.SetValue(msg.text)
|
||||
// Move cursor to the end of the inserted text.
|
||||
ic.textarea.CursorEnd()
|
||||
}
|
||||
m.layoutDirty = true
|
||||
}
|
||||
|
||||
case extReloadResultMsg:
|
||||
if msg.err != nil {
|
||||
m.printSystemMessage(fmt.Sprintf("Extension reload failed: %v", msg.err))
|
||||
@@ -2826,15 +2894,20 @@ func (m *AppModel) handleExtensionCommand(text string) tea.Cmd {
|
||||
|
||||
// expandPromptTemplate checks if the submitted text matches a prompt template
|
||||
// and returns the expanded content with arguments substituted.
|
||||
// Returns (expanded, true) if a template was found and expanded, (text, false) otherwise.
|
||||
func (m *AppModel) expandPromptTemplate(text string) (string, bool) {
|
||||
//
|
||||
// Return values:
|
||||
// - (expanded, true, "") — template matched and expanded successfully
|
||||
// - (text, false, "") — no template matched; caller should treat text as-is
|
||||
// - ("", false, reason) — template matched but validation failed; reason
|
||||
// contains a user-facing error message (already printed to ScrollList)
|
||||
func (m *AppModel) expandPromptTemplate(text string) (string, bool, string) {
|
||||
if len(m.promptTemplates) == 0 {
|
||||
return text, false
|
||||
return text, false, ""
|
||||
}
|
||||
|
||||
// Only consider inputs that look like slash commands.
|
||||
if !strings.HasPrefix(text, "/") {
|
||||
return text, false
|
||||
return text, false, ""
|
||||
}
|
||||
|
||||
// Split: "/templatename arg1 arg2" → name="/templatename", args="arg1 arg2"
|
||||
@@ -2844,11 +2917,24 @@ func (m *AppModel) expandPromptTemplate(text string) (string, bool) {
|
||||
// Find matching template
|
||||
for _, tpl := range m.promptTemplates {
|
||||
if tpl.Name == name {
|
||||
return tpl.Expand(args), true
|
||||
// Validate that enough positional arguments were provided.
|
||||
required := tpl.RequiredArgs()
|
||||
if required > 0 {
|
||||
provided := len(prompts.ParseCommandArgs(args))
|
||||
if provided < required {
|
||||
reason := fmt.Sprintf(
|
||||
"/%s requires %d argument(s), got %d",
|
||||
name, required, provided,
|
||||
)
|
||||
m.printSystemMessage(reason)
|
||||
return "", false, reason
|
||||
}
|
||||
}
|
||||
return tpl.Expand(args), true, ""
|
||||
}
|
||||
}
|
||||
|
||||
return text, false
|
||||
return text, false, ""
|
||||
}
|
||||
|
||||
// refreshPromptTemplates reloads prompt templates from the provider callback
|
||||
@@ -2873,6 +2959,7 @@ func (m *AppModel) refreshPromptTemplates() {
|
||||
Name: "/" + tpl.Name,
|
||||
Description: tpl.Description,
|
||||
Category: "Prompts",
|
||||
HasArgs: tpl.HasArgPlaceholders(),
|
||||
})
|
||||
}
|
||||
ic.commands = kept
|
||||
@@ -2961,6 +3048,7 @@ func (m *AppModel) printHelpMessage() {
|
||||
"- `Ctrl+C`: Exit at any time\n" +
|
||||
"- `ESC` (x2): Cancel ongoing LLM generation\n" +
|
||||
"- `Ctrl+X s`: Steer — redirect the agent mid-turn (injected between tool calls)\n" +
|
||||
"- `Ctrl+X e`: Open `$EDITOR` to compose/edit your prompt\n" +
|
||||
"- `Enter` (while working): Queue message for after the agent finishes\n\n" +
|
||||
"You can also just type your message to chat with the AI assistant."
|
||||
m.printSystemMessage(help)
|
||||
@@ -4048,6 +4136,13 @@ func cancelTimerCmd() tea.Cmd {
|
||||
// Interactive prompt support
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// externalEditorMsg is sent when the user returns from $EDITOR after
|
||||
// composing a prompt via the Ctrl+X e chord.
|
||||
type externalEditorMsg struct {
|
||||
text string
|
||||
err error
|
||||
}
|
||||
|
||||
// shareResultMsg carries the result of an async gist upload.
|
||||
type shareResultMsg struct {
|
||||
err error
|
||||
|
||||
@@ -5,6 +5,7 @@ package render
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"charm.land/lipgloss/v2"
|
||||
@@ -13,8 +14,14 @@ import (
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// fileTokenPattern matches @file references in user text. Supports:
|
||||
// - @"path with spaces.txt" (quoted)
|
||||
// - @path/to/file.txt (unquoted, no spaces)
|
||||
var fileTokenPattern = regexp.MustCompile(`@"[^"]+"|@[^\s]+`)
|
||||
|
||||
// UserBlock renders a user message with herald Tip styling.
|
||||
// The width parameter controls line wrapping so long messages don't overflow.
|
||||
// Any @file tokens in the content are highlighted with the theme accent color.
|
||||
func UserBlock(content string, width int, ty *herald.Typography, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "(empty message)"
|
||||
@@ -27,10 +34,23 @@ func UserBlock(content string, width int, ty *herald.Typography, theme style.The
|
||||
content = lipgloss.Wrap(content, width-4, "")
|
||||
}
|
||||
|
||||
// Highlight @file tokens with accent color so file references are
|
||||
// visually distinct from surrounding prompt text.
|
||||
content = highlightFileTokens(content, theme)
|
||||
|
||||
rendered := ty.Tip(content)
|
||||
return styleMarginBottom(theme, rendered)
|
||||
}
|
||||
|
||||
// highlightFileTokens wraps @file tokens in the given text with the theme
|
||||
// accent color so they stand out visually in rendered user messages.
|
||||
func highlightFileTokens(text string, theme style.Theme) string {
|
||||
accentStyle := lipgloss.NewStyle().Foreground(theme.Accent).Bold(true)
|
||||
return fileTokenPattern.ReplaceAllStringFunc(text, func(token string) string {
|
||||
return accentStyle.Render(token)
|
||||
})
|
||||
}
|
||||
|
||||
// AssistantBlock renders an assistant message with markdown styling.
|
||||
func AssistantBlock(content string, width int, theme style.Theme) string {
|
||||
if strings.TrimSpace(content) == "" {
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package render
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/indaco/herald"
|
||||
|
||||
"github.com/mark3labs/kit/internal/ui/style"
|
||||
)
|
||||
|
||||
// testTypography creates a herald Typography for tests.
|
||||
func testTypography(theme style.Theme) *herald.Typography {
|
||||
return herald.New(
|
||||
herald.WithPalette(herald.ColorPalette{
|
||||
Primary: theme.Primary,
|
||||
Secondary: theme.Secondary,
|
||||
Tertiary: theme.Info,
|
||||
Accent: theme.Accent,
|
||||
Highlight: theme.Highlight,
|
||||
Muted: theme.Muted,
|
||||
Text: theme.Text,
|
||||
Surface: theme.Background,
|
||||
Base: theme.CodeBg,
|
||||
}),
|
||||
herald.WithAlertLabel(herald.AlertTip, "You"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestHighlightFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantHas []string // substrings that must be present in the output
|
||||
wantNone []string // substrings that must NOT be present as plain text
|
||||
}{
|
||||
{
|
||||
name: "no tokens",
|
||||
input: "hello world",
|
||||
wantHas: []string{"hello world"},
|
||||
},
|
||||
{
|
||||
name: "single unquoted token",
|
||||
input: "refactor @main.go please",
|
||||
wantHas: []string{"@main.go", "refactor", "please"},
|
||||
},
|
||||
{
|
||||
name: "quoted token with spaces",
|
||||
input: `check @"path with spaces/file.txt" out`,
|
||||
wantHas: []string{`@"path with spaces/file.txt"`, "check", "out"},
|
||||
},
|
||||
{
|
||||
name: "multiple tokens",
|
||||
input: "@main.go @utils.go refactor these",
|
||||
wantHas: []string{"@main.go", "@utils.go", "refactor these"},
|
||||
},
|
||||
{
|
||||
name: "path with directory",
|
||||
input: "look at @internal/ui/render/blocks.go",
|
||||
wantHas: []string{"@internal/ui/render/blocks.go", "look at"},
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantHas: []string{""},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := highlightFileTokens(tt.input, theme)
|
||||
|
||||
for _, want := range tt.wantHas {
|
||||
if !strings.Contains(result, want) {
|
||||
t.Errorf("highlightFileTokens(%q) = %q, want substring %q", tt.input, result, want)
|
||||
}
|
||||
}
|
||||
|
||||
// If there were @tokens, the result should contain ANSI escape
|
||||
// sequences (from lipgloss styling).
|
||||
if fileTokenPattern.MatchString(tt.input) && !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("highlightFileTokens(%q) should contain ANSI escapes for @tokens but got %q", tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserBlockHighlightsFileTokens(t *testing.T) {
|
||||
theme := style.DefaultTheme()
|
||||
ty := testTypography(theme)
|
||||
|
||||
// A user message with @file tokens should contain ANSI escapes around the token.
|
||||
content := "refactor @main.go and @utils.go"
|
||||
result := UserBlock(content, 80, ty, theme)
|
||||
|
||||
// The rendered output should contain both file references.
|
||||
if !strings.Contains(result, "@main.go") {
|
||||
t.Errorf("UserBlock output should contain @main.go, got:\n%s", result)
|
||||
}
|
||||
if !strings.Contains(result, "@utils.go") {
|
||||
t.Errorf("UserBlock output should contain @utils.go, got:\n%s", result)
|
||||
}
|
||||
|
||||
// Verify ANSI codes are present (the tokens are styled).
|
||||
if !strings.Contains(result, "\x1b[") {
|
||||
t.Errorf("UserBlock output should contain ANSI escape codes for styled @file tokens")
|
||||
}
|
||||
}
|
||||
@@ -134,23 +134,28 @@ func (ut *UsageTracker) EstimateAndUpdateUsage(inputText, outputText string) {
|
||||
}
|
||||
|
||||
// SetContextTokens records the approximate current context window utilization.
|
||||
// This should be set from FinalUsage.InputTokens, which already includes the
|
||||
// full conversation history (system prompt + all previous messages). Do NOT
|
||||
// add OutputTokens as that would double-count (output becomes input next turn).
|
||||
// Use FinalResponse.Usage rather than aggregate TotalUsage, because TotalUsage
|
||||
// sums across all tool-calling steps and overstates the actual window fill level.
|
||||
//
|
||||
// The value should include ALL token categories from the last API call:
|
||||
//
|
||||
// InputTokens + CacheReadTokens + CacheCreationTokens + OutputTokens
|
||||
//
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens holds the bulk of the context. All four must be summed
|
||||
// to get the true context window fill level.
|
||||
//
|
||||
// OutputTokens is included because the assistant's output becomes part of
|
||||
// the context on the next turn.
|
||||
//
|
||||
// Use FinalResponse.Usage (last step only) rather than aggregate TotalUsage,
|
||||
// because TotalUsage sums across all tool-calling steps and overstates the
|
||||
// actual window fill level.
|
||||
//
|
||||
// The value is set unconditionally (not max-only) so that context shrinks
|
||||
// correctly after compaction.
|
||||
func (ut *UsageTracker) SetContextTokens(tokens int) {
|
||||
ut.mu.Lock()
|
||||
defer ut.mu.Unlock()
|
||||
// Track the maximum context seen so far. In multi-step tool calls,
|
||||
// FinalUsage.InputTokens may reflect only the last step's input, which
|
||||
// can be smaller than previous steps. We want to show the largest context
|
||||
// the model has processed in this session.
|
||||
if tokens > ut.contextTokens {
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
// If tokens < current, we keep the larger value (no-op)
|
||||
// This prevents the display from dropping during multi-step tool calls.
|
||||
ut.contextTokens = tokens
|
||||
}
|
||||
|
||||
// RenderUsageInfo generates a formatted string displaying current usage statistics
|
||||
|
||||
@@ -31,6 +31,11 @@ func (m *Kit) EstimateContextTokens() int {
|
||||
// limit and should be compacted.
|
||||
// Formula: contextTokens > contextWindow − reserveTokens.
|
||||
// Returns false if the model's context limit is unknown.
|
||||
//
|
||||
// When API-reported token counts are available (after at least one turn),
|
||||
// the real count is used instead of the text-based heuristic. This is
|
||||
// significantly more accurate because it includes system prompts, tool
|
||||
// definitions, and other overhead that the heuristic cannot account for.
|
||||
func (m *Kit) ShouldCompact() bool {
|
||||
info := m.GetModelInfo()
|
||||
if info == nil || info.Limit.Context <= 0 {
|
||||
@@ -42,6 +47,16 @@ func (m *Kit) ShouldCompact() bool {
|
||||
reserveTokens = m.compactionOpts.ReserveTokens
|
||||
}
|
||||
|
||||
// Prefer the real API-reported token count when available.
|
||||
m.lastInputTokensMu.RLock()
|
||||
realTokens := m.lastInputTokens
|
||||
m.lastInputTokensMu.RUnlock()
|
||||
|
||||
if realTokens > 0 {
|
||||
return realTokens > info.Limit.Context-reserveTokens
|
||||
}
|
||||
|
||||
// Fall back to text-based heuristic before first turn completes.
|
||||
messages := m.session.GetMessages()
|
||||
return compaction.ShouldCompact(convertKitMessagesToFantasy(messages), info.Limit.Context, reserveTokens)
|
||||
}
|
||||
@@ -245,6 +260,14 @@ func (m *Kit) persistAndEmitCompaction(
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to persist compaction entry: %w", err)
|
||||
}
|
||||
|
||||
// Reset the API-reported token count so GetContextStats() and
|
||||
// ShouldCompact() don't use stale pre-compaction values. The next
|
||||
// API call will set the accurate post-compaction count.
|
||||
m.lastInputTokensMu.Lock()
|
||||
m.lastInputTokens = 0
|
||||
m.lastInputTokensMu.Unlock()
|
||||
|
||||
m.events.emit(CompactionEvent{
|
||||
Summary: summary,
|
||||
OriginalTokens: originalTokens,
|
||||
|
||||
+264
-55
@@ -51,6 +51,12 @@ type Kit struct {
|
||||
authHandler MCPAuthHandler // OAuth handler for remote MCP servers (may need Close)
|
||||
opts *Options // stored for reload operations (skills, etc.)
|
||||
|
||||
// hasCustomSystemPrompt is true when the user explicitly configured a
|
||||
// system prompt (via --system-prompt flag, config file, or SDK option).
|
||||
// When false, per-model system prompts from modelSettings/customModels
|
||||
// can replace the default prompt on model switch.
|
||||
hasCustomSystemPrompt bool
|
||||
|
||||
// Hook registries — interception layer (see hooks.go).
|
||||
beforeToolCall *hookRegistry[BeforeToolCallHook, BeforeToolCallResult]
|
||||
afterToolResult *hookRegistry[AfterToolResultHook, AfterToolResultResult]
|
||||
@@ -140,6 +146,79 @@ func (m *Kit) MCPToolsReady() bool {
|
||||
return m.agent.MCPToolsReady()
|
||||
}
|
||||
|
||||
// MCPServerStatus describes the runtime state of a loaded MCP server.
|
||||
type MCPServerStatus struct {
|
||||
// Name is the configured server name.
|
||||
Name string
|
||||
// ToolCount is the number of tools loaded from this server.
|
||||
ToolCount int
|
||||
}
|
||||
|
||||
// AddMCPServer connects to a new MCP server at runtime and makes its tools
|
||||
// available to the agent immediately. The server's tools are prefixed with the
|
||||
// server name (e.g. "myserver__tool_name") to avoid naming conflicts, matching
|
||||
// the behaviour of servers loaded at initialization.
|
||||
//
|
||||
// Returns the number of tools loaded from the server.
|
||||
//
|
||||
// AddMCPServer is safe to call while the agent is idle. If a turn is in
|
||||
// progress ([Kit.IsGenerating] returns true), the new tools will be visible
|
||||
// starting from the next LLM step.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// n, err := k.AddMCPServer(ctx, "github", kit.MCPServerConfig{
|
||||
// Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
|
||||
// Environment: map[string]string{"GITHUB_TOKEN": os.Getenv("GITHUB_TOKEN")},
|
||||
// })
|
||||
func (m *Kit) AddMCPServer(ctx context.Context, name string, cfg MCPServerConfig) (int, error) {
|
||||
return m.agent.AddMCPServer(ctx, name, cfg)
|
||||
}
|
||||
|
||||
// RemoveMCPServer disconnects an MCP server and removes all its tools from
|
||||
// the agent. After this call the agent will no longer see or be able to call
|
||||
// tools from the named server.
|
||||
//
|
||||
// RemoveMCPServer is safe to call while the agent is idle. If a turn is in
|
||||
// progress, the tools are removed at the next LLM step. Any in-flight tool
|
||||
// calls to the removed server will fail gracefully.
|
||||
//
|
||||
// Returns an error if the named server is not currently loaded.
|
||||
func (m *Kit) RemoveMCPServer(name string) error {
|
||||
return m.agent.RemoveMCPServer(name)
|
||||
}
|
||||
|
||||
// ListMCPServers returns the status of all currently loaded MCP servers.
|
||||
// The returned slice is a snapshot; it is safe to read concurrently.
|
||||
func (m *Kit) ListMCPServers() []MCPServerStatus {
|
||||
names := m.agent.GetLoadedServerNames()
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a tool count per server by scanning tool names for the prefix.
|
||||
toolNames := m.GetToolNames()
|
||||
countByServer := make(map[string]int, len(names))
|
||||
for _, tn := range toolNames {
|
||||
for _, sn := range names {
|
||||
prefix := sn + "__"
|
||||
if len(tn) > len(prefix) && tn[:len(prefix)] == prefix {
|
||||
countByServer[sn]++
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := make([]MCPServerStatus, 0, len(names))
|
||||
for _, n := range names {
|
||||
result = append(result, MCPServerStatus{
|
||||
Name: n,
|
||||
ToolCount: countByServer[n],
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetExtensionToolCount returns the number of tools registered by extensions.
|
||||
func (m *Kit) GetExtensionToolCount() int {
|
||||
return m.agent.GetExtensionToolCount()
|
||||
@@ -221,9 +300,12 @@ func iterBranchMessages[T any](tm *session.TreeManager, fn func(*session.Message
|
||||
return results
|
||||
}
|
||||
|
||||
// SetModel changes the active model at runtime. The existing tools, system
|
||||
// prompt, and session are preserved. The model string should be in
|
||||
// "provider/model" format (e.g. "anthropic/claude-sonnet-4-5-20250929").
|
||||
// SetModel changes the active model at runtime. The existing tools and
|
||||
// session are preserved. When the new model has a per-model system prompt
|
||||
// (from modelSettings or customModels params), it is composed with the
|
||||
// current AGENTS.md context and skills before being applied.
|
||||
// The model string should be in "provider/model" format
|
||||
// (e.g. "anthropic/claude-sonnet-4-5-20250929").
|
||||
// Returns an error if the model string is invalid or the provider cannot
|
||||
// be created.
|
||||
func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
@@ -239,7 +321,7 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
|
||||
// With message-level caching, thinking and caching can work together.
|
||||
// No need to disable caching when thinking is enabled.
|
||||
config := &models.ProviderConfig{
|
||||
cfg := &models.ProviderConfig{
|
||||
ModelString: modelString,
|
||||
SystemPrompt: systemPrompt,
|
||||
ProviderAPIKey: viper.GetString("provider-api-key"),
|
||||
@@ -249,18 +331,50 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
ThinkingLevel: thinkingLevel,
|
||||
DisableCaching: false, // Caching enabled by default, works with thinking
|
||||
}
|
||||
temperature := float32(viper.GetFloat64("temperature"))
|
||||
config.Temperature = &temperature
|
||||
topP := float32(viper.GetFloat64("top-p"))
|
||||
config.TopP = &topP
|
||||
topK := int32(viper.GetInt("top-k"))
|
||||
config.TopK = &topK
|
||||
frequencyPenalty := float32(viper.GetFloat64("frequency-penalty"))
|
||||
config.FrequencyPenalty = &frequencyPenalty
|
||||
presencePenalty := float32(viper.GetFloat64("presence-penalty"))
|
||||
config.PresencePenalty = &presencePenalty
|
||||
|
||||
if err := m.agent.SetModel(ctx, config); err != nil {
|
||||
// Only set generation parameter pointers when the user has explicitly
|
||||
// provided a value. This leaves nil pointers for unset params, allowing
|
||||
// per-model defaults (modelSettings / customModels params) to apply.
|
||||
if viper.IsSet("temperature") {
|
||||
v := float32(viper.GetFloat64("temperature"))
|
||||
cfg.Temperature = &v
|
||||
}
|
||||
if viper.IsSet("top-p") {
|
||||
v := float32(viper.GetFloat64("top-p"))
|
||||
cfg.TopP = &v
|
||||
}
|
||||
if viper.IsSet("top-k") {
|
||||
v := int32(viper.GetInt("top-k"))
|
||||
cfg.TopK = &v
|
||||
}
|
||||
if viper.IsSet("frequency-penalty") {
|
||||
v := float32(viper.GetFloat64("frequency-penalty"))
|
||||
cfg.FrequencyPenalty = &v
|
||||
}
|
||||
if viper.IsSet("presence-penalty") {
|
||||
v := float32(viper.GetFloat64("presence-penalty"))
|
||||
cfg.PresencePenalty = &v
|
||||
}
|
||||
|
||||
// When the user hasn't set a custom global system prompt, check for a
|
||||
// per-model system prompt. Pre-apply model settings to discover it,
|
||||
// then compose with AGENTS.md context and skills if found.
|
||||
if !m.hasCustomSystemPrompt {
|
||||
// Temporarily clear the system prompt so ApplyModelSettings can
|
||||
// detect that no explicit prompt is set and apply the per-model one.
|
||||
cfg.SystemPrompt = ""
|
||||
models.ApplyModelSettings(cfg, models.LookupModelForSettings(modelString))
|
||||
|
||||
if cfg.SystemPrompt != "" {
|
||||
// Per-model system prompt found — compose with runtime context.
|
||||
cfg.SystemPrompt = m.composeSystemPrompt(cfg.SystemPrompt)
|
||||
} else {
|
||||
// No per-model prompt — restore the global composed prompt.
|
||||
cfg.SystemPrompt = systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.agent.SetModel(ctx, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -276,6 +390,32 @@ func (m *Kit) SetModel(ctx context.Context, modelString string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// composeSystemPrompt takes a base system prompt and composes it with the
|
||||
// current runtime context: AGENTS.md content, skills metadata, and date/cwd.
|
||||
// This mirrors the composition done during Kit.New() initialization.
|
||||
func (m *Kit) composeSystemPrompt(basePrompt string) string {
|
||||
cwd, _ := os.Getwd()
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
for _, cf := range m.contextFiles {
|
||||
pb.WithSection("", fmt.Sprintf("Instructions from: %s\n\n%s", cf.Path, cf.Content))
|
||||
}
|
||||
|
||||
// Inject skills metadata.
|
||||
if len(m.skills) > 0 {
|
||||
pb.WithSkills(m.skills)
|
||||
}
|
||||
|
||||
// Append current date/time and working directory.
|
||||
pb.WithSection("", fmt.Sprintf(
|
||||
"Current date and time: %s\nCurrent working directory: %s",
|
||||
time.Now().Format("Monday, January 2, 2006, 3:04:05 PM MST"), cwd,
|
||||
))
|
||||
|
||||
return pb.Build()
|
||||
}
|
||||
|
||||
// GetAvailableModels returns a list of known models from the registry. Each
|
||||
// entry includes provider, model ID, context limit, and whether the model
|
||||
// supports reasoning. This is an advisory list — models not in the registry
|
||||
@@ -477,6 +617,14 @@ type Options struct {
|
||||
// Skills
|
||||
Skills []string // Explicit skill files/dirs to load (empty = auto-discover)
|
||||
SkillsDir string // Override default project-local skills directory
|
||||
NoSkills bool // Disable skill loading entirely (auto-discovery and explicit)
|
||||
|
||||
// NoExtensions disables Yaegi extension loading entirely.
|
||||
NoExtensions bool
|
||||
|
||||
// NoContextFiles disables automatic loading of project context files
|
||||
// (e.g. AGENTS.md) from the working directory.
|
||||
NoContextFiles bool
|
||||
|
||||
// Compaction
|
||||
AutoCompact bool // Auto-compact when near context limit
|
||||
@@ -497,6 +645,17 @@ type Options struct {
|
||||
// display a URL in a custom UI, redirect to a web app, etc.).
|
||||
MCPAuthHandler MCPAuthHandler
|
||||
|
||||
// MCPTokenStoreFactory, if non-nil, is called to create a token store for
|
||||
// each remote MCP server that requires OAuth. The factory receives the
|
||||
// server's URL and returns a [MCPTokenStore] implementation.
|
||||
//
|
||||
// When nil (default), tokens are persisted to a JSON file at
|
||||
// $XDG_CONFIG_HOME/.kit/mcp_tokens.json (or ~/.config/.kit/mcp_tokens.json).
|
||||
//
|
||||
// Use this to store tokens in a database, encrypt them, keep them
|
||||
// in-memory, or write them to a custom file path.
|
||||
MCPTokenStoreFactory MCPTokenStoreFactory
|
||||
|
||||
// OnMCPServerLoaded, if non-nil, is called when each MCP server finishes
|
||||
// loading during Kit initialization. The callback receives the server name,
|
||||
// tool count, and any error. Called from a background goroutine; safe to
|
||||
@@ -582,16 +741,17 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
// provider creation, session init) then runs outside the lock, allowing
|
||||
// parallel subagent spawns to proceed concurrently.
|
||||
var (
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
cwd string
|
||||
contextFiles []*ContextFile
|
||||
loadedSkills []*Skill
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
providerConfig *models.ProviderConfig
|
||||
modelString string
|
||||
cwd string
|
||||
contextFiles []*ContextFile
|
||||
loadedSkills []*Skill
|
||||
mcpConfig *config.Config
|
||||
debug bool
|
||||
noExtensions bool
|
||||
maxSteps int
|
||||
streaming bool
|
||||
hasCustomSystemPrompt bool
|
||||
)
|
||||
|
||||
if err := func() error {
|
||||
@@ -636,19 +796,56 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
// Load context files (AGENTS.md) from the project root.
|
||||
contextFiles = loadContextFiles(cwd)
|
||||
if !opts.NoContextFiles {
|
||||
contextFiles = loadContextFiles(cwd)
|
||||
}
|
||||
|
||||
// Load skills — either from explicit paths or via auto-discovery.
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
if !opts.NoSkills {
|
||||
var err error
|
||||
loadedSkills, err = loadSkills(opts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load skills: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Always compose the system prompt with runtime context: base prompt +
|
||||
// AGENTS.md context + skills metadata + date/cwd.
|
||||
//
|
||||
// If the configured model has a per-model system prompt (via
|
||||
// modelSettings or customModels params) and the user hasn't
|
||||
// explicitly set system-prompt, use the per-model prompt as the
|
||||
// base instead of the global default.
|
||||
{
|
||||
basePrompt := viper.GetString("system-prompt")
|
||||
|
||||
// Track whether the user explicitly configured a custom system
|
||||
// prompt. When they haven't (basePrompt is the built-in default
|
||||
// or empty), per-model system prompts can replace it on switch.
|
||||
userSetSystemPrompt := basePrompt != "" && basePrompt != defaultSystemPrompt
|
||||
hasCustomSystemPrompt = userSetSystemPrompt
|
||||
|
||||
// Check for per-model system prompt override when no explicit
|
||||
// global system-prompt was configured by the user.
|
||||
if !userSetSystemPrompt {
|
||||
modelStr := viper.GetString("model")
|
||||
if modelStr != "" {
|
||||
if mi := models.LookupModelForSettings(modelStr); mi != nil {
|
||||
var perModelParams *models.GenerationParams
|
||||
// modelSettings takes priority over custom model params.
|
||||
if ms := models.LoadModelSettingsFromConfig(); ms != nil {
|
||||
perModelParams = ms[modelStr]
|
||||
}
|
||||
if perModelParams == nil && mi.Params != nil {
|
||||
perModelParams = mi.Params
|
||||
}
|
||||
if perModelParams != nil && perModelParams.SystemPrompt != "" {
|
||||
basePrompt = models.LoadSystemPromptValue(perModelParams.SystemPrompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pb := skills.NewPromptBuilder(basePrompt)
|
||||
|
||||
// Inject AGENTS.md content as project context.
|
||||
@@ -679,7 +876,7 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
modelString = viper.GetString("model")
|
||||
debug = viper.GetBool("debug")
|
||||
noExtensions = viper.GetBool("no-extensions")
|
||||
noExtensions = opts.NoExtensions || viper.GetBool("no-extensions")
|
||||
maxSteps = viper.GetInt("max-steps")
|
||||
streaming = viper.GetBool("stream")
|
||||
|
||||
@@ -745,6 +942,13 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Set up custom token store factory for MCP OAuth tokens.
|
||||
// The SDK MCPTokenStoreFactory is structurally identical to
|
||||
// tools.TokenStoreFactory, so it can be assigned directly.
|
||||
if opts.MCPTokenStoreFactory != nil {
|
||||
setupOpts.TokenStoreFactory = tools.TokenStoreFactory(opts.MCPTokenStoreFactory)
|
||||
}
|
||||
|
||||
if opts.CLI != nil {
|
||||
setupOpts.ShowSpinner = opts.CLI.ShowSpinner
|
||||
setupOpts.SpinnerFunc = opts.CLI.SpinnerFunc
|
||||
@@ -774,24 +978,25 @@ func New(ctx context.Context, opts *Options) (*Kit, error) {
|
||||
}
|
||||
|
||||
k := &Kit{
|
||||
agent: agentResult.Agent,
|
||||
session: sessionManager,
|
||||
modelString: modelString,
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
contextFiles: contextFiles,
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
opts: opts,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
afterTurn: afterTurn,
|
||||
contextPrepare: contextPrepare,
|
||||
beforeCompact: beforeCompact,
|
||||
agent: agentResult.Agent,
|
||||
session: sessionManager,
|
||||
modelString: modelString,
|
||||
events: newEventBus(),
|
||||
autoCompact: opts.AutoCompact,
|
||||
compactionOpts: opts.CompactionOptions,
|
||||
contextFiles: contextFiles,
|
||||
skills: loadedSkills,
|
||||
extRunner: agentResult.ExtRunner,
|
||||
bufferedLogger: agentResult.BufferedLogger,
|
||||
authHandler: setupOpts.AuthHandler,
|
||||
opts: opts,
|
||||
hasCustomSystemPrompt: hasCustomSystemPrompt,
|
||||
beforeToolCall: beforeToolCall,
|
||||
afterToolResult: afterToolResult,
|
||||
beforeTurn: beforeTurn,
|
||||
afterTurn: afterTurn,
|
||||
contextPrepare: contextPrepare,
|
||||
beforeCompact: beforeCompact,
|
||||
}
|
||||
|
||||
// Bridge extension events to SDK hooks.
|
||||
@@ -978,9 +1183,11 @@ type TurnResult struct {
|
||||
// report usage.
|
||||
TotalUsage *LLMUsage
|
||||
|
||||
// FinalUsage is the token usage from the last API call only. Use this
|
||||
// for context window fill estimation (InputTokens + OutputTokens ≈
|
||||
// current context size). Nil if unavailable.
|
||||
// FinalUsage is the token usage from the last API call only. For context
|
||||
// window fill, sum all categories: InputTokens + CacheReadTokens +
|
||||
// CacheCreationTokens + OutputTokens. With prompt caching, InputTokens
|
||||
// alone understates the context (cached tokens are reported separately).
|
||||
// Nil if unavailable.
|
||||
FinalUsage *LLMUsage
|
||||
|
||||
// Messages is the full updated conversation after the turn, including
|
||||
@@ -1459,12 +1666,14 @@ func (m *Kit) runTurn(ctx context.Context, promptLabel string, prompt string, pr
|
||||
}
|
||||
|
||||
// Store the API-reported token count so GetContextStats() matches the
|
||||
// built-in status bar (which uses input + output tokens). The
|
||||
// text-based heuristic misses system prompts, tool definitions, etc.
|
||||
// built-in status bar. The context window is filled by all token
|
||||
// categories: non-cached input, cache reads, cache writes, and output.
|
||||
// With Anthropic prompt caching, InputTokens can be near-zero while
|
||||
// CacheReadTokens/CacheCreationTokens hold the bulk of the context.
|
||||
if result.FinalResponse != nil {
|
||||
u := result.FinalResponse.Usage
|
||||
m.lastInputTokensMu.Lock()
|
||||
m.lastInputTokens = int(u.InputTokens) + int(u.OutputTokens)
|
||||
m.lastInputTokens = int(u.InputTokens) + int(u.CacheReadTokens) + int(u.CacheCreationTokens) + int(u.OutputTokens)
|
||||
m.lastInputTokensMu.Unlock()
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package kit_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
kit "github.com/mark3labs/kit/pkg/kit"
|
||||
)
|
||||
|
||||
// TestMCPServerStatus_TypeSurface verifies the MCPServerStatus type is
|
||||
// accessible and has the expected fields.
|
||||
func TestMCPServerStatus_TypeSurface(t *testing.T) {
|
||||
s := kit.MCPServerStatus{
|
||||
Name: "test-server",
|
||||
ToolCount: 5,
|
||||
}
|
||||
if s.Name != "test-server" {
|
||||
t.Errorf("Expected Name 'test-server', got %q", s.Name)
|
||||
}
|
||||
if s.ToolCount != 5 {
|
||||
t.Errorf("Expected ToolCount 5, got %d", s.ToolCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMCPServerConfig_ForDynamicAdd verifies that MCPServerConfig can be
|
||||
// constructed with the expected fields for dynamic server management.
|
||||
func TestMCPServerConfig_ForDynamicAdd(t *testing.T) {
|
||||
// Stdio server config.
|
||||
stdio := kit.MCPServerConfig{
|
||||
Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
|
||||
Environment: map[string]string{"GITHUB_TOKEN": "test-token"},
|
||||
}
|
||||
if len(stdio.Command) != 3 {
|
||||
t.Errorf("Expected 3 command parts, got %d", len(stdio.Command))
|
||||
}
|
||||
if stdio.Environment["GITHUB_TOKEN"] != "test-token" {
|
||||
t.Error("Expected GITHUB_TOKEN in environment")
|
||||
}
|
||||
|
||||
// Remote server config.
|
||||
remote := kit.MCPServerConfig{
|
||||
URL: "https://mcp.example.com/sse",
|
||||
Headers: []string{"Authorization: Bearer test"},
|
||||
}
|
||||
if remote.URL != "https://mcp.example.com/sse" {
|
||||
t.Errorf("Unexpected URL: %s", remote.URL)
|
||||
}
|
||||
|
||||
// Config with tool filtering.
|
||||
filtered := kit.MCPServerConfig{
|
||||
Command: []string{"some-server"},
|
||||
AllowedTools: []string{"read", "write"},
|
||||
}
|
||||
if len(filtered.AllowedTools) != 2 {
|
||||
t.Errorf("Expected 2 allowed tools, got %d", len(filtered.AllowedTools))
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mark3labs/kit/internal/message"
|
||||
"github.com/mark3labs/kit/internal/models"
|
||||
"github.com/mark3labs/kit/internal/session"
|
||||
"github.com/mark3labs/mcp-go/client/transport"
|
||||
)
|
||||
|
||||
// ==== Message Types (internal/message/content.go) ====
|
||||
@@ -204,6 +205,29 @@ type CompactionResult = compaction.CompactionResult
|
||||
// CompactionOptions configures compaction behaviour.
|
||||
type CompactionOptions = compaction.CompactionOptions
|
||||
|
||||
// ==== MCP OAuth Types ====
|
||||
|
||||
// MCPTokenStore persists OAuth tokens for a single MCP server. Implementations
|
||||
// must be safe for concurrent use.
|
||||
//
|
||||
// This is a type alias for the mcp-go transport.TokenStore interface. SDK
|
||||
// consumers can implement this interface to provide custom storage backends
|
||||
// (database, encrypted file, in-memory, etc.).
|
||||
type MCPTokenStore = transport.TokenStore
|
||||
|
||||
// MCPToken represents an OAuth token for an MCP server, containing access
|
||||
// and refresh tokens along with expiration metadata.
|
||||
type MCPToken = transport.Token
|
||||
|
||||
// MCPTokenStoreFactory creates an [MCPTokenStore] for a given MCP server URL.
|
||||
// It is called once per remote MCP server during connection setup.
|
||||
type MCPTokenStoreFactory func(serverURL string) (MCPTokenStore, error)
|
||||
|
||||
// ErrMCPNoToken is the sentinel error that [MCPTokenStore] implementations
|
||||
// should return from GetToken when no token is stored for the server.
|
||||
// Callers can check for this with errors.Is.
|
||||
var ErrMCPNoToken = transport.ErrNoToken
|
||||
|
||||
// ==== Constructor & Helper Functions ====
|
||||
|
||||
// ParseModelString parses a model string in "provider/model" format.
|
||||
|
||||
@@ -17,7 +17,7 @@ import time
|
||||
import os
|
||||
|
||||
KIT_BIN = os.path.join(os.path.dirname(__file__), "..", "output", "kit")
|
||||
MODEL = "opencode/kimi-k2.5"
|
||||
MODEL = os.environ.get("MODEL", "opencode/kimi-k2.5")
|
||||
CWD = os.path.expanduser("~")
|
||||
TIMEOUT = 60 # seconds to wait for the prompt to complete
|
||||
|
||||
|
||||
+79
-3
@@ -98,10 +98,20 @@ host, err := kit.New(ctx, &kit.Options{
|
||||
// Skills
|
||||
Skills: []string{"/path/to/skill.md"}, // explicit skill files (empty = auto-discover)
|
||||
SkillsDir: "/path/to/skills", // override project-local skills dir
|
||||
NoSkills: true, // disable skill loading entirely
|
||||
|
||||
// Feature toggles
|
||||
NoExtensions: true, // disable Yaegi extension loading entirely
|
||||
NoContextFiles: true, // disable automatic AGENTS.md loading
|
||||
|
||||
// Compaction
|
||||
AutoCompact: true, // auto-compact near context limit
|
||||
CompactionOptions: &kit.CompactionOptions{...}, // nil = defaults
|
||||
|
||||
// MCP OAuth
|
||||
MCPTokenStoreFactory: func(serverURL string) (kit.MCPTokenStore, error) {
|
||||
return myCustomStore(serverURL), nil // custom OAuth token storage
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
@@ -125,8 +135,12 @@ result, err := host.PromptResult(ctx, "Analyze this file")
|
||||
// result.StopReason — "stop", "length", "tool-calls", "error", etc.
|
||||
// result.SessionID — session UUID
|
||||
// result.TotalUsage — aggregate tokens across all steps (*kit.LLMUsage)
|
||||
// LLMUsage{InputTokens, OutputTokens, TotalTokens, ...}
|
||||
// LLMUsage{InputTokens, OutputTokens, TotalTokens,
|
||||
// ReasoningTokens, CacheCreationTokens, CacheReadTokens}
|
||||
// result.FinalUsage — tokens from last API call only (*kit.LLMUsage)
|
||||
// For context window fill, sum: InputTokens + CacheReadTokens +
|
||||
// CacheCreationTokens + OutputTokens (with prompt caching,
|
||||
// InputTokens alone understates the context)
|
||||
// result.Messages — full updated conversation ([]kit.LLMMessage)
|
||||
// LLMMessage{Role kit.LLMMessageRole, Content string}
|
||||
```
|
||||
@@ -466,6 +480,7 @@ names := host.GetToolNames() // []string of all tool names
|
||||
tools := host.GetTools() // []kit.Tool (full tool objects)
|
||||
mcpCount := host.GetMCPToolCount() // tools from MCP servers
|
||||
extCount := host.GetExtensionToolCount() // tools from extensions
|
||||
ready := host.MCPToolsReady() // true when async MCP tool loading is complete
|
||||
```
|
||||
|
||||
---
|
||||
@@ -618,6 +633,56 @@ Always `"provider/model"`: `"anthropic/claude-sonnet-4-5-20250929"`, `"openai/gp
|
||||
provider, modelID, err := kit.ParseModelString("anthropic/claude-sonnet-4-5-20250929")
|
||||
```
|
||||
|
||||
### Per-model system prompts
|
||||
|
||||
Models can have per-model system prompts configured via `modelSettings` or `customModels` in `.kit.yml`. When the user hasn't explicitly set a system prompt (via `--system-prompt`, config, or `Options.SystemPrompt`), the per-model prompt is used as the base and composed with AGENTS.md context and skills.
|
||||
|
||||
On `SetModel()`, if the new model has a per-model system prompt and no custom global prompt was set, the per-model prompt automatically replaces the previous one.
|
||||
|
||||
### Per-model generation parameters
|
||||
|
||||
Models can define default generation parameters (`temperature`, `top_p`, `top_k`, `frequency_penalty`, `presence_penalty`) via `modelSettings` or `customModels` `params` in `.kit.yml`. These defaults apply when the user hasn't explicitly set the parameter. Explicit CLI flags or config values always take priority.
|
||||
|
||||
---
|
||||
|
||||
## Dynamic MCP Server Management
|
||||
|
||||
Add, remove, and inspect MCP servers at runtime without restarting Kit:
|
||||
|
||||
```go
|
||||
// Add a new MCP server — tools become available immediately
|
||||
n, err := host.AddMCPServer(ctx, "github", kit.MCPServerConfig{
|
||||
Command: []string{"npx", "-y", "@modelcontextprotocol/server-github"},
|
||||
Environment: map[string]string{"GITHUB_TOKEN": os.Getenv("GITHUB_TOKEN")},
|
||||
})
|
||||
fmt.Printf("Loaded %d tools from github server\n", n)
|
||||
|
||||
// Remove an MCP server — its tools are no longer available
|
||||
err = host.RemoveMCPServer("github")
|
||||
|
||||
// List all currently loaded MCP servers
|
||||
servers := host.ListMCPServers()
|
||||
for _, s := range servers {
|
||||
fmt.Printf("Server %s: %d tools\n", s.Name, s.ToolCount)
|
||||
}
|
||||
```
|
||||
|
||||
`AddMCPServer` is safe to call while the agent is idle. If a turn is in progress, new tools are visible starting from the next LLM step. Tool names are prefixed with the server name (e.g. `"github__create_issue"`).
|
||||
|
||||
### MCP OAuth Token Storage
|
||||
|
||||
For remote MCP servers that use OAuth, you can provide a custom token store:
|
||||
|
||||
```go
|
||||
host, _ := kit.New(ctx, &kit.Options{
|
||||
MCPTokenStoreFactory: func(serverURL string) (kit.MCPTokenStore, error) {
|
||||
return &MyDatabaseTokenStore{serverURL: serverURL}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
The `MCPTokenStore` interface requires `GetToken`/`SetToken`/`DeleteToken` methods. Return `kit.ErrMCPNoToken` from `GetToken` when no token is stored. When nil (default), tokens are persisted to `$XDG_CONFIG_HOME/.kit/mcp_tokens.json`.
|
||||
|
||||
---
|
||||
|
||||
## Context & Compaction
|
||||
@@ -625,9 +690,12 @@ provider, modelID, err := kit.ParseModelString("anthropic/claude-sonnet-4-5-2025
|
||||
```go
|
||||
tokens := host.EstimateContextTokens() // heuristic token count
|
||||
shouldCompact := host.ShouldCompact() // true if near context limit
|
||||
// ShouldCompact() uses API-reported token counts (including cache tokens)
|
||||
// when available, falling back to text-based heuristic before the first turn.
|
||||
|
||||
stats := host.GetContextStats()
|
||||
// stats.EstimatedTokens — uses API-reported count when available (more accurate)
|
||||
// stats.EstimatedTokens — uses API-reported count when available (more accurate;
|
||||
// includes system prompts, tool definitions, cache tokens)
|
||||
// stats.ContextLimit — model's context window size
|
||||
// stats.UsagePercent — fraction used (0.0–1.0)
|
||||
// stats.MessageCount — number of messages
|
||||
@@ -787,13 +855,21 @@ kit.ProviderConfig, kit.ProviderResult, kit.ModelInfo, kit.ModelCost, kit.ModelL
|
||||
// LLM types — concrete Kit-owned structs (no external library dependency)
|
||||
kit.LLMMessage // {Role LLMMessageRole, Content string}
|
||||
kit.LLMMessageRole // "user" | "assistant" | "system" | "tool"
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ReasoningTokens, ...}
|
||||
kit.LLMUsage // {InputTokens, OutputTokens, TotalTokens, ReasoningTokens,
|
||||
// CacheCreationTokens, CacheReadTokens}
|
||||
kit.LLMResponse // {Content, FinishReason, Usage}
|
||||
kit.LLMFilePart // {Filename, Data []byte, MediaType}
|
||||
|
||||
// Compaction types
|
||||
kit.CompactionResult, kit.CompactionOptions
|
||||
|
||||
// MCP OAuth types
|
||||
kit.MCPTokenStore // interface for custom OAuth token storage
|
||||
kit.MCPToken // OAuth token struct (access, refresh, expiry)
|
||||
kit.MCPTokenStoreFactory // func(serverURL string) (MCPTokenStore, error)
|
||||
kit.ErrMCPNoToken // sentinel error for "no token stored"
|
||||
kit.MCPServerStatus // {Name string, ToolCount int}
|
||||
|
||||
// Conversion helpers
|
||||
msgs := kit.ConvertToLLMMessages(&msg) // SDK Message → []LLMMessage
|
||||
msg := kit.ConvertFromLLMMessage(lMsg) // LLMMessage → SDK Message
|
||||
|
||||
Reference in New Issue
Block a user